diff --git a/.codespellignore b/.codespellignore index c91d0f7707..4b1c229c68 100644 --- a/.codespellignore +++ b/.codespellignore @@ -2,4 +2,5 @@ InOut inout LoadE SelectE -ser \ No newline at end of file +ser +te \ No newline at end of file diff --git a/.config/nextest.toml b/.config/nextest.toml new file mode 100644 index 0000000000..db7e0b6a84 --- /dev/null +++ b/.config/nextest.toml @@ -0,0 +1,13 @@ +# Nextest configuration for OpenVM project + +# Define test groups with different weights +[[profile.default.overrides]] +# Match all tests with "persistent" in their name +filter = 'test(~persistent)' +# Give these tests 5x the default weight because they use more memory +threads-required = 16 + +# custom profile for heavy tests +[profile.heavy] +# Run fewer tests in parallel for heavy workloads +test-threads = 2 diff --git a/.github/workflows/benchmark-call.yml b/.github/workflows/benchmark-call.yml index 737e1c81ed..4b54185970 100644 --- a/.github/workflows/benchmark-call.yml +++ b/.github/workflows/benchmark-call.yml @@ -49,7 +49,7 @@ on: features: type: string required: false - description: Host features, comma separated (aggregation,profiling) + description: Host features, comma separated (aggregation,perf-metrics) workflow_call: inputs: benchmark_name: @@ -102,12 +102,12 @@ on: features: type: string required: false - description: Host features, comma separated (aggregation,profiling) + description: Host features, comma separated (aggregation,perf-metrics) env: S3_METRICS_PATH: s3://openvm-public-data-sandbox-us-east-1/benchmark/github/metrics S3_FLAMEGRAPHS_PATH: s3://openvm-public-data-sandbox-us-east-1/benchmark/github/flamegraphs - FEATURE_FLAGS: "bench-metrics,parallel,nightly-features" + FEATURE_FLAGS: "metrics,parallel,nightly-features" INPUT_ARGS: "" CARGO_NET_GIT_FETCH_WITH_CLI: "true" @@ -230,11 +230,11 @@ jobs: s5cmd cp $METRIC_PATH ${{ env.S3_METRICS_PATH }}/${METRIC_NAME}-${current_sha}.json - name: Install inferno-flamegraph - if: ${{ contains(env.FEATURE_FLAGS, 'profiling') }} + if: ${{ contains(env.FEATURE_FLAGS, 'perf-metrics') }} run: cargo install inferno - name: Generate flamegraphs - if: ${{ contains(env.FEATURE_FLAGS, 'profiling') }} + if: ${{ contains(env.FEATURE_FLAGS, 'perf-metrics') }} run: | if [[ -f $METRIC_PATH ]]; then GUEST_SYMBOLS_PATH="${METRIC_PATH%.json}.syms" @@ -250,9 +250,15 @@ jobs: fi ########################################################################## - # Update s3 for latest main metrics upon a push event # + # Update s3 for latest branch metrics upon a push event # ########################################################################## - - name: Update latest main result in s3 - if: github.event_name == 'push' && github.ref == 'refs/heads/main' + - name: Update latest branch result in s3 + if: github.event_name == 'push' run: | - s5cmd cp $METRIC_PATH "${{ env.S3_METRICS_PATH }}/main-${METRIC_NAME}.json" + if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then + # for backwards compatibility + REF_HASH="main" + else + REF_HASH=$(echo "${{ github.ref }}" | sha256sum | cut -d' ' -f1) + fi + s5cmd cp $METRIC_PATH "${{ env.S3_METRICS_PATH }}/${REF_HASH}-${METRIC_NAME}.json" diff --git a/.github/workflows/benchmarks-execute.yml b/.github/workflows/benchmarks-execute.yml index 741ccdb0f1..17c5277b0f 100644 --- a/.github/workflows/benchmarks-execute.yml +++ b/.github/workflows/benchmarks-execute.yml @@ -1,8 +1,9 @@ -name: "benchmarks-execute" +name: "Execution benchmarks" on: push: - branches: ["main"] + # TODO(ayush): remove after feat/new-execution is merged + branches: ["main", "feat/new-execution"] pull_request: types: [opened, synchronize, reopened, labeled] branches: ["**"] @@ -18,95 +19,101 @@ on: - ".github/workflows/benchmarks-execute.yml" workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always + S3_FIXTURES_PATH: s3://openvm-public-data-sandbox-us-east-1/benchmark/fixtures + JEMALLOC_SYS_WITH_MALLOC_CONF: "retain:true,background_thread:true,metadata_thp:always,thp:always,dirty_decay_ms:-1,muzzy_decay_ms:-1,abort_conf:true" jobs: - execute-benchmarks: + codspeed-walltime-benchmarks: + name: Run codspeed walltime benchmarks runs-on: - runs-on=${{ github.run_id }} - - runner=8cpu-linux-x64 + - family=m5a.xlarge # 2.5Ghz clock speed + - image=ubuntu24-full-x64 + - extras=s3-cache + + env: + CODSPEED_RUNNER_MODE: walltime + steps: + - uses: runs-on/action@v1 - uses: actions/checkout@v4 - - - name: Set up Rust - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 with: - profile: minimal - toolchain: stable - override: true + cache-on-failure: true - - name: Run execution benchmarks - working-directory: benchmarks/execute - run: cargo run | tee benchmark_output.log + - name: Install architecture specific tools + run: | + source ci/scripts/utils.sh + install_s5cmd - - name: Parse benchmark results + - name: Pull fixtures from S3 + run: | + mkdir -p benchmarks/fixtures + s5cmd cp "${{ env.S3_FIXTURES_PATH }}/*" benchmarks/fixtures/ || echo "No fixtures found in S3" + + - name: Install cargo-binstall + uses: cargo-bins/cargo-binstall@main + - name: Install codspeed + run: cargo binstall --no-confirm --force cargo-codspeed + + - name: Build benchmarks working-directory: benchmarks/execute + run: cargo codspeed build --profile maxperf + - name: Run benchmarks + uses: CodSpeedHQ/action@v3 + with: + working-directory: benchmarks/execute + run: cargo codspeed run + token: ${{ secrets.CODSPEED_TOKEN }} + + codspeed-instrumentation-benchmarks: + name: Run codspeed instrumentation benchmarks + runs-on: + - runs-on=${{ github.run_id }} + - family=m5a.xlarge + - image=ubuntu24-full-x64 + - extras=s3-cache + if: github.event_name != 'pull_request' + + env: + CODSPEED_RUNNER_MODE: instrumentation + + steps: + - uses: runs-on/action@v1 + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + cache-on-failure: true + + - name: Install architecture specific tools run: | - # Determine if running in GitHub Actions environment - if [ -n "$GITHUB_STEP_SUMMARY" ]; then - SUMMARY_FILE="$GITHUB_STEP_SUMMARY" - echo "### Benchmark Results Summary" >> "$SUMMARY_FILE" - else - SUMMARY_FILE="benchmark_summary.md" - echo "### Benchmark Results Summary" > "$SUMMARY_FILE" - echo "Saving summary to $SUMMARY_FILE" - fi - - # Set up summary table header - echo "| Program | Total Time (ms) |" >> "$SUMMARY_FILE" - echo "| ------- | --------------- |" >> "$SUMMARY_FILE" - - # Variables to track current program and total time - current_program="" - total_time=0 - - # Process the output file line by line - while IFS= read -r line; do - # Check if line contains "Running program" message - if [[ $line =~ i\ \[info\]:\ Running\ program:\ ([a-zA-Z0-9_-]+) ]]; then - # If we were processing a program, output its results - if [[ -n "$current_program" ]]; then - echo "| $current_program | $total_time |" >> "$SUMMARY_FILE" - fi - - # Start tracking new program - current_program="${BASH_REMATCH[1]}" - total_time=0 - fi - - # Check for program completion to catch programs that might have no execution segments - if [[ $line =~ i\ \[info\]:\ Completed\ program:\ ([a-zA-Z0-9_-]+) ]]; then - completed_program="${BASH_REMATCH[1]}" - # If no segments were found for this program, ensure it's still in the output - if [[ "$current_program" == "$completed_program" && $total_time == 0 ]]; then - echo "| $current_program | 0 |" >> "$SUMMARY_FILE" - current_program="" - fi - fi - - # Check if line contains execution time (looking for the format with ms or s) - if [[ $line =~ execute_segment\ \[\ ([0-9.]+)(ms|s)\ \|\ [0-9.]+%\ \]\ segment ]]; then - segment_time="${BASH_REMATCH[1]}" - unit="${BASH_REMATCH[2]}" - - # Convert to milliseconds if in seconds - if [[ "$unit" == "s" ]]; then - segment_time=$(echo "scale=6; $segment_time * 1000" | bc) - fi - - # Add segment time to total - total_time=$(echo "scale=6; $total_time + $segment_time" | bc) - fi - done < benchmark_output.log - - # Output the last program result if there was one - if [[ -n "$current_program" ]]; then - echo "| $current_program | $total_time |" >> "$SUMMARY_FILE" - fi - - # If not in GitHub Actions, print the summary to the terminal - if [ -z "$GITHUB_STEP_SUMMARY" ]; then - echo -e "\nBenchmark Summary:" - cat "$SUMMARY_FILE" - fi + source ci/scripts/utils.sh + install_s5cmd + + - name: Pull fixtures from S3 + run: | + mkdir -p benchmarks/fixtures + s5cmd cp "${{ env.S3_FIXTURES_PATH }}/*" benchmarks/fixtures/ || echo "No fixtures found in S3" + + - name: Install cargo-binstall + uses: cargo-bins/cargo-binstall@main + - name: Install codspeed + run: cargo binstall --no-confirm --force cargo-codspeed + + - name: Build benchmarks + working-directory: benchmarks/execute + run: cargo codspeed build + - name: Run benchmarks + uses: CodSpeedHQ/action@v3 + with: + working-directory: benchmarks/execute + run: cargo codspeed run + token: ${{ secrets.CODSPEED_TOKEN }} diff --git a/.github/workflows/benchmarks-upload-fixtures.yml b/.github/workflows/benchmarks-upload-fixtures.yml new file mode 100644 index 0000000000..d18538e419 --- /dev/null +++ b/.github/workflows/benchmarks-upload-fixtures.yml @@ -0,0 +1,42 @@ +name: "Upload benchmark fixtures" + +on: + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + S3_FIXTURES_PATH: s3://openvm-public-data-sandbox-us-east-1/benchmark/fixtures + +jobs: + generate-fixtures: + name: Generate and upload benchmark fixtures + runs-on: + - runs-on=${{ github.run_id }} + - runner=64cpu-linux-arm64 + - family=m7 + - extras=s3-cache + + steps: + - uses: runs-on/action@v1 + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + cache-on-failure: true + + - name: Install architecture specific tools + run: | + source ci/scripts/utils.sh + install_s5cmd + + - name: Generate fixtures + run: cargo r -r --bin generate-fixtures --features generate-fixtures + + - name: Upload fixtures to S3 + run: | + if [ -d "benchmarks/fixtures" ]; then + s5cmd cp benchmarks/fixtures/ ${{ env.S3_FIXTURES_PATH }}/ + else + echo "No fixtures directory found" + exit 1 + fi diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 3c2b02c574..4b1fbcc502 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -2,7 +2,7 @@ name: "OpenVM Benchmarks: Coordinate Runner & Reporting" on: push: - branches: ["main"] + branches: ["main", "feat/new-execution"] pull_request: types: [opened, synchronize, reopened, labeled] branches: ["**"] @@ -89,7 +89,7 @@ jobs: FEATURE_FLAGS="aggregation,${FEATURE_FLAGS}" fi if [[ "${{ github.event.inputs.flamegraphs }}" == "true" ]]; then - FEATURE_FLAGS="profiling,${FEATURE_FLAGS}" + FEATURE_FLAGS="perf-metrics,${FEATURE_FLAGS}" fi matrix=$(jq -c --argjson run_e2e $RUN_E2E --arg features "$FEATURE_FLAGS" ' @@ -211,9 +211,21 @@ jobs: json_file_list=$(echo -n "$json_files" | paste -sd "," -) echo $json_file_list - prev_json_files=$(echo $matrix | jq -r ' + # For PRs, get the latest commit from the target branch + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + if [[ "${{ github.base_ref }}" == "main" ]]; then + REF_HASH="main" + else + REF_HASH=$(echo "refs/heads/${{ github.base_ref }}" | sha256sum | cut -d' ' -f1) + fi + echo "Target branch REF_HASH: $REF_HASH" + else + REF_HASH="main" + fi + + prev_json_files=$(echo $matrix | jq -r --arg target "$REF_HASH" ' .[] | - "main-\(.id).json"') + "\($target)-\(.id).json"') prev_json_file_list=$(echo -n "$prev_json_files" | paste -sd "," -) echo $prev_json_file_list diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 574a49be15..7f3df63f6f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,6 +16,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: - uses: runs-on/action@v1 diff --git a/.github/workflows/cli.yml b/.github/workflows/cli.yml index 510a124092..d0816f6731 100644 --- a/.github/workflows/cli.yml +++ b/.github/workflows/cli.yml @@ -36,7 +36,8 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - disk=large - - runner=32cpu-linux-arm64 + - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: @@ -47,7 +48,8 @@ jobs: cache-on-failure: true - uses: taiki-e/install-action@nextest - name: Install solc # svm should support arm64 linux - run: (hash svm 2>/dev/null || cargo install --version 0.2.23 svm-rs) && svm install 0.8.19 && solc --version + run: | + (hash svm 2>/dev/null || cargo install --version 0.2.23 svm-rs) && svm install 0.8.19 && solc --version - name: Install cargo-openvm working-directory: crates/cli @@ -80,8 +82,7 @@ jobs: working-directory: crates/cli run: | export RUST_BACKTRACE=1 - cargo build - cargo run --bin cargo-openvm -- openvm keygen --config ./example/app_config.toml --output-dir . + cargo openvm keygen --config ./example/app_config.toml --output-dir . - name: Set USE_LOCAL_OPENVM environment variable run: | @@ -94,4 +95,5 @@ jobs: - name: Run CLI tests working-directory: crates/cli run: | - cargo nextest run --cargo-profile=fast + export SKIP_INSTALL=1 + cargo nextest run --cargo-profile=fast --test-threads=1 diff --git a/.github/workflows/extension-tests.yml b/.github/workflows/extension-tests.yml index ef13b840c6..2d07bdf1f6 100644 --- a/.github/workflows/extension-tests.yml +++ b/.github/workflows/extension-tests.yml @@ -40,6 +40,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - tag=extension-${{ matrix.extension.name }} - extras=s3-cache @@ -69,7 +70,7 @@ jobs: - name: Run ${{ matrix.extension.name }} circuit crate tests working-directory: extensions/${{ matrix.extension.path }}/circuit - run: cargo nextest run --cargo-profile=fast + run: cargo nextest run --cargo-profile=fast --test-threads=32 - name: Run ${{ matrix.extension.name }} guest crate tests if: hashFiles(format('extensions/{0}/guest', matrix.extension.path)) != '' @@ -86,4 +87,4 @@ jobs: working-directory: extensions/${{ matrix.extension.path }}/tests run: | rustup component add rust-src --toolchain nightly-2025-02-14 - cargo nextest run --cargo-profile=fast --no-tests=pass + cargo nextest run --cargo-profile=fast --profile=heavy --no-tests=pass diff --git a/.github/workflows/guest-lib-tests.yml b/.github/workflows/guest-lib-tests.yml index 1b87b600e2..98a3743a36 100644 --- a/.github/workflows/guest-lib-tests.yml +++ b/.github/workflows/guest-lib-tests.yml @@ -13,6 +13,7 @@ on: - "guest-libs/**" - "Cargo.toml" - ".github/workflows/guest-lib-tests.yml" + - "crates/sdk/guest/fib/**" concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} @@ -41,6 +42,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - tag=crate-${{ matrix.crate.name }} - extras=s3-cache diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index e41580948e..559579713b 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -46,7 +46,7 @@ jobs: # list of all unique features across workspace generated using: # cargo metadata --format-version=1 --no-deps | jq -r '.packages[].features | to_entries[] | .key' | sort -u | tr '\n' ' ' && echo "" # (exclude mimalloc since it conflicts with jemalloc) - cargo clippy --all-targets --all --tests --features "aggregation bench-metrics bls12_381 bn254 build-binaries default entrypoint evm-prove evm-verify export-intrinsics export-libm function-span getrandom-unsupported halo2-compiler halo2curves heap-embedded-alloc jemalloc jemalloc-prof nightly-features panic-handler parallel profiling rust-runtime static-verifier std test-utils" -- -D warnings + cargo clippy --all-targets --all --tests --features "aggregation bls12_381 bn254 build-elfs default entrypoint evm-prove evm-verify export-intrinsics export-libm function-span getrandom-unsupported halo2-compiler halo2curves heap-embedded-alloc jemalloc jemalloc-prof metrics nightly-features panic-handler parallel perf-metrics rust-runtime static-verifier std test-utils" -- -D warnings cargo clippy --all-targets --all --tests --no-default-features --features "mimalloc" -- -D warnings - name: Run fmt, clippy for guest diff --git a/.github/workflows/native-compiler.yml b/.github/workflows/native-compiler.yml index af4f39ddff..b79a3cb1c9 100644 --- a/.github/workflows/native-compiler.yml +++ b/.github/workflows/native-compiler.yml @@ -25,6 +25,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: diff --git a/.github/workflows/primitives.yml b/.github/workflows/primitives.yml index 714230b8cd..2d86155ab2 100644 --- a/.github/workflows/primitives.yml +++ b/.github/workflows/primitives.yml @@ -26,6 +26,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=32cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: diff --git a/.github/workflows/recursion.yml b/.github/workflows/recursion.yml index 814c1fa44a..64538c18c1 100644 --- a/.github/workflows/recursion.yml +++ b/.github/workflows/recursion.yml @@ -26,6 +26,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: diff --git a/.github/workflows/sdk.yml b/.github/workflows/sdk.yml index e24df21ffe..4d194a03df 100644 --- a/.github/workflows/sdk.yml +++ b/.github/workflows/sdk.yml @@ -26,6 +26,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - family=m7a.24xlarge + - image=ubuntu24-full-x64 - disk=large - extras=s3-cache @@ -97,4 +98,10 @@ jobs: working-directory: crates/sdk run: | export RUST_BACKTRACE=1 - cargo nextest run --cargo-profile=fast --test-threads=2 --features parallel,evm-verify + cargo nextest run --cargo-profile=fast --features parallel,evm-verify + + - name: Run ignored tests + working-directory: crates/sdk + if: ${{ github.event_name == 'push' }} + run: | + cargo nextest run --cargo-profile=fast --features parallel,evm-verify --ignored test_static_verifier_custom_pv_handler diff --git a/.github/workflows/vm.yml b/.github/workflows/vm.yml index cb7f2284ca..c8c03dc931 100644 --- a/.github/workflows/vm.yml +++ b/.github/workflows/vm.yml @@ -25,6 +25,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: @@ -40,3 +41,8 @@ jobs: working-directory: crates/vm run: | cargo nextest run --cargo-profile=fast --features parallel + + - name: Run vm crate tests with basic memory + working-directory: crates/vm + run: | + cargo nextest run --cargo-profile=fast --features parallel,basic-memory diff --git a/.gitignore b/.gitignore index d794a5dc57..aaf6aff435 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,9 @@ guest.syms # openvm generated files crates/cli/openvm/ + +# samply profile +profile.json.gz + +# test fixtures +benchmarks/fixtures diff --git a/CHANGELOG.md b/CHANGELOG.md index 928c548adf..ee6267bf99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,20 @@ All notable changes to OpenVM will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project follows a versioning principles documented in [VERSIONING.md](./VERSIONING.md). +## [Unreleased] + +### Added +- (Config) Added `addr_spaces` vector of `AddressSpaceHostConfig` to `MemoryConfig`. + +### Changed +- (Toolchain) Removed `step` from `Program` struct because `DEFAULT_PC_STEP = 4` is always used. +- (Config) The `clk_max_bits` field in `MemoryConfig` has been renamed to `timestamp_max_bits`. +- (Prover) Guest memory is stored on host with address space-specified memory layouts. In particular address space `1` through `3` are now represented in bytes instead of field elements. +- (ISA) Field arithmetic instructions now restrict address spaces `e, f` to be either `0` or `4`, instead of allowing any address space. +- (ISA) RV32IM load instructions are now restricted to address space `2` only, instead of allowing address spaces `0`, `1`, or `2`. +- (ISA) The maximum valid pointer value in address space `1` (register address space) is now `127`, corresponding to 32 registers with 4 byte limbs each. +- (ISA) Memory accesses now have configurable minimum block size requirements per address space. Address spaces `1`, `2`, and `3` require minimum block size of 4. Native address space (`4`) allows minimum block size of 1. Address spaces beyond `4` default to minimum block size of 1 but are configurable. + ## v1.3.0 (2025-07-15) No circuit constraints or verifying keys were changed in this release. diff --git a/Cargo.lock b/Cargo.lock index ce7abadf50..7913133053 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "Inflector" @@ -34,9 +34,9 @@ dependencies = [ [[package]] name = "adler2" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aes" @@ -51,9 +51,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "once_cell", @@ -76,28 +76,61 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "alloy-eip2124" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "741bdd7499908b3aa0b159bba11e71c8cddd009a2c2eb7a06e825f1ec87900a5" +dependencies = [ + "alloy-primitives 1.2.1", + "alloy-rlp", + "crc", + "serde", + "thiserror 2.0.12", +] + [[package]] name = "alloy-eip2930" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0069cf0642457f87a01a014f6dc29d5d893cd4fd8fddf0c3cdfad1bb3ebafc41" +checksum = "7b82752a889170df67bbb36d42ca63c531eb16274f0d7299ae2a680facba17bd" dependencies = [ - "alloy-primitives 0.8.25", + "alloy-primitives 1.2.1", "alloy-rlp", "serde", ] [[package]] name = "alloy-eip7702" -version = "0.4.2" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c986539255fb839d1533c128e190e557e52ff652c9ef62939e233a81dd93f7e" +checksum = "9d4769c6ffddca380b0070d71c8b7f30bed375543fe76bb2f74ec0acf4b7cd16" dependencies = [ - "alloy-primitives 0.8.25", + "alloy-primitives 1.2.1", "alloy-rlp", - "derive_more 1.0.0", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "serde", + "thiserror 2.0.12", +] + +[[package]] +name = "alloy-eips" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f35887da30b5fc50267109a3c61cd63e6ca1f45967983641053a40ee83468c1" +dependencies = [ + "alloy-eip2124", + "alloy-eip2930", + "alloy-eip7702", + "alloy-primitives 1.2.1", + "alloy-rlp", + "alloy-serde", + "auto_impl", + "c-kzg", + "derive_more 2.0.1", + "either", + "serde", + "sha2 0.10.9", ] [[package]] @@ -121,10 +154,10 @@ dependencies = [ "bytes", "cfg-if", "const-hex", - "derive_more 0.99.19", + "derive_more 0.99.20", "hex-literal 0.4.1", "itoa", - "ruint 1.12.3", + "ruint 1.15.0", "tiny-keccak", ] @@ -138,10 +171,10 @@ dependencies = [ "bytes", "cfg-if", "const-hex", - "derive_more 0.99.19", + "derive_more 0.99.20", "hex-literal 0.4.1", "itoa", - "ruint 1.12.3", + "ruint 1.15.0", "tiny-keccak", ] @@ -157,15 +190,42 @@ dependencies = [ "const-hex", "derive_more 2.0.1", "foldhash", - "hashbrown 0.15.2", - "indexmap 2.7.1", + "hashbrown 0.15.4", + "indexmap 2.10.0", "itoa", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "keccak-asm", "paste", "proptest", "rand 0.8.5", - "ruint 1.12.3", + "ruint 1.15.0", + "rustc-hash 2.1.1", + "serde", + "sha3", + "tiny-keccak", +] + +[[package]] +name = "alloy-primitives" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6177ed26655d4e84e00b65cb494d4e0b8830e7cae7ef5d63087d445a2600fb55" +dependencies = [ + "alloy-rlp", + "bytes", + "cfg-if", + "const-hex", + "derive_more 2.0.1", + "foldhash", + "hashbrown 0.15.4", + "indexmap 2.10.0", + "itoa", + "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", + "keccak-asm", + "paste", + "proptest", + "rand 0.9.1", + "ruint 1.15.0", "rustc-hash 2.1.1", "serde", "sha3", @@ -174,9 +234,9 @@ dependencies = [ [[package]] name = "alloy-rlp" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6c1d995bff8d011f7cd6c81820d51825e6e06d6db73914c1630ecf544d83d6" +checksum = "5f70d83b765fdc080dbcd4f4db70d8d23fe4761f2f02ebfa9146b833900634b4" dependencies = [ "alloy-rlp-derive", "arrayvec", @@ -185,13 +245,24 @@ dependencies = [ [[package]] name = "alloy-rlp-derive" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a40e1ef334153322fd878d07e86af7a529bcb86b2439525920a88eba87bcf943" +checksum = "64b728d511962dda67c1bc7ea7c03736ec275ed2cf4c35d9585298ac9ccf3b73" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "alloy-serde" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee8d2c52adebf3e6494976c8542fbdf12f10123b26e11ad56f77274c16a2a039" +dependencies = [ + "alloy-primitives 1.2.1", + "serde", + "serde_json", ] [[package]] @@ -205,7 +276,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -218,11 +289,11 @@ dependencies = [ "alloy-sol-macro-input", "const-hex", "heck", - "indexmap 2.7.1", + "indexmap 2.10.0", "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "syn-solidity", "tiny-keccak", ] @@ -241,7 +312,7 @@ dependencies = [ "proc-macro2", "quote", "serde_json", - "syn 2.0.98", + "syn 2.0.104", "syn-solidity", ] @@ -252,7 +323,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d162f8524adfdfb0e4bd0505c734c985f3e2474eb022af32eef0d52a4f3935c" dependencies = [ "serde", - "winnow 0.7.3", + "winnow 0.7.12", ] [[package]] @@ -300,9 +371,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.18" +version = "0.6.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" dependencies = [ "anstyle", "anstyle-parse", @@ -315,44 +386,44 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" [[package]] name = "anstyle-parse" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" dependencies = [ "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.7" +version = "3.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" dependencies = [ "anstyle", - "once_cell", + "once_cell_polyfill", "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.96" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "approx" @@ -379,6 +450,18 @@ dependencies = [ "yansi 0.5.1", ] +[[package]] +name = "ark-bls12-381" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df4dcc01ff89867cd86b0da835f23c3f02738353aaee7dde7495af71363b8d5" +dependencies = [ + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", +] + [[package]] name = "ark-bn254" version = "0.3.0" @@ -401,6 +484,18 @@ dependencies = [ "ark-std 0.4.0", ] +[[package]] +name = "ark-bn254" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d69eab57e8d2663efa5c63135b2af4f396d66424f88954c21104125ab6b3e6bc" +dependencies = [ + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-r1cs-std", + "ark-std 0.5.0", +] + [[package]] name = "ark-ec" version = "0.3.0" @@ -422,7 +517,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defd9a439d56ac24968cca0571f598a61bc8c55f71d50a89cda591cb750670ba" dependencies = [ "ark-ff 0.4.2", - "ark-poly", + "ark-poly 0.4.2", "ark-serialize 0.4.2", "ark-std 0.4.0", "derivative", @@ -432,6 +527,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff 0.5.0", + "ark-poly 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.4", + "itertools 0.13.0", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "zeroize", +] + [[package]] name = "ark-ff" version = "0.3.0" @@ -470,6 +586,26 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm 0.5.0", + "ark-ff-macros 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "educe", + "itertools 0.13.0", + "num-bigint 0.4.6", + "num-traits", + "paste", + "zeroize", +] + [[package]] name = "ark-ff-asm" version = "0.3.0" @@ -490,6 +626,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.104", +] + [[package]] name = "ark-ff-macros" version = "0.3.0" @@ -515,6 +661,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint 0.4.6", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "ark-poly" version = "0.4.2" @@ -528,6 +687,50 @@ dependencies = [ "hashbrown 0.13.2", ] +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.4", +] + +[[package]] +name = "ark-r1cs-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "941551ef1df4c7a401de7068758db6503598e6f01850bdb2cfdb614a1f9dbea1" +dependencies = [ + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-relations", + "ark-std 0.5.0", + "educe", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "tracing", +] + +[[package]] +name = "ark-relations" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec46ddc93e7af44bcab5230937635b06fb5744464dd6a7e7b083e80ebd274384" +dependencies = [ + "ark-ff 0.5.0", + "ark-std 0.5.0", + "tracing", + "tracing-subscriber 0.2.25", +] + [[package]] name = "ark-serialize" version = "0.3.0" @@ -544,12 +747,25 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" dependencies = [ - "ark-serialize-derive", + "ark-serialize-derive 0.4.2", "ark-std 0.4.0", "digest 0.10.7", "num-bigint 0.4.6", ] +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive 0.5.0", + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "num-bigint 0.4.6", +] + [[package]] name = "ark-serialize-derive" version = "0.4.2" @@ -561,6 +777,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "ark-std" version = "0.3.0" @@ -581,6 +808,16 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + [[package]] name = "arrayref" version = "0.3.9" @@ -604,24 +841,30 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.86" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "atomic" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" +checksum = "a89cbf775b137e9b968e67227ef7f775587cde3fd31b0d8599dbd0f598a48340" dependencies = [ "bytemuck", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "aurora-engine-modexp" version = "1.2.0" @@ -634,26 +877,26 @@ dependencies = [ [[package]] name = "auto_impl" -version = "1.2.1" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e12882f59de5360c748c4cbf569a042d5fb0eb515f7bea9c1f470b47f6ffbd73" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "autocfg" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.5.18" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90aff65e86db5fe300752551c1b015ef72b708ac54bded8ef43d0d53cb7cb0b1" +checksum = "ebd9b83179adf8998576317ce47785948bcff399ec5b15f4dfbdedd44ddf5b92" dependencies = [ "aws-credential-types", "aws-runtime", @@ -661,7 +904,7 @@ dependencies = [ "aws-sdk-ssooidc", "aws-sdk-sts", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -670,7 +913,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 0.2.12", + "http 1.3.1", "ring", "time", "tokio", @@ -681,9 +924,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.1" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" +checksum = "b68c2194a190e1efc999612792e25b1ab3abfefe4306494efaaabc25933c0cbe" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -691,17 +934,40 @@ dependencies = [ "zeroize", ] +[[package]] +name = "aws-lc-rs" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08b5d4e069cbc868041a64bd68dc8cb39a0d79585cd6c5a24caa8c2d622121be" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "aws-runtime" -version = "1.5.5" +version = "1.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76dd04d39cc12844c0994f2c9c5a6f5184c22e9188ec1ff723de41910a21dcad" +checksum = "b2090e664216c78e766b6bac10fe74d2f451c02441d43484cd76ac9a295075f7" dependencies = [ "aws-credential-types", "aws-sigv4", "aws-smithy-async", "aws-smithy-eventstream", - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -710,7 +976,6 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -719,9 +984,9 @@ dependencies = [ [[package]] name = "aws-sdk-s3" -version = "1.78.0" +version = "1.98.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3038614b6cf7dd68d9a7b5b39563d04337eb3678d1d4173e356e927b0356158a" +checksum = "029e89cae7e628553643aecb3a3f054a0a0912ff0fd1f5d6a0b4fda421dce64b" dependencies = [ "aws-credential-types", "aws-runtime", @@ -729,7 +994,7 @@ dependencies = [ "aws-smithy-async", "aws-smithy-checksums", "aws-smithy-eventstream", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -741,70 +1006,70 @@ dependencies = [ "hex", "hmac", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "lru", - "once_cell", "percent-encoding", "regex-lite", - "sha2", + "sha2 0.10.9", "tracing", "url", ] [[package]] name = "aws-sdk-sso" -version = "1.61.0" +version = "1.76.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e65ff295979977039a25f5a0bf067a64bc5e6aa38f3cef4037cf42516265553c" +checksum = "64bf26698dd6d238ef1486bdda46f22a589dc813368ba868dc3d94c8d27b56ba" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.62.0" +version = "1.77.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91430a60f754f235688387b75ee798ef00cfd09709a582be2b7525ebb5306d4f" +checksum = "09cd07ed1edd939fae854a22054299ae3576500f4e0fadc560ca44f9c6ea1664" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.62.0" +version = "1.78.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9276e139d39fff5a0b0c984fc2d30f970f9a202da67234f948fda02e5bea1dbe" +checksum = "37f7766d2344f56d10d12f3c32993da36d78217f32594fe4fb8e57a538c1cdea" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-query", "aws-smithy-runtime", @@ -812,21 +1077,21 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", + "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.2.9" +version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051" +checksum = "ddfb9021f581b71870a17eac25b52335b82211cdc092e02b6876b2bcefa61666" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", @@ -835,12 +1100,11 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.2.0", - "once_cell", + "http 1.3.1", "p256 0.11.1", "percent-encoding", "ring", - "sha2", + "sha2 0.10.9", "subtle", "time", "tracing", @@ -849,9 +1113,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa59d1327d8b5053c54bf2eaae63bf629ba9e904434d0835a28ed3c0ed0a614e" +checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" dependencies = [ "futures-util", "pin-project-lite", @@ -860,31 +1124,29 @@ dependencies = [ [[package]] name = "aws-smithy-checksums" -version = "0.63.0" +version = "0.63.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2dc8d842d872529355c72632de49ef8c5a2949a4472f10e802f28cf925770c" +checksum = "5ab9472f7a8ec259ddb5681d2ef1cb1cf16c0411890063e67cdc7b62562cc496" dependencies = [ - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-types", "bytes", - "crc32c", - "crc32fast", - "crc64fast-nvme", + "crc-fast", "hex", "http 0.2.12", "http-body 0.4.6", "md-5", "pin-project-lite", "sha1", - "sha2", + "sha2 0.10.9", "tracing", ] [[package]] name = "aws-smithy-eventstream" -version = "0.60.7" +version = "0.60.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "461e5e02f9864cba17cff30f007c2e37ade94d01e87cdb5204e44a84e6d38c17" +checksum = "338a3642c399c0a5d157648426110e199ca7fd1c689cc395676b81aa563700c4" dependencies = [ "aws-smithy-types", "bytes", @@ -893,18 +1155,19 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.12" +version = "0.62.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7809c27ad8da6a6a68c454e651d4962479e81472aa19ae99e59f9aba1f9713cc" +checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" dependencies = [ + "aws-smithy-eventstream", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "pin-utils", @@ -912,35 +1175,52 @@ dependencies = [ ] [[package]] -name = "aws-smithy-http" -version = "0.61.1" +name = "aws-smithy-http-client" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6f276f21c7921fe902826618d1423ae5bf74cf8c1b8472aee8434f3dfd31824" +checksum = "f108f1ca850f3feef3009bdcc977be201bca9a91058864d9de0684e64514bee0" dependencies = [ - "aws-smithy-eventstream", + "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "bytes", - "bytes-utils", - "futures-core", + "h2 0.3.27", + "h2 0.4.11", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", - "once_cell", - "percent-encoding", + "hyper 0.14.32", + "hyper 1.6.0", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.7", + "hyper-util", "pin-project-lite", - "pin-utils", + "rustls 0.21.12", + "rustls 0.23.29", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tower", "tracing", ] [[package]] name = "aws-smithy-json" -version = "0.61.2" +version = "0.61.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "623a51127f24c30776c8b374295f2df78d92517386f77ba30773f15a30ce1422" +checksum = "a16e040799d29c17412943bdbf488fd75db04112d0c0d4b9290bacf5ae0014b9" dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-observability" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +dependencies = [ + "aws-smithy-runtime-api", +] + [[package]] name = "aws-smithy-query" version = "0.60.7" @@ -953,42 +1233,39 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.8" +version = "1.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d526a12d9ed61fadefda24abe2e682892ba288c2018bcb38b1b4c111d13f6d92" +checksum = "c3aaec682eb189e43c8a19c3dab2fe54590ad5f2cc2d26ab27608a20f2acf81c" dependencies = [ "aws-smithy-async", - "aws-smithy-http 0.60.12", + "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "fastrand", - "h2", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "httparse", - "hyper", - "hyper-rustls", - "once_cell", "pin-project-lite", "pin-utils", - "rustls", "tokio", "tracing", ] [[package]] name = "aws-smithy-runtime-api" -version = "1.7.3" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" +checksum = "9852b9226cb60b78ce9369022c0df678af1cac231c882d5da97a0c4e03be6e67" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -997,16 +1274,16 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.13" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7b8a53819e42f10d0821f56da995e1470b199686a1809168db6ca485665f042" +checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -1023,18 +1300,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.9" +version = "0.60.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +checksum = "3db87b96cb1b16c024980f133968d52882ca0daaee3a086c6decc500f6c99728" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.5" +version = "1.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbd0a668309ec1f66c0f6bda4840dd6d4796ae26d699ebc266d7cc95c6d040f" +checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -1046,9 +1323,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -1102,9 +1379,9 @@ dependencies = [ [[package]] name = "base64ct" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" [[package]] name = "bincode" @@ -1115,6 +1392,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.1", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.104", + "which", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -1147,9 +1447,9 @@ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bitcode" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18c1406a27371b2f76232a2259df6ab607b91b5a0a7476a7729ff590df5a969a" +checksum = "cf300f4aa6e66f3bdff11f1236a88c622fe47ea814524792240b4d554d9858ee" dependencies = [ "arrayvec", "bitcode_derive", @@ -1166,7 +1466,23 @@ checksum = "42b6b4cb608b8282dc3b53d0f4c9ab404655d562674c682db7e6c0458cc83c23" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "bitcoin-io" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b47c4ab7a93edb0c7198c5535ed9b52b63095f4e9b45279c6736cec4b856baf" + +[[package]] +name = "bitcoin_hashes" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb18c03d0db0247e147a21a6faafd5a7eb851c743db062de72018b6b7e8e4d16" +dependencies = [ + "bitcoin-io", + "hex-conservative", ] [[package]] @@ -1177,9 +1493,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bitvec" @@ -1215,16 +1531,24 @@ dependencies = [ [[package]] name = "blake3" -version = "1.6.0" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq 0.3.1", +] + +[[package]] +name = "block-buffer" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", - "memmap2", + "generic-array", ] [[package]] @@ -1251,9 +1575,9 @@ dependencies = [ [[package]] name = "blst" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47c79a94619fade3c0b887670333513a67ac28a6a7e653eb260bf0d4103db38d" +checksum = "4fd49896f12ac9b6dcd7a5998466b9b58263a695a3dd1ecc1aaca2e12a90b080" dependencies = [ "cc", "glob", @@ -1267,7 +1591,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34e20109dce74b02019885a01edc8ca485380a297ed8d6eb9e63e657774074b" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "js-sys", "primitive-types", "rustc-hex", @@ -1278,9 +1602,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.3.2" +version = "3.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7acc34ff59877422326db7d6f2d845a582b16396b6b08194942bf34c6528ab" +checksum = "f61138465baf186c63e8d9b6b613b508cd832cba4ce93cf37ce5f096f91ac1a6" dependencies = [ "bon-macros", "rustversion", @@ -1288,9 +1612,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.3.2" +version = "3.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4159dd617a7fbc9be6a692fe69dc2954f8e6bb6bb5e4d7578467441390d77fd0" +checksum = "40d1dad34aa19bf02295382f08d9bc40651585bd497266831d40ee6296fb49ca" dependencies = [ "darling", "ident_case", @@ -1298,7 +1622,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -1321,7 +1645,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -1342,21 +1666,21 @@ checksum = "b4ae4235e6dac0694637c763029ecea1a2ec9e4e06ec2729bd21ba4d9c863eb7" [[package]] name = "bumpalo" -version = "3.17.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "byte-slice-cast" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" +checksum = "7575182f7272186991736b70173b0ea045398f984bf5ebbb3804736ce1330c9d" [[package]] name = "bytemuck" -version = "1.21.0" +version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" +checksum = "5c76a5792e44e4abe34d3abf15636779261d45a7450612059293d1d2cfc63422" [[package]] name = "byteorder" @@ -1366,9 +1690,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" dependencies = [ "serde", ] @@ -1405,9 +1729,9 @@ dependencies = [ [[package]] name = "c-kzg" -version = "1.0.3" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0307f72feab3300336fb803a57134159f6e20139af1357f36c54cb90d8e8928" +checksum = "7318cfa722931cb5fe0838b98d3ce5621e75f6a6408abc21721d80de9223f2e4" dependencies = [ "blst", "cc", @@ -1420,16 +1744,16 @@ dependencies = [ [[package]] name = "camino" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" +checksum = "0da45bc31171d8d6960122e222a67740df867c1dd53b4d51caa297084c185cab" dependencies = [ "serde", ] [[package]] name = "cargo-openvm" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "aws-config", "aws-sdk-s3", @@ -1450,8 +1774,8 @@ dependencies = [ "target-lexicon 0.12.16", "tempfile", "tokio", - "toml 0.8.20", - "toml_edit 0.22.24", + "toml 0.8.23", + "toml_edit 0.22.27", "tracing", "vergen", ] @@ -1473,7 +1797,7 @@ checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037" dependencies = [ "camino", "cargo-platform", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "thiserror 1.0.69", @@ -1487,20 +1811,29 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" dependencies = [ "jobserver", "libc", "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" [[package]] name = "cfg_aliases" @@ -1510,15 +1843,15 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.39" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.6", + "windows-link", ] [[package]] @@ -1558,11 +1891,22 @@ dependencies = [ "inout", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" -version = "4.5.30" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -1570,39 +1914,117 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.30" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", "clap_lex", "strsim", + "terminal_size", ] [[package]] name = "clap_derive" -version = "4.5.28" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "clap_lex" -version = "0.7.4" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] +name = "codspeed" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7524e02ff6173bc143d9abc01b518711b77addb60de871bbe5686843f88fb48" +dependencies = [ + "anyhow", + "bincode", + "colored", + "glob", + "libc", + "nix", + "serde", + "serde_json", + "statrs", + "uuid", +] + +[[package]] +name = "codspeed-divan-compat" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "157f6307b7400d74f3e41bd429b751b53d05c138a6a0f35853055e2523440354" +dependencies = [ + "codspeed", + "codspeed-divan-compat-macros", + "codspeed-divan-compat-walltime", +] + +[[package]] +name = "codspeed-divan-compat-macros" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5e422ac666f5871ab86d17b0f7292696ef194138bab5b49f743d23799cd6c04" +dependencies = [ + "divan-macros", + "itertools 0.14.0", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "codspeed-divan-compat-walltime" +version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +checksum = "66715e496e52fe861695e2644577adc7573544a729585fba4737193a62fd5a8a" +dependencies = [ + "cfg-if", + "clap", + "codspeed", + "condtype", + "divan-macros", + "libc", + "regex-lite", +] [[package]] name = "colorchoice" -version = "1.0.3" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "colored" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys 0.48.0", +] [[package]] name = "concurrent-queue" @@ -1613,6 +2035,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "condtype" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" + [[package]] name = "const-default" version = "1.0.0" @@ -1621,9 +2049,9 @@ checksum = "0b396d1f76d455557e1218ec8066ae14bba60b4b36ecd55577ba979f5db7ecaa" [[package]] name = "const-hex" -version = "1.14.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b0485bab839b018a8f1723fc5391819fea5f8f0f32288ef8a735fd096b6160c" +checksum = "83e22e0ed40b96a48d3db274f72fd365bd78f67af39b6bbd47e8a15e1c6207ff" dependencies = [ "cfg-if", "cpufeatures", @@ -1686,6 +2114,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1703,9 +2141,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.2.1" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" dependencies = [ "crc-catalog", ] @@ -1717,12 +2155,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] -name = "crc32c" -version = "0.6.8" +name = "crc-fast" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +checksum = "6bf62af4cc77d8fe1c22dde4e721d87f2f54056139d8c412e1366b740305f56f" dependencies = [ - "rustc_version 0.4.1", + "crc", + "digest 0.10.7", + "libc", + "rand 0.9.1", + "regex", ] [[package]] @@ -1734,15 +2176,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crc64fast-nvme" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4955638f00a809894c947f85a024020a20815b65a5eea633798ea7924edab2b3" -dependencies = [ - "crc", -] - [[package]] name = "criterion" version = "0.5.1" @@ -1843,9 +2276,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-bigint" @@ -1883,9 +2316,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ "darling_core", "darling_macro", @@ -1893,27 +2326,42 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "darling_macro" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", + "rayon", ] [[package]] @@ -1928,9 +2376,9 @@ dependencies = [ [[package]] name = "der" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" dependencies = [ "const-oid", "pem-rfc7468", @@ -1939,9 +2387,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", "serde", @@ -1966,7 +2414,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -1977,20 +2425,31 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "derive-where" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "510c292c8cf384b1a340b816a9a6cf2599eb8f566a44949024af88418000c50b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", ] [[package]] name = "derive_more" -version = "0.99.19" +version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da29a38df43d6f156149c9b43ded5e018ddff2a855cf2cfd62e8cd7d079c69f" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ "convert_case", "proc-macro2", "quote", "rustc_version 0.4.1", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2019,7 +2478,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "unicode-xid", ] @@ -2031,30 +2490,30 @@ checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "unicode-xid", ] [[package]] name = "diesel" -version = "2.2.10" +version = "2.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff3e1edb1f37b4953dd5176916347289ed43d7119cc2e6c7c3f7849ff44ea506" +checksum = "229850a212cd9b84d4f0290ad9d294afc0ae70fccaa8949dbe8b43ffafa1e20c" dependencies = [ "diesel_derives", ] [[package]] name = "diesel_derives" -version = "2.2.5" +version = "2.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68d4216021b3ea446fd2047f5c8f8fe6e98af34508a254a01e4d6bc1e844f84d" +checksum = "1b96984c469425cb577bf6f17121ecb3e4fe1e81de5d8f780dd372802858d756" dependencies = [ "diesel_table_macro_syntax", "dsl_auto_type", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2063,7 +2522,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" dependencies = [ - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2081,7 +2540,7 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", + "block-buffer 0.10.4", "const-oid", "crypto-common", "subtle", @@ -2137,7 +2596,18 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "divan-macros" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dc51d98e636f5e3b0759a39257458b22619cac7e96d932da6eeb052891bb67c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", ] [[package]] @@ -2157,7 +2627,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2168,9 +2638,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "dyn-clone" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" [[package]] name = "ecdsa" @@ -2190,7 +2660,7 @@ version = "0.16.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ - "der 0.7.9", + "der 0.7.10", "digest 0.10.7", "elliptic-curve 0.13.8", "rfc6979 0.4.0", @@ -2199,11 +2669,23 @@ dependencies = [ "spki 0.7.3", ] +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "elf" @@ -2292,6 +2774,26 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "enum_dispatch" version = "0.3.13" @@ -2301,7 +2803,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2312,7 +2814,7 @@ checksum = "2f9ed6b3789237c8a0c1c505af1c7eb2c560df6186f01b098c3a1064ea532f38" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2326,9 +2828,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.6" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -2344,12 +2846,12 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -2447,7 +2949,7 @@ dependencies = [ "chrono", "ethers-core", "reqwest", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "thiserror 1.0.69", @@ -2474,10 +2976,10 @@ dependencies = [ "path-slash", "rayon", "regex", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", - "sha2", + "sha2 0.10.9", "solang-parser", "svm-rs", "svm-rs-builds", @@ -2592,7 +3094,7 @@ dependencies = [ "atomic", "pear", "serde", - "toml 0.8.20", + "toml 0.8.23", "uncased", "version_check", ] @@ -2617,9 +3119,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.1.0" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" dependencies = [ "crc32fast", "miniz_oxide", @@ -2633,9 +3135,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "forge-fmt" @@ -2682,7 +3184,7 @@ dependencies = [ "regex", "reqwest", "revm-primitives 1.3.0", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "serde_regex", @@ -2703,6 +3205,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -2750,7 +3258,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2801,39 +3309,39 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets 0.52.6", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", ] [[package]] name = "getset" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded738faa0e88d3abc9d1a13cb11adc2073c400969eeb8793cf7132589959fc" +checksum = "9cf0fc11e47561d47397154977bc219f4cf809b2974facc3ccb3b89e2436f912" dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2848,7 +3356,7 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b903b73e45dc0c6c596f2d37eccece7c1c8bb6e4407b001096387c63d0d93724" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "libc", "libgit2-sys", "log", @@ -2857,9 +3365,9 @@ dependencies = [ [[package]] name = "glam" -version = "0.30.0" +version = "0.30.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17fcdf9683c406c2fc4d124afd29c0d595e22210d633cbdb8695ba9935ab1dc6" +checksum = "50a99dbe56b72736564cfa4b85bf9a33079f16ae8b74983ab06af3b1a3696b11" [[package]] name = "glob" @@ -2905,9 +3413,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" dependencies = [ "bytes", "fnv", @@ -2915,7 +3423,26 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.7.1", + "indexmap 2.10.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17da50a276f1e01e0ba6c029e47b7100754904ee8a278f886546e98575380785" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.3.1", + "indexmap 2.10.0", "slab", "tokio", "tokio-util", @@ -2924,9 +3451,9 @@ dependencies = [ [[package]] name = "half" -version = "2.4.1" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -3041,7 +3568,7 @@ dependencies = [ "rayon", "serde", "serde_arrays", - "sha2", + "sha2 0.10.9", "static_assertions", "subtle", "unroll", @@ -3069,7 +3596,7 @@ dependencies = [ "rayon", "serde", "serde_arrays", - "sha2", + "sha2 0.10.9", "static_assertions", "subtle", "unroll", @@ -3116,9 +3643,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" dependencies = [ "allocator-api2", "equivalent", @@ -3132,7 +3659,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.4", ] [[package]] @@ -3143,15 +3670,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - -[[package]] -name = "hermit-abi" -version = "0.4.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" [[package]] name = "hex" @@ -3162,6 +3683,15 @@ dependencies = [ "serde", ] +[[package]] +name = "hex-conservative" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5313b072ce3c597065a808dbf612c4c8e8590bdbf8b579508bf7a762c5eae6cd" +dependencies = [ + "arrayvec", +] + [[package]] name = "hex-literal" version = "0.4.1" @@ -3214,9 +3744,9 @@ dependencies = [ [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -3241,27 +3771,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http 1.3.1", ] [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", - "http 1.2.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -3279,7 +3809,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.27", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -3293,6 +3823,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.11", + "http 1.3.1", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + [[package]] name = "hyper-rustls" version = "0.24.2" @@ -3300,25 +3850,64 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", - "http 0.2.12", - "hyper", - "log", - "rustls", - "rustls-native-certs", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http 1.3.1", + "hyper 1.6.0", + "hyper-util", + "rustls 0.23.29", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.2", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.6.0", + "libc", + "pin-project-lite", + "socket2", "tokio", - "tokio-rustls", + "tower-service", + "tracing", ] [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", "windows-core", ] @@ -3334,21 +3923,22 @@ dependencies = [ [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", "litemap", @@ -3357,31 +3947,11 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", "icu_collections", @@ -3389,67 +3959,54 @@ dependencies = [ "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "potential_utf", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", + "icu_locale_core", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.98", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -3469,9 +4026,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -3512,7 +4069,7 @@ checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -3553,12 +4110,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.1" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.4", "serde", ] @@ -3577,6 +4134,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "io-uring" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -3585,13 +4153,13 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" -version = "0.4.15" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ - "hermit-abi 0.4.0", + "hermit-abi", "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -3618,6 +4186,24 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -3629,16 +4215,17 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.3", "libc", ] @@ -3679,7 +4266,6 @@ dependencies = [ "num-bigint 0.4.6", "once_cell", "openvm", - "openvm-algebra-circuit", "openvm-algebra-guest", "openvm-algebra-moduli-macros", "openvm-algebra-transpiler", @@ -3688,7 +4274,6 @@ dependencies = [ "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", - "openvm-rv32im-circuit", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-transpiler", @@ -3696,6 +4281,7 @@ dependencies = [ "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", + "rand 0.8.5", "serde", "signature 2.2.0", ] @@ -3710,7 +4296,7 @@ dependencies = [ "ecdsa 0.16.9", "elliptic-curve 0.13.8", "once_cell", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -3771,11 +4357,17 @@ dependencies = [ "spin", ] +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" -version = "0.2.169" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libgit2-sys" @@ -3789,17 +4381,27 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets 0.48.5", +] + [[package]] name = "libm" -version = "0.2.11" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libmimalloc-sys" -version = "0.1.39" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" +checksum = "bf88cd67e9de251c1781dbe2f641a1a3ad66eaae831b8a2c38fbdc5ddae16d4d" dependencies = [ "cc", "libc", @@ -3807,19 +4409,65 @@ dependencies = [ [[package]] name = "libredox" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "libc", ] +[[package]] +name = "libsecp256k1" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e79019718125edc905a079a70cfa5f3820bc76139fc91d6f9abc27ea2a887139" +dependencies = [ + "arrayref", + "base64 0.22.1", + "digest 0.9.0", + "libsecp256k1-core", + "libsecp256k1-gen-ecmult", + "libsecp256k1-gen-genmult", + "rand 0.8.5", + "serde", + "sha2 0.9.9", +] + +[[package]] +name = "libsecp256k1-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5be9b9bb642d8522a44d533eab56c16c738301965504753b03ad1de3425d5451" +dependencies = [ + "crunchy", + "digest 0.9.0", + "subtle", +] + +[[package]] +name = "libsecp256k1-gen-ecmult" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3038c808c55c87e8a172643a7d87187fc6c4174468159cb3090659d55bcb4809" +dependencies = [ + "libsecp256k1-core", +] + +[[package]] +name = "libsecp256k1-gen-genmult" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3db8d6ba2cec9eacc40e6e8ccc98931840301f1006e95647ceb2dd5c3aa06f7c" +dependencies = [ + "libsecp256k1-core", +] + [[package]] name = "libz-sys" -version = "1.1.21" +version = "1.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df9b68e50e6e0b26f672573834882eb57759f6db9b3be2ea3c35c91188bb4eaa" +checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d" dependencies = [ "cc", "libc", @@ -3839,17 +4487,23 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "litemap" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -3863,9 +4517,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.25" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "lru" @@ -3873,7 +4527,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.4", ] [[package]] @@ -3884,7 +4538,7 @@ checksum = "1b27834086c65ec3f9387b096d66e99f221cf081c2b738042aa252bcd41204e3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -3918,9 +4572,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.4" +version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" [[package]] name = "memmap2" @@ -3948,9 +4602,9 @@ checksum = "3d97bbf43eb4f088f8ca469930cde17fa036207c9a5e02ccc5107c4e8b17c964" [[package]] name = "metrics" -version = "0.23.0" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" +checksum = "3045b4193fbdc5b5681f32f11070da9be3609f189a79f3390706d42587f46bb5" dependencies = [ "ahash", "portable-atomic", @@ -3962,7 +4616,7 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62a6a1f7141f1d9bc7a886b87536bbfc97752e08b369e1e0453a9acfab5f5da4" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "itoa", "lockfree-object-pool", "metrics", @@ -3970,7 +4624,7 @@ dependencies = [ "once_cell", "tracing", "tracing-core", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -3983,7 +4637,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "indexmap 2.7.1", + "indexmap 2.10.0", "metrics", "num_cpus", "ordered-float", @@ -3994,9 +4648,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.43" +version = "0.1.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" +checksum = "b1791cbe101e95af5764f06f20f6760521f7158f69dbf9d6baf941ee1bf6bc40" dependencies = [ "libmimalloc-sys", ] @@ -4007,24 +4661,30 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" -version = "0.8.4" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", ] [[package]] name = "mio" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", ] [[package]] @@ -4042,6 +4702,28 @@ dependencies = [ "smallvec", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -4184,33 +4866,34 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" dependencies = [ - "hermit-abi 0.3.9", + "hermit-abi", "libc", ] [[package]] name = "num_enum" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +checksum = "a973b4e44ce6cad84ce69d797acf9a044532e4184c4f267913d1b546a0727b7a" dependencies = [ "num_enum_derive", + "rustversion", ] [[package]] name = "num_enum_derive" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +checksum = "77e878c846a8abae00dd069496dbe8751b16ac1c3d6bd2a7283a938e8228f90d" dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -4251,19 +4934,31 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" dependencies = [ "critical-section", "portable-atomic", ] +[[package]] +name = "once_cell_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" + [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "opaque-debug" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "open-fastrlp" @@ -4298,12 +4993,12 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openvm" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "bytemuck", "chrono", - "getrandom 0.2.15", - "getrandom 0.3.1", + "getrandom 0.2.16", + "getrandom 0.3.3", "num-bigint 0.4.6", "openvm-custom-insn", "openvm-platform", @@ -4313,7 +5008,7 @@ dependencies = [ [[package]] name = "openvm-algebra-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -4336,23 +5031,23 @@ dependencies = [ "openvm-stark-sdk", "rand 0.8.5", "serde", - "serde-big-array", "serde_with", "strum", + "test-case", ] [[package]] name = "openvm-algebra-complex-macros" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-algebra-guest" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "halo2curves-axiom", "num-bigint 0.4.6", @@ -4367,18 +5062,18 @@ dependencies = [ [[package]] name = "openvm-algebra-moduli-macros" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "num-bigint 0.4.6", "num-prime", "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-algebra-tests" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "eyre", "num-bigint 0.4.6", @@ -4395,7 +5090,7 @@ dependencies = [ [[package]] name = "openvm-algebra-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-algebra-guest", "openvm-instructions", @@ -4408,57 +5103,61 @@ dependencies = [ [[package]] name = "openvm-benchmarks-execute" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ - "cargo-openvm", + "bitcode", "clap", - "criterion", + "codspeed-divan-compat", "derive_more 1.0.0", "eyre", + "openvm-algebra-circuit", + "openvm-algebra-transpiler", "openvm-benchmarks-utils", + "openvm-bigint-circuit", + "openvm-bigint-transpiler", "openvm-circuit", + "openvm-continuations", + "openvm-ecc-circuit", + "openvm-ecc-transpiler", "openvm-keccak256-circuit", "openvm-keccak256-transpiler", + "openvm-native-circuit", + "openvm-pairing-circuit", + "openvm-pairing-guest", + "openvm-pairing-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", "openvm-sdk", + "openvm-sha256-circuit", + "openvm-sha256-transpiler", "openvm-stark-sdk", "openvm-transpiler", + "rand 0.8.5", + "serde", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] name = "openvm-benchmarks-prove" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "clap", - "derive-new 0.6.0", "derive_more 1.0.0", "eyre", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", - "num-bigint 0.4.6", - "openvm-algebra-circuit", - "openvm-algebra-transpiler", "openvm-benchmarks-utils", "openvm-circuit", - "openvm-ecc-circuit", - "openvm-ecc-transpiler", - "openvm-keccak256-circuit", - "openvm-keccak256-transpiler", + "openvm-continuations", "openvm-native-circuit", "openvm-native-compiler", "openvm-native-recursion", - "openvm-pairing-circuit", - "openvm-pairing-guest", - "openvm-rv32im-circuit", - "openvm-rv32im-transpiler", "openvm-sdk", "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", + "rand 0.8.5", "rand_chacha 0.3.1", - "serde", "tiny-keccak", "tokio", "tracing", @@ -4466,22 +5165,29 @@ dependencies = [ [[package]] name = "openvm-benchmarks-utils" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ + "bitcode", "cargo_metadata", "clap", "eyre", "openvm-build", + "openvm-circuit", + "openvm-continuations", + "openvm-native-circuit", + "openvm-sdk", + "openvm-stark-sdk", "openvm-transpiler", "tempfile", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] name = "openvm-bigint-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ + "alloy-primitives 1.2.1", "derive-new 0.6.0", "derive_more 1.0.0", "openvm-bigint-transpiler", @@ -4497,11 +5203,12 @@ dependencies = [ "openvm-stark-sdk", "rand 0.8.5", "serde", + "test-case", ] [[package]] name = "openvm-bigint-guest" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-platform", "strum_macros", @@ -4509,7 +5216,7 @@ dependencies = [ [[package]] name = "openvm-bigint-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-bigint-guest", "openvm-instructions", @@ -4523,7 +5230,7 @@ dependencies = [ [[package]] name = "openvm-build" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "cargo_metadata", "eyre", @@ -4534,10 +5241,10 @@ dependencies = [ [[package]] name = "openvm-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "backtrace", - "cfg-if", + "dashmap", "derivative", "derive-new 0.6.0", "derive_more 1.0.0", @@ -4545,6 +5252,7 @@ dependencies = [ "eyre", "getset", "itertools 0.14.0", + "memmap2", "metrics", "openvm-circuit", "openvm-circuit-derive", @@ -4570,16 +5278,17 @@ dependencies = [ [[package]] name = "openvm-circuit-derive" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "itertools 0.14.0", + "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-circuit-primitives" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "itertools 0.14.0", @@ -4595,16 +5304,16 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives-derive" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "itertools 0.14.0", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-continuations" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derivative", "openvm-circuit", @@ -4622,32 +5331,34 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-ecc-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", + "halo2curves-axiom", "hex-literal 0.4.1", "lazy_static", "num-bigint 0.4.6", + "num-integer", "num-traits", "once_cell", "openvm-algebra-circuit", "openvm-circuit", "openvm-circuit-derive", "openvm-circuit-primitives", - "openvm-circuit-primitives-derive", + "openvm-ecc-guest", "openvm-ecc-transpiler", "openvm-instructions", "openvm-mod-circuit-builder", "openvm-rv32-adapters", - "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", + "rand 0.8.5", "serde", "serde_with", "strum", @@ -4655,17 +5366,22 @@ dependencies = [ [[package]] name = "openvm-ecc-guest" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "ecdsa 0.16.9", "elliptic-curve 0.13.8", "group 0.13.0", "halo2curves-axiom", + "hex-literal 0.4.1", + "lazy_static", + "num-bigint 0.4.6", "once_cell", "openvm", "openvm-algebra-guest", + "openvm-algebra-moduli-macros", "openvm-custom-insn", "openvm-ecc-sw-macros", + "openvm-ecc-te-macros", "openvm-rv32im-guest", "serde", "strum_macros", @@ -4673,7 +5389,7 @@ dependencies = [ [[package]] name = "openvm-ecc-integration-tests" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "eyre", "halo2curves-axiom", @@ -4690,21 +5406,30 @@ dependencies = [ "openvm-transpiler", "serde", "serde_with", - "toml 0.8.20", + "toml 0.8.23", +] + +[[package]] +name = "openvm-ecc-sw-macros" +version = "1.4.0-rc.2" +dependencies = [ + "openvm-macros-common", + "quote", + "syn 2.0.104", ] [[package]] -name = "openvm-ecc-sw-macros" -version = "1.3.0" +name = "openvm-ecc-te-macros" +version = "1.4.0-rc.2" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-ecc-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-ecc-guest", "openvm-instructions", @@ -4717,7 +5442,7 @@ dependencies = [ [[package]] name = "openvm-ff-derive" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "addchain", "eyre", @@ -4740,7 +5465,7 @@ dependencies = [ [[package]] name = "openvm-instructions" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "backtrace", "bitcode", @@ -4760,18 +5485,18 @@ dependencies = [ [[package]] name = "openvm-instructions-derive" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-instructions", "quote", "strum", "strum_macros", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-keccak256" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "eyre", "openvm-circuit", @@ -4788,7 +5513,7 @@ dependencies = [ [[package]] name = "openvm-keccak256-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -4806,22 +5531,20 @@ dependencies = [ "p3-keccak-air", "rand 0.8.5", "serde", - "serde-big-array", "strum", "tiny-keccak", - "tracing", ] [[package]] name = "openvm-keccak256-guest" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-platform", ] [[package]] name = "openvm-keccak256-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -4834,14 +5557,14 @@ dependencies = [ [[package]] name = "openvm-macros-common" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-mod-circuit-builder" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "halo2curves-axiom", "itertools 0.14.0", @@ -4854,14 +5577,12 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", - "serde", - "serde_with", "tracing", ] [[package]] name = "openvm-native-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -4875,19 +5596,19 @@ dependencies = [ "openvm-native-compiler", "openvm-poseidon2-air", "openvm-rv32im-circuit", + "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", "serde", - "serde-big-array", "static_assertions", "strum", - "tracing", + "test-case", ] [[package]] name = "openvm-native-compiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "backtrace", "itertools 0.14.0", @@ -4913,15 +5634,15 @@ dependencies = [ [[package]] name = "openvm-native-compiler-derive" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-native-recursion" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "bitcode", "cfg-if", @@ -4951,7 +5672,7 @@ dependencies = [ [[package]] name = "openvm-native-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-instructions", "openvm-transpiler", @@ -4960,7 +5681,7 @@ dependencies = [ [[package]] name = "openvm-pairing" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "eyre", "group 0.13.0", @@ -4982,6 +5703,7 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-instructions", + "openvm-pairing", "openvm-pairing-circuit", "openvm-pairing-guest", "openvm-pairing-transpiler", @@ -4997,27 +5719,24 @@ dependencies = [ [[package]] name = "openvm-pairing-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", "eyre", "halo2curves-axiom", - "itertools 0.14.0", "num-bigint 0.4.6", "num-traits", "openvm-algebra-circuit", "openvm-circuit", "openvm-circuit-derive", "openvm-circuit-primitives", - "openvm-circuit-primitives-derive", "openvm-ecc-circuit", "openvm-ecc-guest", "openvm-instructions", "openvm-mod-circuit-builder", "openvm-pairing-guest", "openvm-pairing-transpiler", - "openvm-rv32-adapters", "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", @@ -5028,7 +5747,7 @@ dependencies = [ [[package]] name = "openvm-pairing-guest" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "halo2curves-axiom", "hex-literal 0.4.1", @@ -5049,10 +5768,9 @@ dependencies = [ [[package]] name = "openvm-pairing-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-instructions", - "openvm-instructions-derive", "openvm-pairing-guest", "openvm-stark-backend", "openvm-transpiler", @@ -5062,7 +5780,7 @@ dependencies = [ [[package]] name = "openvm-platform" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "critical-section", "embedded-alloc", @@ -5073,7 +5791,7 @@ dependencies = [ [[package]] name = "openvm-poseidon2-air" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derivative", "lazy_static", @@ -5089,7 +5807,7 @@ dependencies = [ [[package]] name = "openvm-prof" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "clap", "eyre", @@ -5102,7 +5820,7 @@ dependencies = [ [[package]] name = "openvm-rv32-adapters" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "itertools 0.14.0", @@ -5114,14 +5832,11 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", - "serde", - "serde-big-array", - "serde_with", ] [[package]] name = "openvm-rv32im-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -5138,13 +5853,13 @@ dependencies = [ "openvm-stark-sdk", "rand 0.8.5", "serde", - "serde-big-array", "strum", + "test-case", ] [[package]] name = "openvm-rv32im-guest" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-custom-insn", "p3-field", @@ -5153,7 +5868,7 @@ dependencies = [ [[package]] name = "openvm-rv32im-integration-tests" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "eyre", "openvm", @@ -5166,12 +5881,13 @@ dependencies = [ "openvm-toolchain-tests", "openvm-transpiler", "serde", + "strum", "test-case", ] [[package]] name = "openvm-rv32im-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -5186,10 +5902,9 @@ dependencies = [ [[package]] name = "openvm-sdk" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "alloy-sol-types", - "async-trait", "bitcode", "bon", "clap", @@ -5228,6 +5943,7 @@ dependencies = [ "openvm-stark-sdk", "openvm-transpiler", "p3-fri", + "rand 0.8.5", "rrs-lib", "serde", "serde_json", @@ -5236,12 +5952,13 @@ dependencies = [ "snark-verifier-sdk", "tempfile", "thiserror 1.0.69", + "toml 0.8.23", "tracing", ] [[package]] name = "openvm-sha2" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "eyre", "openvm-circuit", @@ -5253,31 +5970,30 @@ dependencies = [ "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", - "sha2", + "sha2 0.10.9", ] [[package]] name = "openvm-sha256-air" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-circuit", "openvm-circuit-primitives", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", - "sha2", + "sha2 0.10.9", ] [[package]] name = "openvm-sha256-circuit" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", "openvm-circuit", "openvm-circuit-derive", "openvm-circuit-primitives", - "openvm-circuit-primitives-derive", "openvm-instructions", "openvm-rv32im-circuit", "openvm-sha256-air", @@ -5286,20 +6002,20 @@ dependencies = [ "openvm-stark-sdk", "rand 0.8.5", "serde", - "sha2", + "sha2 0.10.9", "strum", ] [[package]] name = "openvm-sha256-guest" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-platform", ] [[package]] name = "openvm-sha256-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -5312,8 +6028,8 @@ dependencies = [ [[package]] name = "openvm-stark-backend" -version = "1.1.1" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.1.1#0879de162658b797b8dd6b6ee4429cbb8dd78ba1" +version = "1.2.0-rc.0" +source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.2.0-rc.1#2bf6fd20e3c77cabe01f830d06e6439ea101f98e" dependencies = [ "bitcode", "cfg-if", @@ -5340,11 +6056,12 @@ dependencies = [ [[package]] name = "openvm-stark-sdk" -version = "1.1.1" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.1.1#0879de162658b797b8dd6b6ee4429cbb8dd78ba1" +version = "1.2.0-rc.0" +source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.2.0-rc.1#2bf6fd20e3c77cabe01f830d06e6439ea101f98e" dependencies = [ + "dashmap", "derivative", - "derive_more 0.99.19", + "derive_more 0.99.20", "ff 0.13.1", "itertools 0.14.0", "metrics", @@ -5367,16 +6084,16 @@ dependencies = [ "serde", "serde_json", "static_assertions", - "toml 0.8.20", + "toml 0.8.23", "tracing", "tracing-forest", - "tracing-subscriber", + "tracing-subscriber 0.3.19", "zkhash", ] [[package]] name = "openvm-toolchain-tests" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "derive_more 1.0.0", "eyre", @@ -5394,6 +6111,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", + "rand 0.8.5", "serde", "tempfile", "test-case", @@ -5401,7 +6119,7 @@ dependencies = [ [[package]] name = "openvm-transpiler" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "elf", "eyre", @@ -5415,7 +6133,7 @@ dependencies = [ [[package]] name = "openvm-verify-stark" -version = "1.3.0" +version = "1.4.0-rc.2" dependencies = [ "eyre", "openvm-circuit", @@ -5462,7 +6180,7 @@ checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" dependencies = [ "ecdsa 0.14.8", "elliptic-curve 0.12.3", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -5477,7 +6195,6 @@ dependencies = [ "hex-literal 0.4.1", "num-bigint 0.4.6", "openvm", - "openvm-algebra-circuit", "openvm-algebra-guest", "openvm-algebra-moduli-macros", "openvm-algebra-transpiler", @@ -5486,7 +6203,6 @@ dependencies = [ "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", - "openvm-rv32im-circuit", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-transpiler", @@ -5494,9 +6210,22 @@ dependencies = [ "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", + "rand 0.8.5", "serde", ] +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa 0.16.9", + "elliptic-curve 0.13.8", + "primeorder", + "sha2 0.10.9", +] + [[package]] name = "p3-air" version = "0.1.0" @@ -5858,9 +6587,9 @@ dependencies = [ [[package]] name = "parity-scale-codec" -version = "3.7.4" +version = "3.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9fde3d0718baf5bc92f577d652001da0f8d54cd03a7974e118d04fc888dc23d" +checksum = "799781ae679d79a948e13d4824a40970bfa500058d245760dd857301059810fa" dependencies = [ "arrayvec", "bitvec", @@ -5874,14 +6603,14 @@ dependencies = [ [[package]] name = "parity-scale-codec-derive" -version = "3.7.4" +version = "3.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581c837bb6b9541ce7faa9377c20616e4fb7650f6b0f68bc93c827ee504fb7b3" +checksum = "34b4653168b563151153c9e4c08ebed57fb8262bebfa79711552fa983c623e7a" dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -5892,9 +6621,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -5902,9 +6631,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -5975,7 +6704,7 @@ dependencies = [ "digest 0.10.7", "hmac", "password-hash", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -5998,7 +6727,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -6018,12 +6747,12 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" -version = "2.7.15" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" +checksum = "1db05f56d34358a8b1066f67cbb203ee3e7ed2ba674a6263a1d5ec6db2204323" dependencies = [ "memchr", - "thiserror 2.0.11", + "thiserror 2.0.12", "ucd-trie", ] @@ -6034,7 +6763,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.7.1", + "indexmap 2.10.0", ] [[package]] @@ -6067,7 +6796,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -6107,15 +6836,15 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ - "der 0.7.9", + "der 0.7.10", "spki 0.7.3", ] [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plotters" @@ -6147,9 +6876,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "poseidon-primitives" @@ -6162,7 +6891,7 @@ dependencies = [ "lazy_static", "log", "rand 0.8.5", - "rand_xorshift", + "rand_xorshift 0.3.0", "thiserror 1.0.69", ] @@ -6194,7 +6923,7 @@ dependencies = [ "md-5", "memchr", "rand 0.9.1", - "sha2", + "sha2 0.10.9", "stringprep", ] @@ -6209,6 +6938,15 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "potential_utf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -6217,9 +6955,9 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ "zerocopy", ] @@ -6232,12 +6970,21 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "prettyplease" -version = "0.2.29" +version = "0.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" dependencies = [ "proc-macro2", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve 0.13.8", ] [[package]] @@ -6256,11 +7003,11 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" dependencies = [ - "toml_edit 0.22.24", + "toml_edit 0.22.27", ] [[package]] @@ -6282,14 +7029,14 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -6302,25 +7049,25 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "version_check", "yansi 1.0.1", ] [[package]] name = "proptest" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" +checksum = "6fcdab19deb5195a31cf7726a210015ff1496ba1464fd42cb4f537b8b01b471f" dependencies = [ "bit-set 0.8.0", "bit-vec 0.8.0", - "bitflags 2.8.0", + "bitflags 2.9.1", "lazy_static", "num-traits", - "rand 0.8.5", - "rand_chacha 0.3.1", - "rand_xorshift", + "rand 0.9.1", + "rand_chacha 0.9.0", + "rand_xorshift 0.4.0", "regex-syntax 0.8.5", "rusty-fork", "tempfile", @@ -6329,9 +7076,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.25.0" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f239d656363bcee73afef85277f1b281e8ac6212a1d42aa90e55b90ed43c47a4" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" dependencies = [ "libc", "memoffset", @@ -6343,9 +7090,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.25.0" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "755ea671a1c34044fa165247aaf6f419ca39caa6003aee791a0df2713d8f1b6d" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" dependencies = [ "once_cell", "target-lexicon 0.13.2", @@ -6353,9 +7100,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.25.0" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc95a2e67091e44791d4ea300ff744be5293f394f1bafd9f78c080814d35956e" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" dependencies = [ "libc", "pyo3-build-config", @@ -6363,15 +7110,15 @@ dependencies = [ [[package]] name = "quanta" -version = "0.12.5" +version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" dependencies = [ "crossbeam-utils", "libc", "once_cell", "raw-cpuid", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -6393,13 +7140,19 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "radium" version = "0.7.0" @@ -6436,6 +7189,7 @@ checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", + "serde", ] [[package]] @@ -6464,7 +7218,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -6473,7 +7227,8 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.1", + "getrandom 0.3.3", + "serde", ] [[package]] @@ -6485,13 +7240,22 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "raw-cpuid" -version = "11.4.0" +version = "11.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc" +checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] [[package]] @@ -6516,11 +7280,11 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.11" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" +checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] [[package]] @@ -6529,11 +7293,31 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "libredox", "thiserror 1.0.69", ] +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "regex" version = "1.11.1" @@ -6595,11 +7379,11 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.27", "http 0.2.12", "http-body 0.4.6", - "hyper", - "hyper-rustls", + "hyper 0.14.32", + "hyper-rustls 0.24.2", "ipnet", "js-sys", "log", @@ -6607,7 +7391,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.12", "rustls-pemfile", "serde", "serde_json", @@ -6615,7 +7399,7 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tower-service", "url", "wasm-bindgen", @@ -6627,46 +7411,164 @@ dependencies = [ [[package]] name = "revm" -version = "18.0.0" +version = "24.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15689a3c6a8d14b647b4666f2e236ef47b5a5133cdfd423f545947986fff7013" +checksum = "01d277408ff8d6f747665ad9e52150ab4caf8d5eaf0d787614cf84633c8337b4" +dependencies = [ + "revm-bytecode", + "revm-context", + "revm-context-interface", + "revm-database", + "revm-database-interface", + "revm-handler", + "revm-inspector", + "revm-interpreter", + "revm-precompile", + "revm-primitives 19.2.0", + "revm-state", +] + +[[package]] +name = "revm-bytecode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "942fe4724cf552fd28db6b0a2ca5b79e884d40dd8288a4027ed1e9090e0c6f49" +dependencies = [ + "bitvec", + "once_cell", + "phf", + "revm-primitives 19.2.0", + "serde", +] + +[[package]] +name = "revm-context" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01aad49e1233f94cebda48a4e5cef022f7c7ed29b4edf0d202b081af23435ef" dependencies = [ - "auto_impl", "cfg-if", - "dyn-clone", + "derive-where", + "revm-bytecode", + "revm-context-interface", + "revm-database-interface", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-context-interface" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b844f48a411e62c7dde0f757bf5cce49c85b86d6fc1d3b2722c07f2bec4c3ce" +dependencies = [ + "alloy-eip2930", + "alloy-eip7702", + "auto_impl", + "either", + "revm-database-interface", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-database" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad3fbe34f6bb00a9c3155723b3718b9cb9f17066ba38f9eb101b678cd3626775" +dependencies = [ + "alloy-eips", + "revm-bytecode", + "revm-database-interface", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-database-interface" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b8acd36784a6d95d5b9e1b7be3ce014f1e759abb59df1fa08396b30f71adc2a" +dependencies = [ + "auto_impl", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-handler" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "481e8c3290ff4fa1c066592fdfeb2b172edfd14d12e6cade6f6f5588cad9359a" +dependencies = [ + "auto_impl", + "revm-bytecode", + "revm-context", + "revm-context-interface", + "revm-database-interface", "revm-interpreter", "revm-precompile", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-inspector" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc1167ef8937d8867888e63581d8ece729a72073d322119ef4627d813d99ecb" +dependencies = [ + "auto_impl", + "revm-context", + "revm-database-interface", + "revm-handler", + "revm-interpreter", + "revm-primitives 19.2.0", + "revm-state", "serde", "serde_json", ] [[package]] name = "revm-interpreter" -version = "14.0.0" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74e3f11d0fed049a4a10f79820c59113a79b38aed4ebec786a79d5c667bfeb51" +checksum = "b5ee65e57375c6639b0f50555e92a4f1b2434349dd32f52e2176f5c711171697" dependencies = [ - "revm-primitives 14.0.0", + "revm-bytecode", + "revm-context-interface", + "revm-primitives 19.2.0", "serde", ] [[package]] name = "revm-precompile" -version = "15.0.0" +version = "21.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e381060af24b750069a2b2d2c54bba273d84e8f5f9e8026fc9262298e26cc336" +checksum = "0f9311e735123d8d53a02af2aa81877bba185be7c141be7f931bb3d2f3af449c" dependencies = [ + "ark-bls12-381", + "ark-bn254 0.5.0", + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "aurora-engine-modexp", "blst", "c-kzg", "cfg-if", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", + "libsecp256k1", "once_cell", - "revm-primitives 14.0.0", + "p256 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)", + "revm-primitives 19.2.0", "ripemd", "secp256k1", - "sha2", - "substrate-bn", + "sha2 0.10.9", ] [[package]] @@ -6678,7 +7580,7 @@ dependencies = [ "alloy-primitives 0.4.2", "alloy-rlp", "auto_impl", - "bitflags 2.8.0", + "bitflags 2.9.1", "bitvec", "enumn", "hashbrown 0.14.5", @@ -6687,21 +7589,24 @@ dependencies = [ [[package]] name = "revm-primitives" -version = "14.0.0" +version = "19.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3702f132bb484f4f0d0ca4f6fbde3c82cfd745041abbedd6eda67730e1868ef0" +checksum = "1c1588093530ec4442461163be49c433c07a3235d1ca6f6799fef338dacc50d3" dependencies = [ - "alloy-eip2930", - "alloy-eip7702", - "alloy-primitives 0.8.25", - "auto_impl", - "bitflags 2.8.0", - "bitvec", - "c-kzg", - "cfg-if", - "dyn-clone", - "enumn", - "hex", + "alloy-primitives 1.2.1", + "num_enum", + "serde", +] + +[[package]] +name = "revm-state" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0040c61c30319254b34507383ba33d85f92949933adf6525a2cede05d165e1fa" +dependencies = [ + "bitflags 2.9.1", + "revm-bytecode", + "revm-primitives 19.2.0", "serde", ] @@ -6728,13 +7633,13 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -6793,30 +7698,6 @@ dependencies = [ "paste", ] -[[package]] -name = "ruint" -version = "1.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c3cc4c2511671f327125da14133d0c5c5d137f006a1017a16f557bc85b16286" -dependencies = [ - "alloy-rlp", - "ark-ff 0.3.0", - "ark-ff 0.4.2", - "bytes", - "fastrlp 0.3.1", - "num-bigint 0.4.6", - "num-traits", - "parity-scale-codec", - "primitive-types", - "proptest", - "rand 0.8.5", - "rlp", - "ruint-macro 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", - "serde", - "valuable", - "zeroize", -] - [[package]] name = "ruint" version = "1.14.0" @@ -6834,7 +7715,7 @@ dependencies = [ "bytemuck", "bytes", "criterion", - "der 0.7.9", + "der 0.7.10", "diesel", "ethereum_ssz", "eyre", @@ -6859,18 +7740,45 @@ dependencies = [ "postgres-types", "primitive-types", "proptest", - "pyo3", - "quickcheck", + "pyo3", + "quickcheck", + "rand 0.8.5", + "rand 0.9.1", + "rlp", + "ruint 1.14.0", + "ruint-macro 1.2.1", + "serde", + "serde_json", + "sqlx-core", + "subtle", + "thiserror 2.0.12", + "valuable", + "zeroize", +] + +[[package]] +name = "ruint" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11256b5fe8c68f56ac6f39ef0720e592f33d2367a4782740d9c9142e889c7fb4" +dependencies = [ + "alloy-rlp", + "ark-ff 0.3.0", + "ark-ff 0.4.2", + "bytes", + "fastrlp 0.3.1", + "fastrlp 0.4.0", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "parity-scale-codec", + "primitive-types", + "proptest", "rand 0.8.5", "rand 0.9.1", "rlp", - "ruint 1.14.0", - "ruint-macro 1.2.1", + "ruint-macro 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "serde", - "serde_json", - "sqlx-core", - "subtle", - "thiserror 2.0.11", "valuable", "zeroize", ] @@ -6890,9 +7798,9 @@ checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" [[package]] name = "rustc-demangle" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" [[package]] name = "rustc-hash" @@ -6927,7 +7835,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ - "semver 1.0.25", + "semver 1.0.26", ] [[package]] @@ -6936,11 +7844,24 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "errno", "libc", - "linux-raw-sys", - "windows-sys 0.59.0", + "linux-raw-sys 0.4.15", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags 2.9.1", + "errno", + "libc", + "linux-raw-sys 0.9.4", + "windows-sys 0.52.0", ] [[package]] @@ -6951,10 +7872,24 @@ checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.23.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" +dependencies = [ + "aws-lc-rs", + "once_cell", + "rustls-pki-types", + "rustls-webpki 0.103.4", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.3" @@ -6964,7 +7899,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile", "schannel", - "security-framework", + "security-framework 2.11.1", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.2.0", ] [[package]] @@ -6976,6 +7923,15 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -6986,11 +7942,23 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.103.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" [[package]] name = "rusty-fork" @@ -7006,9 +7974,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -7040,7 +8008,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7052,6 +8020,30 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + +[[package]] +name = "schemars" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -7089,7 +8081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" dependencies = [ "base16ct 0.2.0", - "der 0.7.9", + "der 0.7.10", "generic-array", "pkcs8 0.10.2", "serdect", @@ -7099,10 +8091,11 @@ dependencies = [ [[package]] name = "secp256k1" -version = "0.29.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9465315bc9d4566e1724f0fffcbcc446268cb522e60f9a27bcded6b19c108113" +checksum = "b50c5943d326858130af85e049f2661ba3c78b26589b8ab98e65e80ae44a1252" dependencies = [ + "bitcoin_hashes", "rand 0.8.5", "secp256k1-sys", ] @@ -7122,8 +8115,21 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.8.0", - "core-foundation", + "bitflags 2.9.1", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags 2.9.1", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -7150,9 +8156,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" dependencies = [ "serde", ] @@ -7168,9 +8174,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] @@ -7195,22 +8201,22 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "serde_json" -version = "1.0.139" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "itoa", "memchr", "ryu", @@ -7229,9 +8235,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" dependencies = [ "serde", ] @@ -7250,15 +8256,17 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.12.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" +checksum = "f2c45cd61fefa9db6f254525d46e392b852e0e61d9a1fd36e5bd183450a556d5" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.7.1", + "indexmap 2.10.0", + "schemars 0.9.0", + "schemars 1.0.4", "serde", "serde_derive", "serde_json", @@ -7268,14 +8276,14 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.12.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" +checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7301,9 +8309,22 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.8" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if", + "cpufeatures", + "digest 0.9.0", + "opaque-debug", +] + +[[package]] +name = "sha2" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", @@ -7347,9 +8368,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] @@ -7388,24 +8409,21 @@ checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" [[package]] name = "slab" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] +checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "snark-verifier" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28e4c4ed1edca41687fe2d8a09ba30badb0a5cc7fa56dd1159d62aeab7c99ace" +checksum = "c9203c416ff9de0762667270b21573ba5e6edaeda08743b3ca37dc8a5e0a4480" dependencies = [ "halo2-base", "halo2-ecc", @@ -7418,16 +8436,16 @@ dependencies = [ "pairing 0.23.0", "rand 0.8.5", "revm", - "ruint 1.12.3", + "ruint 1.15.0", "serde", "sha3", ] [[package]] name = "snark-verifier-sdk" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "babff70ce6292fce03f692d68569f76b8f6710dbac7be7fe5f32c915909c9065" +checksum = "290ae6e750d9d5fdf05393bbcae6bf7a63e3408eab023abf7d466156a234ac85" dependencies = [ "bincode", "ethereum-types", @@ -7448,9 +8466,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", "windows-sys 0.52.0", @@ -7493,7 +8511,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", - "der 0.7.9", + "der 0.7.10", ] [[package]] @@ -7511,15 +8529,15 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", - "hashbrown 0.15.2", + "hashbrown 0.15.4", "hashlink", - "indexmap 2.7.1", + "indexmap 2.10.0", "log", "memchr", "once_cell", "percent-encoding", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", "tracing", "url", ] @@ -7536,6 +8554,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "statrs" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a3fe7c28c6512e766b0874335db33c94ad7b8f9054228ae1c2abd47ce7d335e" +dependencies = [ + "approx", + "num-traits", +] + [[package]] name = "strength_reduce" version = "0.2.4" @@ -7590,20 +8618,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.98", -] - -[[package]] -name = "substrate-bn" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b5bbfa79abbae15dd642ea8176a21a635ff3c00059961d1ea27ad04e5b441c" -dependencies = [ - "byteorder", - "crunchy", - "lazy_static", - "rand 0.8.5", - "rustc-hex", + "syn 2.0.104", ] [[package]] @@ -7636,10 +8651,10 @@ dependencies = [ "hex", "once_cell", "reqwest", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", - "sha2", + "sha2 0.10.9", "thiserror 1.0.69", "url", "zip", @@ -7653,7 +8668,7 @@ checksum = "aa64b5e8eecd3a8af7cfc311e29db31a268a62d5953233d3e8243ec77a71c4e3" dependencies = [ "build_const", "hex", - "semver 1.0.25", + "semver 1.0.26", "serde_json", "svm-rs", ] @@ -7671,9 +8686,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.98" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -7689,7 +8704,7 @@ dependencies = [ "paste", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7700,13 +8715,13 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7716,7 +8731,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -7750,16 +8765,15 @@ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "tempfile" -version = "3.17.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ - "cfg-if", "fastrand", - "getrandom 0.3.1", + "getrandom 0.3.3", "once_cell", - "rustix", - "windows-sys 0.59.0", + "rustix 1.0.7", + "windows-sys 0.52.0", ] [[package]] @@ -7773,6 +8787,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "terminal_size" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed" +dependencies = [ + "rustix 1.0.7", + "windows-sys 0.59.0", +] + [[package]] name = "test-case" version = "3.3.1" @@ -7791,7 +8815,7 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7802,30 +8826,30 @@ checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "test-case-core", ] [[package]] name = "test-log" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" dependencies = [ "env_logger", "test-log-macros", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] name = "test-log-macros" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7839,11 +8863,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -7854,28 +8878,27 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -7909,9 +8932,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.37" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -7926,15 +8949,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.19" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", @@ -7951,9 +8974,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", "zerovec", @@ -7986,16 +9009,18 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.2" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "pin-project-lite", "signal-hook-registry", + "slab", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -8009,7 +9034,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -8044,15 +9069,25 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.12", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +dependencies = [ + "rustls 0.23.29", "tokio", ] [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", @@ -8067,7 +9102,7 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", @@ -8076,21 +9111,21 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.20" +version = "0.8.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.22.24", + "toml_edit 0.22.27", ] [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" dependencies = [ "serde", ] @@ -8101,7 +9136,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", @@ -8110,17 +9145,40 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", - "winnow 0.7.3", + "toml_write", + "winnow 0.7.12", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "tower-layer", + "tower-service", ] +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" @@ -8141,20 +9199,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.28" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "tracing-core" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", "valuable", @@ -8170,7 +9228,7 @@ dependencies = [ "smallvec", "thiserror 1.0.69", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -8184,6 +9242,15 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "tracing-core", +] + [[package]] name = "tracing-subscriber" version = "0.3.19" @@ -8265,9 +9332,9 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-normalization" @@ -8329,12 +9396,6 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -8349,9 +9410,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.13.2" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" +dependencies = [ + "getrandom 0.3.3", + "js-sys", + "wasm-bindgen", +] [[package]] name = "valuable" @@ -8420,15 +9486,15 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -8461,7 +9527,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "wasm-bindgen-shared", ] @@ -8496,7 +9562,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -8526,6 +9592,18 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "whoami" version = "1.6.0" @@ -8559,7 +9637,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] @@ -8570,11 +9648,61 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-core" -version = "0.52.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ - "windows-targets 0.52.6", + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", ] [[package]] @@ -8736,9 +9864,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.3" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] @@ -8755,24 +9883,18 @@ dependencies = [ [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "wyz" @@ -8803,9 +9925,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", @@ -8815,55 +9937,54 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "synstructure", ] [[package]] name = "zerocopy" -version = "0.7.35" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" dependencies = [ - "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.35" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "zerofrom" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "synstructure", ] @@ -8884,14 +10005,25 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", ] [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" dependencies = [ "yoke", "zerofrom", @@ -8900,13 +10032,13 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -8950,7 +10082,7 @@ dependencies = [ "pasta_curves 0.5.1", "rand 0.8.5", "serde", - "sha2", + "sha2 0.10.9", "sha3", "subtle", ] diff --git a/Cargo.toml b/Cargo.toml index 2734767aff..974cf25a74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace.package] -version = "1.3.0" +version = "1.4.0-rc.2" edition = "2021" -rust-version = "1.82" +rust-version = "1.86.0" authors = ["OpenVM Authors"] homepage = "https://openvm.dev" repository = "https://github.com/openvm-org/" @@ -58,6 +58,7 @@ members = [ "extensions/ecc/transpiler", "extensions/ecc/guest", "extensions/ecc/sw-macros", + "extensions/ecc/te-macros", "extensions/ecc/tests", "extensions/pairing/circuit", "extensions/pairing/guest", @@ -70,7 +71,7 @@ members = [ "guest-libs/sha2/", "guest-libs/verify_stark/", ] -exclude = ["crates/sdk/example"] +exclude = ["crates/sdk/example", "benchmarks/guest/**"] resolver = "2" # Fastest runtime configuration @@ -85,6 +86,7 @@ codegen-units = 16 [profile.profiling] inherits = "release" debug = 2 +debug-assertions = false strip = false # Make sure debug symbols are in the bench profile for flamegraphs @@ -99,6 +101,7 @@ codegen-units = 1 [profile.dev] opt-level = 1 +debug = 2 # For O1 optimization but still fast(ish) compile times [profile.fast] @@ -110,8 +113,8 @@ lto = "thin" [workspace.dependencies] # Stark Backend -openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.1.1", default-features = false } -openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.1.1", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.0-rc.1", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.0-rc.1", default-features = false } # OpenVM openvm-sdk = { path = "crates/sdk", default-features = false } @@ -162,6 +165,7 @@ openvm-ecc-circuit = { path = "extensions/ecc/circuit", default-features = false openvm-ecc-transpiler = { path = "extensions/ecc/transpiler", default-features = false } openvm-ecc-guest = { path = "extensions/ecc/guest", default-features = false } openvm-ecc-sw-macros = { path = "extensions/ecc/sw-macros", default-features = false } +openvm-ecc-te-macros = { path = "extensions/ecc/te-macros", default-features = false } openvm-pairing-circuit = { path = "extensions/pairing/circuit", default-features = false } openvm-pairing-transpiler = { path = "extensions/pairing/transpiler", default-features = false } openvm-pairing-guest = { path = "extensions/pairing/guest", default-features = false } @@ -171,18 +175,16 @@ openvm-verify-stark = { path = "guest-libs/verify_stark", default-features = fal openvm-benchmarks-utils = { path = "benchmarks/utils", default-features = false } # Plonky3 -p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", features = [ - "nightly-features", -], rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-poseidon2-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-poseidon2-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } zkhash = { git = "https://github.com/HorizenLabs/poseidon2.git", rev = "bb476b9" } snark-verifier-sdk = { version = "0.2.0", default-features = false, features = [ @@ -220,12 +222,14 @@ tempfile = "3.13.0" thiserror = "1.0.65" rustc-hash = "2.0.0" static_assertions = "1.1.0" -async-trait = "0.1.83" getset = "0.1.3" rrs-lib = "0.1.0" rand = { version = "0.8.5", default-features = false } hex = { version = "0.4.3", default-features = false } serde-big-array = "0.5.1" +dashmap = "6.1.0" +memmap2 = "0.9.5" +tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } # default-features = false for no_std for use in guest programs itertools = { version = "0.14.0", default-features = false } @@ -258,3 +262,6 @@ sha2 = { version = "0.10", default-features = false } # p3-poseidon2 = { path = "../Plonky3/poseidon2" } # p3-poseidon2-air = { path = "../Plonky3/poseidon2-air" } # p3-symmetric = { path = "../Plonky3/symmetric" } + +[workspace.metadata.cargo-shear] +ignored = ["cargo-openvm"] diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index 319490220a..76ca243c2d 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -9,41 +9,66 @@ license.workspace = true [dependencies] openvm-benchmarks-utils.workspace = true -cargo-openvm.workspace = true openvm-circuit.workspace = true -openvm-sdk.workspace = true openvm-stark-sdk.workspace = true openvm-transpiler.workspace = true -openvm-rv32im-circuit.workspace = true -openvm-rv32im-transpiler.workspace = true +openvm-algebra-circuit.workspace = true +openvm-algebra-transpiler.workspace = true +openvm-bigint-circuit.workspace = true +openvm-bigint-transpiler.workspace = true +openvm-ecc-circuit.workspace = true +openvm-ecc-transpiler.workspace = true +openvm-native-circuit = { workspace = true } +openvm-pairing-circuit.workspace = true +openvm-pairing-guest.workspace = true +openvm-pairing-transpiler.workspace = true openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true +openvm-rv32im-circuit.workspace = true +openvm-rv32im-transpiler.workspace = true +openvm-sha256-circuit.workspace = true +openvm-sha256-transpiler.workspace = true +openvm-continuations = { workspace = true } +openvm-sdk = { workspace = true } -clap = { version = "4.5.9", features = ["derive", "env"] } +clap.workspace = true eyre.workspace = true -tracing.workspace = true derive_more = { workspace = true, features = ["from"] } - -tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } +rand.workspace = true +serde = { workspace = true, features = ["derive"] } +bitcode.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true [dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } +divan = { package = "codspeed-divan-compat", version = "3.0.2" } [features] default = ["jemalloc"] -profiling = ["openvm-sdk/profiling"] mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] jemalloc-prof = ["openvm-circuit/jemalloc-prof"] nightly-features = ["openvm-circuit/nightly-features"] +perf-metrics = [ + "openvm-circuit/perf-metrics", + "openvm-transpiler/function-span", +] -[[bench]] -name = "fibonacci_execute" -harness = false +# [[bench]] +# name = "fibonacci_execute" +# harness = false + +# [[bench]] +# name = "regex_execute" +# harness = false [[bench]] -name = "regex_execute" +name = "execute" harness = false +[[bin]] +name = "execute-leaf-verifier" +path = "src/execute-leaf-verifier.rs" + [package.metadata.cargo-shear] -ignored = ["derive_more"] +ignored = ["derive_more", "rand"] diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs new file mode 100644 index 0000000000..74963e593c --- /dev/null +++ b/benchmarks/execute/benches/execute.rs @@ -0,0 +1,362 @@ +use std::{fs, path::Path, sync::OnceLock}; + +use divan::Bencher; +use eyre::Result; +use openvm_algebra_circuit::{ + AlgebraCpuProverExt, Fp2Extension, Fp2ExtensionExecutor, ModularExtension, + ModularExtensionExecutor, +}; +use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; +use openvm_benchmarks_utils::{get_elf_path, get_fixtures_dir, get_programs_dir, read_elf_file}; +use openvm_bigint_circuit::{Int256, Int256CpuProverExt, Int256Executor}; +use openvm_bigint_transpiler::Int256TranspilerExtension; +use openvm_circuit::{ + arch::{ + execution_mode::metered::MeteredCtx, instructions::exe::VmExe, + interpreter::InterpretedInstance, ContinuationVmProof, *, + }, + derive::VmConfig, + system::*, +}; +use openvm_continuations::{ + verifier::{common::types::VmVerifierPvs, leaf::types::LeafVmVerifierInput}, + SC, +}; +use openvm_ecc_circuit::{EccCpuProverExt, EccExtension, EccExtensionExecutor}; +use openvm_ecc_transpiler::EccTranspilerExtension; +use openvm_keccak256_circuit::{Keccak256, Keccak256CpuProverExt, Keccak256Executor}; +use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; +use openvm_pairing_circuit::{ + PairingCurve, PairingExtension, PairingExtensionExecutor, PairingProverExt, +}; +use openvm_pairing_guest::bn254::BN254_COMPLEX_STRUCT_NAME; +use openvm_pairing_transpiler::PairingTranspilerExtension; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, +}; +use openvm_rv32im_transpiler::{ + Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +}; +use openvm_sdk::config::{DEFAULT_LEAF_LOG_BLOWUP, SBOX_SIZE}; +use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; +use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, + engine::{StarkEngine, StarkFriEngine}, + openvm_stark_backend::{ + self, + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::{ + cpu::{CpuBackend, CpuDevice}, + hal::DeviceDataTransporter, + }, + }, + p3_baby_bear::BabyBear, +}; +use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use serde::{Deserialize, Serialize}; + +static AVAILABLE_PROGRAMS: &[&str] = &[ + "fibonacci_recursive", + "fibonacci_iterative", + "quicksort", + "bubblesort", + "factorial_iterative_u256", + "revm_snailtracer", + "keccak256", + "keccak256_iter", + "sha256", + "sha256_iter", + "revm_transfer", + "pairing", +]; + +static METERED_CTX: OnceLock<(MeteredCtx, Vec)> = OnceLock::new(); +static EXECUTOR: OnceLock> = OnceLock::new(); + +#[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] +pub struct ExecuteConfig { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub bigint: Int256, + #[extension] + pub keccak: Keccak256, + #[extension] + pub sha256: Sha256, + #[extension] + pub modular: ModularExtension, + #[extension] + pub fp2: Fp2Extension, + #[extension] + pub ecc: EccExtension, + #[extension(generics = true)] + pub pairing: PairingExtension, +} + +impl Default for ExecuteConfig { + fn default() -> Self { + let bn_config = PairingCurve::Bn254.curve_config(); + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + bigint: Int256::default(), + keccak: Keccak256, + sha256: Sha256, + modular: ModularExtension::new(vec![ + bn_config.modulus.clone(), + bn_config.scalar.clone(), + ]), + fp2: Fp2Extension::new(vec![( + BN254_COMPLEX_STRUCT_NAME.to_string(), + bn_config.modulus.clone(), + )]), + ecc: EccExtension::new(vec![bn_config.clone()], vec![]), + pairing: PairingExtension::new(vec![PairingCurve::Bn254]), + } + } +} + +impl InitFileGenerator for ExecuteConfig { + fn write_to_init_file( + &self, + _manifest_dir: &Path, + _init_file_name: Option<&str>, + ) -> eyre::Result<()> { + Ok(()) + } +} + +pub struct ExecuteBuilder; +impl VmBuilder for ExecuteBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = ExecuteConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &ExecuteConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover( + &Int256CpuProverExt, + &config.bigint, + inventory, + )?; + VmProverExtension::::extend_prover( + &Keccak256CpuProverExt, + &config.keccak, + inventory, + )?; + VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; + VmProverExtension::::extend_prover( + &AlgebraCpuProverExt, + &config.modular, + inventory, + )?; + VmProverExtension::::extend_prover(&AlgebraCpuProverExt, &config.fp2, inventory)?; + VmProverExtension::::extend_prover(&EccCpuProverExt, &config.ecc, inventory)?; + VmProverExtension::::extend_prover(&PairingProverExt, &config.pairing, inventory)?; + Ok(chip_complex) + } +} + +fn main() { + divan::main(); +} + +fn create_default_transpiler() -> Transpiler { + Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Int256TranspilerExtension) + .with_extension(Keccak256TranspilerExtension) + .with_extension(Sha256TranspilerExtension) + .with_extension(ModularTranspilerExtension) + .with_extension(Fp2TranspilerExtension) + .with_extension(EccTranspilerExtension) + .with_extension(PairingTranspilerExtension) +} + +fn load_program_executable(program: &str) -> Result> { + let transpiler = create_default_transpiler(); + let program_dir = get_programs_dir().join(program); + let elf_path = get_elf_path(&program_dir); + let elf = read_elf_file(&elf_path)?; + Ok(VmExe::from_elf(elf, transpiler)?) +} + +fn metering_setup() -> &'static (MeteredCtx, Vec) { + METERED_CTX.get_or_init(|| { + let config = ExecuteConfig::default(); + let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); + let (vm, _) = VirtualMachine::new_with_keygen(engine, ExecuteBuilder, config).unwrap(); + let ctx = vm.build_metered_ctx(); + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + (ctx, executor_idx_to_air_idx) + }) +} + +fn executor() -> &'static VmExecutor { + EXECUTOR.get_or_init(|| { + let vm_config = ExecuteConfig::default(); + VmExecutor::::new(vm_config).unwrap() + }) +} + +#[divan::bench(args = AVAILABLE_PROGRAMS, sample_count=10)] +fn benchmark_execute(bencher: Bencher, program: &str) { + bencher + .with_inputs(|| { + let exe = load_program_executable(program).expect("Failed to load program executable"); + let interpreter = executor().instance(&exe).unwrap(); + (interpreter, vec![]) + }) + .bench_values(|(interpreter, input)| { + interpreter + .execute(input, None) + .expect("Failed to execute program in interpreted mode"); + }); +} + +#[divan::bench(args = AVAILABLE_PROGRAMS, sample_count=5)] +fn benchmark_execute_metered(bencher: Bencher, program: &str) { + bencher + .with_inputs(|| { + let exe = load_program_executable(program).expect("Failed to load program executable"); + let (ctx, executor_idx_to_air_idx) = metering_setup(); + let interpreter = executor() + .metered_instance(&exe, executor_idx_to_air_idx) + .unwrap(); + (interpreter, vec![], ctx.clone()) + }) + .bench_values(|(interpreter, input, ctx)| { + interpreter + .execute_metered(input, ctx) + .expect("Failed to execute program"); + }); +} + +fn setup_leaf_verifier() -> ( + VirtualMachine, + VmExe, + Vec>, +) { + let fixtures_dir = get_fixtures_dir(); + let app_proof_bytes = fs::read(fixtures_dir.join("kitchen-sink.app.proof")).unwrap(); + let app_proof: ContinuationVmProof = bitcode::deserialize(&app_proof_bytes).unwrap(); + + let leaf_exe_bytes = fs::read(fixtures_dir.join("kitchen-sink.leaf.exe")).unwrap(); + let leaf_exe: VmExe = bitcode::deserialize(&leaf_exe_bytes).unwrap(); + + let leaf_pk_bytes = fs::read(fixtures_dir.join("kitchen-sink.leaf.pk")).unwrap(); + let leaf_pk = bitcode::deserialize(&leaf_pk_bytes).unwrap(); + + let leaf_inputs = LeafVmVerifierInput::chunk_continuation_vm_proof(&app_proof, 2); + let leaf_input = leaf_inputs.first().expect("No leaf input available"); + + let config = NativeConfig::aggregation( + VmVerifierPvs::::width(), + SBOX_SIZE.min(FriParameters::standard_fast().max_constraint_degree()), + ); + let fri_params = + FriParameters::standard_with_100_bits_conjectured_security(DEFAULT_LEAF_LOG_BLOWUP); + let engine = BabyBearPoseidon2Engine::new(fri_params); + let d_pk = engine.device().transport_pk_to_device(&leaf_pk); + let vm = VirtualMachine::new(engine, NativeCpuBuilder, config, d_pk).unwrap(); + let input_stream = leaf_input.write_to_stream(); + + (vm, leaf_exe, input_stream) +} + +#[divan::bench(sample_count = 5)] +fn benchmark_leaf_verifier_execute(bencher: Bencher) { + bencher + .with_inputs(|| { + let (vm, leaf_exe, input_stream) = setup_leaf_verifier(); + let interpreter = vm.executor().instance(&leaf_exe).unwrap(); + + // SAFETY: We transmute the interpreter to have the same lifetime as the VM. + // This is safe because the vm is moved into the tuple and will remain + // alive for the entire duration that the interpreter is used. + #[allow(clippy::missing_transmute_annotations)] + let interpreter = + unsafe { std::mem::transmute::<_, InterpretedInstance<'_, _, _>>(interpreter) }; + + (vm, interpreter, input_stream) + }) + .bench_values(|(_vm, interpreter, input_stream)| { + interpreter + .execute(input_stream, None) + .expect("Failed to execute program in interpreted mode"); + }); +} + +#[divan::bench(sample_count = 5)] +fn benchmark_leaf_verifier_execute_metered(bencher: Bencher) { + bencher + .with_inputs(|| { + let (vm, leaf_exe, input_stream) = setup_leaf_verifier(); + let ctx = vm.build_metered_ctx(); + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + let interpreter = vm + .executor() + .metered_instance(&leaf_exe, &executor_idx_to_air_idx) + .unwrap(); + + // SAFETY: We transmute the interpreter to have the same lifetime as the VM. + // This is safe because the vm is moved into the tuple and will remain + // alive for the entire duration that the interpreter is used. + #[allow(clippy::missing_transmute_annotations)] + let interpreter = + unsafe { std::mem::transmute::<_, InterpretedInstance<'_, _, _>>(interpreter) }; + + (vm, interpreter, input_stream, ctx) + }) + .bench_values(|(_vm, interpreter, input_stream, ctx)| { + interpreter + .execute_metered(input_stream, ctx) + .expect("Failed to execute program"); + }); +} + +#[divan::bench(sample_count = 5)] +fn benchmark_leaf_verifier_execute_preflight(bencher: Bencher) { + bencher + .with_inputs(|| { + let (vm, leaf_exe, input_stream) = setup_leaf_verifier(); + let state = vm.create_initial_state(&leaf_exe, input_stream); + + (vm, leaf_exe, state) + }) + .bench_values(|(vm, leaf_exe, state)| { + let _out = vm + .execute_preflight(&leaf_exe, state, None, NATIVE_MAX_TRACE_HEIGHTS) + .expect("Failed to execute preflight"); + }); +} diff --git a/benchmarks/execute/benches/fibonacci_execute.rs b/benchmarks/execute/benches/fibonacci_execute.rs index 70952b53c9..49b453d028 100644 --- a/benchmarks/execute/benches/fibonacci_execute.rs +++ b/benchmarks/execute/benches/fibonacci_execute.rs @@ -1,42 +1,44 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use openvm_benchmarks_utils::{build_elf, get_programs_dir}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_rv32im_circuit::Rv32ImConfig; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +// use criterion::{criterion_group, criterion_main, Criterion}; +// use openvm_benchmarks_utils::{build_elf, get_programs_dir}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_rv32im_circuit::Rv32ImConfig; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// // TODO(ayush): add this back +// // use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{transpiler::Transpiler, FromElf}; -fn benchmark_function(c: &mut Criterion) { - let program_dir = get_programs_dir().join("fibonacci"); - let elf = build_elf(&program_dir, "release").unwrap(); +// fn benchmark_function(c: &mut Criterion) { +// let program_dir = get_programs_dir().join("fibonacci"); +// let elf = build_elf(&program_dir, "release").unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), - ) - .unwrap(); +// let exe = VmExe::from_elf( +// elf, +// Transpiler::::default() +// .with_extension(Rv32ITranspilerExtension) +// .with_extension(Rv32MTranspilerExtension) +// .with_extension(Rv32IoTranspilerExtension), +// ) +// .unwrap(); - let mut group = c.benchmark_group("fibonacci"); - let config = Rv32ImConfig::default(); - let executor = VmExecutor::::new(config); +// let mut group = c.benchmark_group("fibonacci"); +// let config = Rv32ImConfig::default(); +// let executor = VmExecutor::::new(config); - group.bench_function("execute", |b| { - b.iter(|| { - let n = 100_000u64; - let mut stdin = StdIn::default(); - stdin.write(&n); - executor.execute(exe.clone(), stdin).unwrap(); - }) - }); +// group.bench_function("execute", |b| { +// b.iter(|| { +// // TODO(ayush): add this back +// // let n = 100_000u64; +// // let mut stdin = StdIn::default(); +// // stdin.write(&n); +// executor.execute(exe.clone(), vec![]).unwrap(); +// }) +// }); - group.finish(); -} +// group.finish(); +// } -criterion_group!(benches, benchmark_function); -criterion_main!(benches); +// criterion_group!(benches, benchmark_function); +// criterion_main!(benches); diff --git a/benchmarks/execute/benches/regex_execute.rs b/benchmarks/execute/benches/regex_execute.rs index a3a110e344..d4116b5aab 100644 --- a/benchmarks/execute/benches/regex_execute.rs +++ b/benchmarks/execute/benches/regex_execute.rs @@ -1,47 +1,47 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use openvm_benchmarks_utils::{build_elf, get_programs_dir}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +// TODO(ayush): add this back +// use criterion::{black_box, criterion_group, criterion_main, Criterion}; +// use openvm_benchmarks_utils::{build_elf, get_programs_dir}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_keccak256_circuit::Keccak256Rv32Config; +// use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{transpiler::Transpiler, FromElf}; -fn benchmark_function(c: &mut Criterion) { - let program_dir = get_programs_dir().join("regex"); - let elf = build_elf(&program_dir, "release").unwrap(); +// fn benchmark_function(c: &mut Criterion) { +// let program_dir = get_programs_dir().join("regex"); +// let elf = build_elf(&program_dir, "release").unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - ) - .unwrap(); +// let exe = VmExe::from_elf( +// elf, +// Transpiler::::default() +// .with_extension(Rv32ITranspilerExtension) +// .with_extension(Rv32MTranspilerExtension) +// .with_extension(Rv32IoTranspilerExtension) +// .with_extension(Keccak256TranspilerExtension), +// ) +// .unwrap(); - let mut group = c.benchmark_group("regex"); - group.sample_size(10); - let config = Keccak256Rv32Config::default(); - let executor = VmExecutor::::new(config); +// let mut group = c.benchmark_group("regex"); +// group.sample_size(10); +// let config = Keccak256Rv32Config::default(); +// let executor = VmExecutor::::new(config); - let data = include_str!("../../guest/regex/regex_email.txt"); +// let data = include_str!("../../guest/regex/regex_email.txt"); - let fe_bytes = data.to_owned().into_bytes(); - group.bench_function("execute", |b| { - b.iter(|| { - executor - .execute(exe.clone(), black_box(StdIn::from_bytes(&fe_bytes))) - .unwrap(); - }) - }); +// let fe_bytes = data.to_owned().into_bytes(); +// group.bench_function("execute", |b| { +// b.iter(|| { +// let input = black_box(Stdin::from_bytes(&fe_bytes)); +// executor.execute(exe.clone(), input).unwrap(); +// }) +// }); - group.finish(); -} +// group.finish(); +// } -criterion_group!(benches, benchmark_function); -criterion_main!(benches); +// criterion_group!(benches, benchmark_function); +// criterion_main!(benches); diff --git a/benchmarks/execute/examples/regex_execute.rs b/benchmarks/execute/examples/regex_execute.rs index 59705a19fd..3a6fd4162f 100644 --- a/benchmarks/execute/examples/regex_execute.rs +++ b/benchmarks/execute/examples/regex_execute.rs @@ -1,35 +1,35 @@ -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{ - elf::Elf, openvm_platform::memory::MEM_SIZE, transpiler::Transpiler, FromElf, -}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_keccak256_circuit::Keccak256Rv32Config; +// use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{ +// elf::Elf, openvm_platform::memory::MEM_SIZE, transpiler::Transpiler, FromElf, +// }; fn main() { - let elf = Elf::decode(include_bytes!("regex-elf"), MEM_SIZE as u32).unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - ) - .unwrap(); + // let elf = Elf::decode(include_bytes!("regex-elf"), MEM_SIZE as u32).unwrap(); + // let exe = VmExe::from_elf( + // elf, + // Transpiler::::default() + // .with_extension(Rv32ITranspilerExtension) + // .with_extension(Rv32MTranspilerExtension) + // .with_extension(Rv32IoTranspilerExtension) + // .with_extension(Keccak256TranspilerExtension), + // ) + // .unwrap(); - let config = Keccak256Rv32Config::default(); - let executor = VmExecutor::::new(config); + // let config = Keccak256Rv32Config::default(); + // let executor = VmExecutor::::new(config); - let data = include_str!("../../guest/regex/regex_email.txt"); + // let data = include_str!("../../guest/regex/regex_email.txt"); - let timer = std::time::Instant::now(); - executor - .execute(exe.clone(), StdIn::from_bytes(data.as_bytes())) - .unwrap(); - println!("execute_time: {:?}", timer.elapsed()); + // let timer = std::time::Instant::now(); + // executor + // .execute(exe.clone(), StdIn::from_bytes(data.as_bytes())) + // .unwrap(); + // println!("execute_time: {:?}", timer.elapsed()); } diff --git a/benchmarks/execute/src/execute-leaf-verifier.rs b/benchmarks/execute/src/execute-leaf-verifier.rs new file mode 100644 index 0000000000..ac478f652d --- /dev/null +++ b/benchmarks/execute/src/execute-leaf-verifier.rs @@ -0,0 +1,101 @@ +use std::fs; + +use clap::{arg, Parser, ValueEnum}; +use eyre::Result; +use openvm_benchmarks_utils::get_fixtures_dir; +use openvm_circuit::arch::{instructions::exe::VmExe, ContinuationVmProof, VirtualMachine}; +use openvm_continuations::{ + verifier::{common::types::VmVerifierPvs, leaf::types::LeafVmVerifierInput}, + SC, +}; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; +use openvm_sdk::config::{DEFAULT_LEAF_LOG_BLOWUP, SBOX_SIZE}; +use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, + engine::{StarkEngine, StarkFriEngine}, + openvm_stark_backend::prover::hal::DeviceDataTransporter, + p3_baby_bear::BabyBear, +}; +use tracing_subscriber::{fmt, EnvFilter}; + +const PROGRAM_NAME: &str = "kitchen-sink"; + +#[derive(Clone, Debug, ValueEnum)] +enum ExecutionMode { + Normal, + Metered, + Preflight, +} + +#[derive(Parser)] +#[command(author, version, about = "OpenVM leaf verifier execution")] +struct Cli { + #[arg(short, long, value_enum, default_value = "preflight")] + mode: ExecutionMode, + + #[arg(short, long)] + verbose: bool, +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + + // Set up logging + let filter = if cli.verbose { + EnvFilter::from_default_env() + } else { + EnvFilter::new("info") + }; + fmt::fmt().with_env_filter(filter).init(); + + let fixtures_dir = get_fixtures_dir(); + let app_proof_bytes = + fs::read(fixtures_dir.join(format!("{}.app.proof", PROGRAM_NAME))).unwrap(); + let app_proof: ContinuationVmProof = bitcode::deserialize(&app_proof_bytes).unwrap(); + + let leaf_exe_bytes = fs::read(fixtures_dir.join(format!("{}.leaf.exe", PROGRAM_NAME))).unwrap(); + let leaf_exe: VmExe = bitcode::deserialize(&leaf_exe_bytes).unwrap(); + + let leaf_pk_bytes = fs::read(fixtures_dir.join(format!("{}.leaf.pk", PROGRAM_NAME))).unwrap(); + let leaf_pk = bitcode::deserialize(&leaf_pk_bytes).unwrap(); + + let leaf_inputs = LeafVmVerifierInput::chunk_continuation_vm_proof(&app_proof, 2); + let leaf_input = leaf_inputs.first().expect("No leaf input available"); + + let config = NativeConfig::aggregation( + VmVerifierPvs::::width(), + SBOX_SIZE.min(FriParameters::standard_fast().max_constraint_degree()), + ); + let fri_params = + FriParameters::standard_with_100_bits_conjectured_security(DEFAULT_LEAF_LOG_BLOWUP); + let engine = BabyBearPoseidon2Engine::new(fri_params); + let d_pk = engine.device().transport_pk_to_device(&leaf_pk); + let vm = VirtualMachine::new(engine, NativeCpuBuilder, config, d_pk)?; + let input_stream = leaf_input.write_to_stream(); + + match cli.mode { + ExecutionMode::Normal => { + tracing::info!("Running normal execute..."); + let interpreter = vm.executor().instance(&leaf_exe)?; + interpreter.execute(input_stream, None)?; + } + ExecutionMode::Metered => { + tracing::info!("Running metered execute..."); + let ctx = vm.build_metered_ctx(); + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + let interpreter = vm + .executor() + .metered_instance(&leaf_exe, &executor_idx_to_air_idx)?; + interpreter.execute_metered(input_stream, ctx)?; + } + ExecutionMode::Preflight => { + tracing::info!("Running preflight execute..."); + let state = vm.create_initial_state(&leaf_exe, input_stream); + let _out = vm + .execute_preflight(&leaf_exe, state, None, NATIVE_MAX_TRACE_HEIGHTS) + .expect("Failed to execute preflight"); + } + } + + Ok(()) +} diff --git a/benchmarks/execute/src/main.rs b/benchmarks/execute/src/main.rs deleted file mode 100644 index a05baeea44..0000000000 --- a/benchmarks/execute/src/main.rs +++ /dev/null @@ -1,121 +0,0 @@ -use cargo_openvm::util::read_config_toml_or_default; -use clap::{Parser, ValueEnum}; -use eyre::Result; -use openvm_benchmarks_utils::{get_elf_path, get_programs_dir, read_elf_file}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::bench::run_with_metric_collection; -use openvm_transpiler::FromElf; - -#[derive(Debug, Clone, ValueEnum)] -enum BuildProfile { - Debug, - Release, -} - -static AVAILABLE_PROGRAMS: &[&str] = &[ - "fibonacci_recursive", - "fibonacci_iterative", - "quicksort", - "bubblesort", - "pairing", - "keccak256", - "keccak256_iter", - "sha256", - "sha256_iter", - "revm_transfer", - "revm_snailtracer", -]; - -#[derive(Parser)] -#[command(author, version, about = "OpenVM Benchmark CLI", long_about = None)] -struct Cli { - /// Programs to benchmark (if not specified, all programs will be run) - #[arg(short, long)] - programs: Vec, - - /// Programs to skip from benchmarking - #[arg(short, long)] - skip: Vec, - - /// Output path for benchmark results - #[arg(short, long, default_value = "OUTPUT_PATH")] - output: String, - - /// List available benchmark programs and exit - #[arg(short, long)] - list: bool, - - /// Verbose output - #[arg(short, long)] - verbose: bool, -} - -fn main() -> Result<()> { - let cli = Cli::parse(); - - if cli.list { - println!("Available benchmark programs:"); - for program in AVAILABLE_PROGRAMS { - println!(" {}", program); - } - return Ok(()); - } - - // Set up logging based on verbosity - if cli.verbose { - tracing_subscriber::fmt::init(); - } - - let mut programs_to_run = if cli.programs.is_empty() { - AVAILABLE_PROGRAMS.to_vec() - } else { - // Validate provided programs - for program in &cli.programs { - if !AVAILABLE_PROGRAMS.contains(&program.as_str()) { - eprintln!("Unknown program: {}", program); - eprintln!("Use --list to see available programs"); - std::process::exit(1); - } - } - cli.programs.iter().map(|s| s.as_str()).collect() - }; - - // Remove programs that should be skipped - if !cli.skip.is_empty() { - // Validate skipped programs - for program in &cli.skip { - if !AVAILABLE_PROGRAMS.contains(&program.as_str()) { - eprintln!("Unknown program to skip: {}", program); - eprintln!("Use --list to see available programs"); - std::process::exit(1); - } - } - - let skip_set: Vec<&str> = cli.skip.iter().map(|s| s.as_str()).collect(); - programs_to_run.retain(|&program| !skip_set.contains(&program)); - } - - tracing::info!("Starting benchmarks with metric collection"); - - run_with_metric_collection(&cli.output, || -> Result<()> { - for program in &programs_to_run { - tracing::info!("Running program: {}", program); - - let program_dir = get_programs_dir().join(program); - let elf_path = get_elf_path(&program_dir); - let elf = read_elf_file(&elf_path)?; - - let config_path = program_dir.join("openvm.toml"); - let vm_config = read_config_toml_or_default(&config_path)?.app_vm_config; - - let exe = VmExe::from_elf(elf, vm_config.transpiler())?; - - let executor = VmExecutor::new(vm_config); - executor.execute(exe, StdIn::default())?; - tracing::info!("Completed program: {}", program); - } - tracing::info!("All programs executed successfully"); - Ok(()) - }) -} diff --git a/benchmarks/guest/Cargo.toml b/benchmarks/guest/Cargo.toml new file mode 100644 index 0000000000..f27ae022c5 --- /dev/null +++ b/benchmarks/guest/Cargo.toml @@ -0,0 +1,31 @@ +[workspace.package] +version = "0.0.0" +edition = "2021" + +[workspace] +members = ["base64_json", "bincode", "bubblesort", "ecrecover", "factorial_iterative_u256", "fibonacci", "fibonacci_iterative", "fibonacci_recursive", "keccak256", "keccak256_iter", "kitchen-sink", "pairing", "quicksort", "regex", "revm_snailtracer", "revm_transfer", "rkyv", "sha256", "sha256_iter"] +resolver = "2" + +[workspace.dependencies] +openvm = { path = "../../crates/toolchain/openvm" } +openvm-algebra-guest = { path = "../../extensions/algebra/guest", default-features = false } +openvm-ecc-guest = { path = "../../extensions/ecc/guest", default-features = false } +openvm-keccak256 = { path = "../../guest-libs/keccak256/", default-features = false } +openvm-ruint = { path = "../../guest-libs/ruint/", package = "ruint", default-features = false } +openvm-pairing = { path = "../../guest-libs/pairing/", default-features = false } +openvm-sha2 = { path = "../../guest-libs/sha2/", default-features = false } +openvm-k256 = { path = "../../guest-libs/k256/", package = "k256" } +openvm-p256 = { path = "../../guest-libs/p256/", package = "p256" } + +# patch for ecrecover +[patch.crates-io] +k256 = { path = "../../guest-libs/k256/" } + +[profile.release] +panic = "abort" +lto = "thin" # faster compile time + +[profile.profiling] +inherits = "release" +debug = 2 +strip = false diff --git a/benchmarks/guest/base64_json/Cargo.toml b/benchmarks/guest/base64_json/Cargo.toml index f0f43b3479..9177070a63 100644 --- a/benchmarks/guest/base64_json/Cargo.toml +++ b/benchmarks/guest/base64_json/Cargo.toml @@ -1,19 +1,13 @@ -[workspace] [package] -version = "0.1.0" name = "openvm-json-program" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } base64 = { version = "0.22.1", default-features = false, features = ["alloc"] } serde = { version = "1.0.214", default-features = false, features = ["derive"] } serde-json-core = "0.6.0" [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/base64_json/elf/openvm-json-program.elf b/benchmarks/guest/base64_json/elf/openvm-json-program.elf index 29e6cac131..55335dca15 100755 Binary files a/benchmarks/guest/base64_json/elf/openvm-json-program.elf and b/benchmarks/guest/base64_json/elf/openvm-json-program.elf differ diff --git a/benchmarks/guest/bincode/Cargo.toml b/benchmarks/guest/bincode/Cargo.toml index eba3c918bf..3464800677 100644 --- a/benchmarks/guest/bincode/Cargo.toml +++ b/benchmarks/guest/bincode/Cargo.toml @@ -1,11 +1,10 @@ -[workspace] [package] name = "openvm-bincode-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } bincode = { version = "2.0.0-rc.3", default-features = false, features = [ "derive", "alloc", @@ -16,8 +15,3 @@ rand_pcg = "0.3.1" [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/bincode/elf/openvm-bincode-program.elf b/benchmarks/guest/bincode/elf/openvm-bincode-program.elf index 085eb7ee4f..2d4b2ae67a 100755 Binary files a/benchmarks/guest/bincode/elf/openvm-bincode-program.elf and b/benchmarks/guest/bincode/elf/openvm-bincode-program.elf differ diff --git a/benchmarks/guest/bubblesort/Cargo.toml b/benchmarks/guest/bubblesort/Cargo.toml index 68a0af82ff..957719a7df 100644 --- a/benchmarks/guest/bubblesort/Cargo.toml +++ b/benchmarks/guest/bubblesort/Cargo.toml @@ -1,16 +1,10 @@ -[workspace] [package] name = "openvm-bubblesort-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf b/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf index 0f81a3926f..cec789e279 100755 Binary files a/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf and b/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf differ diff --git a/benchmarks/guest/bubblesort/src/main.rs b/benchmarks/guest/bubblesort/src/main.rs index 0dd7e51146..d859641504 100644 --- a/benchmarks/guest/bubblesort/src/main.rs +++ b/benchmarks/guest/bubblesort/src/main.rs @@ -1,7 +1,7 @@ use core::hint::black_box; use openvm as _; -const ARRAY_SIZE: usize = 100; +const ARRAY_SIZE: usize = 1_000; fn bubblesort(arr: &mut [T]) { let len = arr.len(); diff --git a/benchmarks/guest/ecrecover/Cargo.toml b/benchmarks/guest/ecrecover/Cargo.toml index b9592028f7..0937e63f0a 100644 --- a/benchmarks/guest/ecrecover/Cargo.toml +++ b/benchmarks/guest/ecrecover/Cargo.toml @@ -1,14 +1,13 @@ -[workspace] [package] name = "openvm-ecdsa-recover-key-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-algebra-guest = { path = "../../../extensions/algebra/guest", default-features = false } -openvm-ecc-guest = { path = "../../../extensions/ecc/guest", default-features = false } -openvm-keccak256 = { path = "../../../guest-libs/keccak256/", default-features = false } +openvm = { workspace = true, features = ["std"] } +openvm-algebra-guest.workspace = true +openvm-ecc-guest.workspace = true +openvm-keccak256.workspace = true revm-precompile = { git = "https://github.com/bluealloy/revm.git", tag = "v75", default-features = false } # IMPORTANT: must be same version as used by revm; revm does not re-export this feature so we enable it here alloy-primitives = { version = "1.2.0", default-features = false, features = [ @@ -18,15 +17,3 @@ k256 = { version = "0.13.3", default-features = false } [features] default = [] - -[profile.release] -panic = "abort" -lto = "thin" # faster compile time - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false - -[patch.crates-io] -k256 = { path = "../../../guest-libs/k256/" } diff --git a/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf b/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf index 4e54268ea4..88c87c6abc 100755 Binary files a/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf and b/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf differ diff --git a/benchmarks/guest/ecrecover/openvm.toml b/benchmarks/guest/ecrecover/openvm.toml index c1261ee458..265d29d89a 100644 --- a/benchmarks/guest/ecrecover/openvm.toml +++ b/benchmarks/guest/ecrecover/openvm.toml @@ -9,9 +9,10 @@ supported_moduli = [ "115792089237316195423570985008687907852837564279074904382605163141518161494337", ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "7" diff --git a/benchmarks/guest/ecrecover/openvm_init.rs b/benchmarks/guest/ecrecover/openvm_init.rs index bec9f527e9..d9cf1bbe09 100644 --- a/benchmarks/guest/ecrecover/openvm_init.rs +++ b/benchmarks/guest/ecrecover/openvm_init.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/benchmarks/guest/factorial_iterative_u256/Cargo.toml b/benchmarks/guest/factorial_iterative_u256/Cargo.toml new file mode 100644 index 0000000000..7acc9e66f8 --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "openvm-factorial-iterative-u256-program" +version.workspace = true +edition.workspace = true + +[dependencies] +openvm = { workspace = true, features = ["std"] } +openvm-ruint.workspace = true + +[features] +default = [] diff --git a/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf b/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf new file mode 100755 index 0000000000..572f71b182 Binary files /dev/null and b/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf differ diff --git a/benchmarks/guest/factorial_iterative_u256/openvm.toml b/benchmarks/guest/factorial_iterative_u256/openvm.toml new file mode 100644 index 0000000000..b226887890 --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/openvm.toml @@ -0,0 +1,4 @@ +[app_vm_config.rv32i] +[app_vm_config.rv32m] +[app_vm_config.io] +[app_vm_config.bigint] diff --git a/benchmarks/guest/factorial_iterative_u256/src/main.rs b/benchmarks/guest/factorial_iterative_u256/src/main.rs new file mode 100644 index 0000000000..c92491d2da --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/src/main.rs @@ -0,0 +1,16 @@ +use core::hint::black_box; +use openvm as _; +use openvm_ruint::aliases::U256; + +// This will overflow but that is fine +const N: u32 = 65_000; + +pub fn main() { + let mut acc = U256::from(1u32); + let mut i = U256::from(N); + while i > black_box(U256::ZERO) { + acc *= i.clone(); + i -= U256::from(1u32); + } + black_box(acc); +} diff --git a/benchmarks/guest/fibonacci/Cargo.toml b/benchmarks/guest/fibonacci/Cargo.toml index 4ea6659e73..469868a3b9 100644 --- a/benchmarks/guest/fibonacci/Cargo.toml +++ b/benchmarks/guest/fibonacci/Cargo.toml @@ -1,16 +1,10 @@ -[workspace] [package] name = "openvm-fibonacci-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf b/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf index 36ad8d359c..20335618e4 100755 Binary files a/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf and b/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf differ diff --git a/benchmarks/guest/fibonacci_iterative/Cargo.toml b/benchmarks/guest/fibonacci_iterative/Cargo.toml index 6f0c145061..75f564d2b9 100644 --- a/benchmarks/guest/fibonacci_iterative/Cargo.toml +++ b/benchmarks/guest/fibonacci_iterative/Cargo.toml @@ -1,16 +1,10 @@ -[workspace] [package] name = "openvm-fibonacci-iterative-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf b/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf index ac9fbf3e89..7c681ee313 100755 Binary files a/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf and b/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf differ diff --git a/benchmarks/guest/fibonacci_iterative/src/main.rs b/benchmarks/guest/fibonacci_iterative/src/main.rs index 09ceb5df41..f7ab8ec0f6 100644 --- a/benchmarks/guest/fibonacci_iterative/src/main.rs +++ b/benchmarks/guest/fibonacci_iterative/src/main.rs @@ -1,15 +1,15 @@ use core::hint::black_box; -use openvm as _; +use openvm::io::reveal_u32; -const N: u64 = 100_000; +const N: u32 = 900_000; pub fn main() { - let mut a: u64 = 0; - let mut b: u64 = 1; + let mut a: u32 = 0; + let mut b: u32 = 1; for _ in 0..black_box(N) { - let c: u64 = a.wrapping_add(b); + let c: u32 = a.wrapping_add(b); a = b; b = c; } - black_box(a); + reveal_u32(a, 0); } diff --git a/benchmarks/guest/fibonacci_recursive/Cargo.toml b/benchmarks/guest/fibonacci_recursive/Cargo.toml index 95b124df43..2b8177d1c7 100644 --- a/benchmarks/guest/fibonacci_recursive/Cargo.toml +++ b/benchmarks/guest/fibonacci_recursive/Cargo.toml @@ -1,16 +1,10 @@ -[workspace] [package] name = "openvm-fibonacci-recursive-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf b/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf index 7dee9d4286..d14372657c 100755 Binary files a/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf and b/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf differ diff --git a/benchmarks/guest/fibonacci_recursive/src/main.rs b/benchmarks/guest/fibonacci_recursive/src/main.rs index fae64a1b0f..9020bc91ef 100644 --- a/benchmarks/guest/fibonacci_recursive/src/main.rs +++ b/benchmarks/guest/fibonacci_recursive/src/main.rs @@ -1,14 +1,15 @@ use core::hint::black_box; -use openvm as _; +use openvm::io::reveal_u32; -const N: u64 = 25; +const N: u32 = 27; pub fn main() { let n = black_box(N); - black_box(fibonacci(n)); + let result = fibonacci(n); + reveal_u32(result, 0); } -fn fibonacci(n: u64) -> u64 { +fn fibonacci(n: u32) -> u32 { if n == 0 { 0 } else if n == 1 { diff --git a/benchmarks/guest/keccak256/Cargo.toml b/benchmarks/guest/keccak256/Cargo.toml index 35bc10320a..486330ee17 100644 --- a/benchmarks/guest/keccak256/Cargo.toml +++ b/benchmarks/guest/keccak256/Cargo.toml @@ -1,18 +1,11 @@ -[workspace] [package] name = "openvm-keccak256-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-keccak256 = { path = "../../../guest-libs/keccak256" } +openvm = { workspace = true, features = ["std"] } +openvm-keccak256.workspace = true [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 - -strip = false diff --git a/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf b/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf index 7425897f99..6e0fc26837 100755 Binary files a/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf and b/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf differ diff --git a/benchmarks/guest/keccak256/src/main.rs b/benchmarks/guest/keccak256/src/main.rs index 5a00ba4067..0d8c6d17b4 100644 --- a/benchmarks/guest/keccak256/src/main.rs +++ b/benchmarks/guest/keccak256/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_keccak256::keccak256; -const INPUT_LENGTH_BYTES: usize = 100 * 1024; // 100 KB +const INPUT_LENGTH_BYTES: usize = 384 * 1024; pub fn main() { let mut input = Vec::with_capacity(INPUT_LENGTH_BYTES); diff --git a/benchmarks/guest/keccak256_iter/Cargo.toml b/benchmarks/guest/keccak256_iter/Cargo.toml index 68c2cbb5dd..73e498e9cf 100644 --- a/benchmarks/guest/keccak256_iter/Cargo.toml +++ b/benchmarks/guest/keccak256_iter/Cargo.toml @@ -1,17 +1,11 @@ -[workspace] [package] name = "openvm-keccak256-iter-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-keccak256 = { path = "../../../guest-libs/keccak256" } +openvm = { workspace = true, features = ["std"] } +openvm-keccak256.workspace = true [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf b/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf index 0cf372eec3..7a267a02ab 100755 Binary files a/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf and b/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf differ diff --git a/benchmarks/guest/keccak256_iter/src/main.rs b/benchmarks/guest/keccak256_iter/src/main.rs index ef36ff1d64..554179819a 100644 --- a/benchmarks/guest/keccak256_iter/src/main.rs +++ b/benchmarks/guest/keccak256_iter/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_keccak256::keccak256; -const ITERATIONS: usize = 10_000; +const ITERATIONS: usize = 65_000; pub fn main() { // Initialize with hash of an empty vector diff --git a/benchmarks/guest/kitchen-sink/Cargo.toml b/benchmarks/guest/kitchen-sink/Cargo.toml index f699305cea..9088004f45 100644 --- a/benchmarks/guest/kitchen-sink/Cargo.toml +++ b/benchmarks/guest/kitchen-sink/Cargo.toml @@ -1,35 +1,25 @@ -[workspace] [package] name = "openvm-kitchen-sink-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", default-features = false, features = [ +openvm = { workspace = true, features = [ "std", ] } -openvm-algebra-guest = { path = "../../../extensions/algebra/guest", default-features = false } -openvm-ecc-guest = { path = "../../../extensions/ecc/guest", default-features = false } -openvm-pairing = { path = "../../../guest-libs/pairing/", features = [ +openvm-algebra-guest.workspace = true +openvm-ecc-guest.workspace = true +openvm-pairing = { workspace = true, features = [ "bn254", "bls12_381", ] } -openvm-keccak256 = { path = "../../../guest-libs/keccak256/", default-features = false } -openvm-sha2 = { path = "../../../guest-libs/sha2/", default-features = false } -openvm-k256 = { path = "../../../guest-libs/k256/", package = "k256" } -openvm-p256 = { path = "../../../guest-libs/p256/", package = "p256" } -openvm-ruint = { path = "../../../guest-libs/ruint/", package = "ruint", default-features = false } +openvm-keccak256.workspace = true +openvm-sha2.workspace = true +openvm-k256.workspace = true +openvm-p256.workspace = true +openvm-ruint.workspace = true hex = { version = "0.4.3", default-features = false, features = ["alloc"] } serde = "1.0" [features] default = [] - -[profile.release] -panic = "abort" -lto = "thin" # faster compile time - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf b/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf index 85f3509fa5..fb59df5d0a 100755 Binary files a/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf and b/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf differ diff --git a/benchmarks/guest/kitchen-sink/openvm.toml b/benchmarks/guest/kitchen-sink/openvm.toml index 2d1b307eef..a8a39a9fac 100644 --- a/benchmarks/guest/kitchen-sink/openvm.toml +++ b/benchmarks/guest/kitchen-sink/openvm.toml @@ -39,31 +39,35 @@ supported_moduli = [ ], ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "7" -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "P256Point" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" b = "41058363725152142129326129780047268409114441015993725554835256314039467401291" -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Bn254G1Affine" modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" scalar = "21888242871839275222246405745257275088548364400416034343698204186575808495617" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "3" -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Bls12_381G1Affine" modulus = "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" scalar = "52435875175126190479447740508185965837690552500527637822603658699938581184513" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "4" diff --git a/benchmarks/guest/kitchen-sink/openvm_init.rs b/benchmarks/guest/kitchen-sink/openvm_init.rs index c4a80b3602..fdec4a6bdb 100644 --- a/benchmarks/guest/kitchen-sink/openvm_init.rs +++ b/benchmarks/guest/kitchen-sink/openvm_init.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "1000000000000000003", "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369", "21888242871839275222246405745257275088696311157297823662689037894645226208583", "21888242871839275222246405745257275088548364400416034343698204186575808495617", "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787", "52435875175126190479447740508185965837690552500527637822603658699938581184513", "2305843009213693951", "7" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 5 }, Bls12_381Fp2 { mod_idx = 7 } } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point, P256Point, Bn254G1Affine, Bls12_381G1Affine } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 5 }, "Bls12_381Fp2" { mod_idx = 7 } } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point", "P256Point", "Bn254G1Affine", "Bls12_381G1Affine" } diff --git a/benchmarks/guest/pairing/Cargo.toml b/benchmarks/guest/pairing/Cargo.toml index dfd73f5eb6..f616b19399 100644 --- a/benchmarks/guest/pairing/Cargo.toml +++ b/benchmarks/guest/pairing/Cargo.toml @@ -1,30 +1,16 @@ -[workspace] [package] name = "openvm-pairing-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-algebra-guest = { path = "../../../extensions/algebra/guest", default-features = false } -openvm-ecc-guest = { path = "../../../extensions/ecc/guest", default-features = false } -openvm-pairing = { path = "../../../guest-libs/pairing/", default-features = false, features = [ - "bn254", -] } -openvm-pairing-guest = { path = "../../../extensions/pairing/guest", default-features = false, features = [ +openvm = { workspace = true, features = ["std"] } +openvm-algebra-guest.workspace = true +openvm-ecc-guest.workspace = true +openvm-pairing = { workspace = true, features = [ "bn254", ] } hex = { version = "0.4.3", default-features = false, features = ["alloc"] } [features] default = [] -halo2curves = ["openvm-pairing-guest/halo2curves"] - -[profile.release] -panic = "abort" -lto = "thin" # faster compile time - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/pairing/elf/openvm-pairing-program.elf b/benchmarks/guest/pairing/elf/openvm-pairing-program.elf index bf30d5a003..69c3cd0106 100755 Binary files a/benchmarks/guest/pairing/elf/openvm-pairing-program.elf and b/benchmarks/guest/pairing/elf/openvm-pairing-program.elf differ diff --git a/benchmarks/guest/pairing/openvm.toml b/benchmarks/guest/pairing/openvm.toml index 321383b8eb..f34872b531 100644 --- a/benchmarks/guest/pairing/openvm.toml +++ b/benchmarks/guest/pairing/openvm.toml @@ -21,9 +21,10 @@ supported_moduli = [ supported_curves = ["Bn254"] # bn254 (alt bn128) -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Bn254G1Affine" modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" scalar = "21888242871839275222246405745257275088548364400416034343698204186575808495617" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "3" diff --git a/benchmarks/guest/pairing/openvm_init.rs b/benchmarks/guest/pairing/openvm_init.rs index 5baf894946..25f8df7f78 100644 --- a/benchmarks/guest/pairing/openvm_init.rs +++ b/benchmarks/guest/pairing/openvm_init.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583", "21888242871839275222246405745257275088548364400416034343698204186575808495617" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { Bn254G1Affine } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! { "Bn254G1Affine" } diff --git a/benchmarks/guest/pairing/src/main.rs b/benchmarks/guest/pairing/src/main.rs index 2b30297248..09e4a259dd 100644 --- a/benchmarks/guest/pairing/src/main.rs +++ b/benchmarks/guest/pairing/src/main.rs @@ -1,9 +1,9 @@ use openvm_algebra_guest::IntMod; use openvm_ecc_guest::AffinePoint; #[allow(unused_imports)] -use { - openvm_pairing::bn254::{Bn254, Bn254G1Affine, Fp, Fp2}, - openvm_pairing_guest::pairing::PairingCheck, +use openvm_pairing::{ + bn254::{Bn254, Bn254G1Affine, Fp, Fp2}, + PairingCheck, }; openvm::init!(); diff --git a/benchmarks/guest/quicksort/Cargo.toml b/benchmarks/guest/quicksort/Cargo.toml index 8556264be0..729208640f 100644 --- a/benchmarks/guest/quicksort/Cargo.toml +++ b/benchmarks/guest/quicksort/Cargo.toml @@ -1,16 +1,10 @@ -[workspace] [package] name = "openvm-quicksort-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf b/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf index 54af6272d6..0e7d6e6143 100755 Binary files a/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf and b/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf differ diff --git a/benchmarks/guest/quicksort/src/main.rs b/benchmarks/guest/quicksort/src/main.rs index 30218cf40e..a6579306c7 100644 --- a/benchmarks/guest/quicksort/src/main.rs +++ b/benchmarks/guest/quicksort/src/main.rs @@ -1,7 +1,7 @@ use core::hint::black_box; use openvm as _; -const ARRAY_SIZE: usize = 1_000; +const ARRAY_SIZE: usize = 3_500; fn quicksort(arr: &mut [T]) { if arr.len() <= 1 { diff --git a/benchmarks/guest/regex/Cargo.toml b/benchmarks/guest/regex/Cargo.toml index 40831a592d..1ffb3fb440 100644 --- a/benchmarks/guest/regex/Cargo.toml +++ b/benchmarks/guest/regex/Cargo.toml @@ -1,18 +1,12 @@ -[workspace] [package] -version = "0.1.0" name = "openvm-regex-program" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-keccak256 = { path = "../../../guest-libs/keccak256/" } +openvm = { workspace = true, features = ["std"] } +openvm-keccak256.workspace = true regex = { version = "1.11.1", default-features = false } [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/regex/elf/openvm-regex-program.elf b/benchmarks/guest/regex/elf/openvm-regex-program.elf index 6e6074e079..05388a8223 100755 Binary files a/benchmarks/guest/regex/elf/openvm-regex-program.elf and b/benchmarks/guest/regex/elf/openvm-regex-program.elf differ diff --git a/benchmarks/guest/revm_snailtracer/Cargo.toml b/benchmarks/guest/revm_snailtracer/Cargo.toml index e37595eb36..6f0a5176b5 100644 --- a/benchmarks/guest/revm_snailtracer/Cargo.toml +++ b/benchmarks/guest/revm_snailtracer/Cargo.toml @@ -1,11 +1,10 @@ -[workspace] [package] name = "openvm-revm-snailtracer" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } revm = { version = "18.0.0", default-features = false } # revm does not re-export this feature so we enable it here derive_more = { version = "1.0.0", default-features = false, features = [ @@ -15,8 +14,3 @@ derive_more = { version = "1.0.0", default-features = false, features = [ [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf b/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf index 9255290412..26e1d4c515 100755 Binary files a/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf and b/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf differ diff --git a/benchmarks/guest/revm_transfer/Cargo.toml b/benchmarks/guest/revm_transfer/Cargo.toml index eea02dd155..c7dc11bdec 100644 --- a/benchmarks/guest/revm_transfer/Cargo.toml +++ b/benchmarks/guest/revm_transfer/Cargo.toml @@ -1,13 +1,12 @@ -[workspace] [package] name = "openvm-revm-transfer" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] +openvm = { workspace = true, features = ["std"] } +openvm-keccak256.workspace = true revm = { version = "18.0.0", default-features = false } -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-keccak256-guest = { path = "../../../extensions/keccak256/guest", default-features = false } tracing = { version = "0.1.40", default-features = false } alloy-primitives = { version = "0.8.10", default-features = false, features = [ "native-keccak", @@ -20,8 +19,3 @@ derive_more = { version = "1.0.0", default-features = false, features = [ [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf b/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf index 0aa22396e6..96f7d328e9 100755 Binary files a/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf and b/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf differ diff --git a/benchmarks/guest/revm_transfer/src/main.rs b/benchmarks/guest/revm_transfer/src/main.rs index ff725efca3..88189cc3cd 100644 --- a/benchmarks/guest/revm_transfer/src/main.rs +++ b/benchmarks/guest/revm_transfer/src/main.rs @@ -2,7 +2,7 @@ //! We run 100 transfers to take the average use alloy_primitives::{address, TxKind, U256}; #[allow(unused_imports, clippy::single_component_path_imports)] -use openvm_keccak256_guest; // export native keccak +use openvm_keccak256; // export native keccak use revm::{db::BenchmarkDB, primitives::Bytecode, Evm}; // Necessary so the linker doesn't skip importing openvm crate diff --git a/benchmarks/guest/rkyv/Cargo.toml b/benchmarks/guest/rkyv/Cargo.toml index c061e59e0f..ee61726cca 100644 --- a/benchmarks/guest/rkyv/Cargo.toml +++ b/benchmarks/guest/rkyv/Cargo.toml @@ -1,11 +1,10 @@ -[workspace] [package] name = "openvm-rkyv-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm = { workspace = true, features = ["std"] } rand = { version = "0.8.5", default-features = false } rand_pcg = "0.3.1" rkyv = { version = "0.8.8", default-features = false, features = [ @@ -15,8 +14,3 @@ rkyv = { version = "0.8.8", default-features = false, features = [ [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf b/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf index 528106e233..f2b7f8d95d 100755 Binary files a/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf and b/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf differ diff --git a/benchmarks/guest/sha256/Cargo.toml b/benchmarks/guest/sha256/Cargo.toml index 1d5491f35a..4b711f2589 100644 --- a/benchmarks/guest/sha256/Cargo.toml +++ b/benchmarks/guest/sha256/Cargo.toml @@ -1,17 +1,11 @@ -[workspace] [package] name = "openvm-sha256-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-sha2 = { path = "../../../guest-libs/sha2" } +openvm = { workspace = true, features = ["std"] } +openvm-sha2.workspace = true [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/sha256/elf/openvm-sha256-program.elf b/benchmarks/guest/sha256/elf/openvm-sha256-program.elf index 9524e8f552..2c03e2dad6 100755 Binary files a/benchmarks/guest/sha256/elf/openvm-sha256-program.elf and b/benchmarks/guest/sha256/elf/openvm-sha256-program.elf differ diff --git a/benchmarks/guest/sha256/src/main.rs b/benchmarks/guest/sha256/src/main.rs index 0178771d09..fc0b3fab78 100644 --- a/benchmarks/guest/sha256/src/main.rs +++ b/benchmarks/guest/sha256/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_sha2::sha256; -const INPUT_LENGTH_BYTES: usize = 100 * 1024; // 100 KB +const INPUT_LENGTH_BYTES: usize = 384 * 1024; pub fn main() { let mut input = Vec::with_capacity(INPUT_LENGTH_BYTES); diff --git a/benchmarks/guest/sha256_iter/Cargo.toml b/benchmarks/guest/sha256_iter/Cargo.toml index 8e0273858a..7934b46f66 100644 --- a/benchmarks/guest/sha256_iter/Cargo.toml +++ b/benchmarks/guest/sha256_iter/Cargo.toml @@ -1,17 +1,11 @@ -[workspace] [package] name = "openvm-sha256-iter-program" -version = "0.0.0" -edition = "2021" +version.workspace = true +edition.workspace = true [dependencies] -openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-sha2 = { path = "../../../guest-libs/sha2" } +openvm = { workspace = true, features = ["std"] } +openvm-sha2.workspace = true [features] default = [] - -[profile.profiling] -inherits = "release" -debug = 2 -strip = false diff --git a/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf b/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf index 95b469ece5..677d9a3b7a 100755 Binary files a/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf and b/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf differ diff --git a/benchmarks/guest/sha256_iter/src/main.rs b/benchmarks/guest/sha256_iter/src/main.rs index 0b495a58a8..aea8b723e9 100644 --- a/benchmarks/guest/sha256_iter/src/main.rs +++ b/benchmarks/guest/sha256_iter/src/main.rs @@ -1,13 +1,13 @@ use core::hint::black_box; -use openvm as _; +use openvm as _; use openvm_sha2::sha256; -const ITERATIONS: usize = 20_000; +const ITERATIONS: usize = 150_000; pub fn main() { // Initialize with hash of an empty vector - let mut hash = black_box(sha256(&vec![])); + let mut hash = black_box(sha256(&[])); // Iteratively apply sha256 for _ in 0..ITERATIONS { diff --git a/benchmarks/prove/Cargo.toml b/benchmarks/prove/Cargo.toml index 9e745d3d80..0861133d09 100644 --- a/benchmarks/prove/Cargo.toml +++ b/benchmarks/prove/Cargo.toml @@ -10,20 +10,11 @@ license.workspace = true [dependencies] openvm-benchmarks-utils.workspace = true openvm-circuit.workspace = true +openvm-continuations.workspace = true openvm-sdk.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true openvm-transpiler.workspace = true -openvm-rv32im-circuit.workspace = true -openvm-rv32im-transpiler.workspace = true -openvm-keccak256-circuit.workspace = true -openvm-keccak256-transpiler.workspace = true -openvm-algebra-circuit.workspace = true -openvm-algebra-transpiler.workspace = true -openvm-ecc-circuit.workspace = true -openvm-ecc-transpiler.workspace = true -openvm-pairing-circuit.workspace = true -openvm-pairing-guest.workspace = true openvm-native-circuit.workspace = true openvm-native-compiler.workspace = true openvm-native-recursion = { workspace = true, features = ["test-utils"] } @@ -34,19 +25,19 @@ tokio = { version = "1.43.1", features = ["rt", "rt-multi-thread", "macros"] } rand_chacha = { version = "0.3", default-features = false } k256 = { workspace = true, features = ["ecdsa"] } tiny-keccak.workspace = true -derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } -num-bigint = { workspace = true, features = ["std", "serde"] } -serde.workspace = true +rand.workspace = true tracing.workspace = true [dev-dependencies] [features] -default = ["parallel", "jemalloc", "bench-metrics"] -bench-metrics = ["openvm-sdk/bench-metrics"] -profiling = ["openvm-sdk/profiling"] -aggregation = [] # runs leaf aggregation benchmarks +default = ["parallel", "jemalloc", "metrics"] +metrics = ["openvm-sdk/metrics"] +perf-metrics = ["openvm-sdk/perf-metrics", "metrics"] +stark-debug = ["openvm-sdk/stark-debug"] +# runs leaf aggregation benchmarks: +aggregation = [] evm = ["openvm-sdk/evm-verify"] parallel = ["openvm-sdk/parallel"] mimalloc = ["openvm-sdk/mimalloc"] @@ -55,7 +46,7 @@ jemalloc-prof = ["openvm-sdk/jemalloc-prof"] nightly-features = ["openvm-sdk/nightly-features"] [package.metadata.cargo-shear] -ignored = ["derive_more"] +ignored = ["derive_more", "rand"] [[bin]] name = "fib_e2e" diff --git a/benchmarks/prove/src/bin/base64_json.rs b/benchmarks/prove/src/bin/base64_json.rs index ed366e51ca..f1e86ff571 100644 --- a/benchmarks/prove/src/bin/base64_json.rs +++ b/benchmarks/prove/src/bin/base64_json.rs @@ -2,33 +2,31 @@ use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::instructions::exe::VmExe; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +use openvm_sdk::{ + config::{SdkVmConfig, SdkVmCpuBuilder}, + StdIn, }; -use openvm_sdk::StdIn; -use openvm_stark_sdk::{bench::run_with_metric_collection, p3_baby_bear::BabyBear}; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_stark_sdk::bench::run_with_metric_collection; +use openvm_transpiler::FromElf; fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Keccak256Rv32Config::default(); + let config = SdkVmConfig::from_toml(include_str!("../../../guest/base64_json/openvm.toml"))? + .app_vm_config; let elf = args.build_bench_program("base64_json", &config, None)?; - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - )?; + let exe = VmExe::from_elf(elf, config.transpiler())?; run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { let data = include_str!("../../../guest/base64_json/json_payload_encoded.txt"); let fe_bytes = data.to_owned().into_bytes(); - args.bench_from_exe("base64_json", config, exe, StdIn::from_bytes(&fe_bytes)) + args.bench_from_exe( + "base64_json", + SdkVmCpuBuilder, + config, + exe, + StdIn::from_bytes(&fe_bytes), + ) }) } diff --git a/benchmarks/prove/src/bin/bincode.rs b/benchmarks/prove/src/bin/bincode.rs index 3cc419c1e1..810812cc40 100644 --- a/benchmarks/prove/src/bin/bincode.rs +++ b/benchmarks/prove/src/bin/bincode.rs @@ -2,29 +2,23 @@ use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::instructions::exe::VmExe; -use openvm_rv32im_circuit::Rv32ImConfig; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +use openvm_sdk::{ + config::{SdkVmConfig, SdkVmCpuBuilder}, + StdIn, }; -use openvm_sdk::StdIn; -use openvm_stark_sdk::{bench::run_with_metric_collection, p3_baby_bear::BabyBear}; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_stark_sdk::bench::run_with_metric_collection; +use openvm_transpiler::FromElf; fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Rv32ImConfig::default(); + let config = + SdkVmConfig::from_toml(include_str!("../../../guest/bincode/openvm.toml"))?.app_vm_config; let elf = args.build_bench_program("bincode", &config, None)?; - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), - )?; + let exe = VmExe::from_elf(elf, config.transpiler())?; run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { let file_data = include_bytes!("../../../guest/bincode/minecraft_savedata.bin"); let stdin = StdIn::from_bytes(file_data); - args.bench_from_exe("bincode", config, exe, stdin) + args.bench_from_exe("bincode", SdkVmCpuBuilder, config, exe, stdin) }) } diff --git a/benchmarks/prove/src/bin/ecrecover.rs b/benchmarks/prove/src/bin/ecrecover.rs index 23fe2c82af..3f3a4120ad 100644 --- a/benchmarks/prove/src/bin/ecrecover.rs +++ b/benchmarks/prove/src/bin/ecrecover.rs @@ -1,35 +1,13 @@ use clap::Parser; use eyre::Result; use k256::ecdsa::{SigningKey, VerifyingKey}; -use num_bigint::BigUint; -use openvm_algebra_circuit::{ - ModularExtension, ModularExtensionExecutor, ModularExtensionPeriphery, -}; -use openvm_algebra_transpiler::ModularTranspilerExtension; use openvm_benchmarks_prove::util::BenchmarkCli; -use openvm_circuit::{ - arch::{instructions::exe::VmExe, InitFileGenerator, SystemConfig}, - derive::VmConfig, -}; -use openvm_ecc_circuit::{ - CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, - SECP256K1_CONFIG, -}; -use openvm_ecc_transpiler::EccTranspilerExtension; -use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, -}; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; +use openvm_circuit::arch::instructions::exe::VmExe; +use openvm_sdk::config::{SdkVmConfig, SdkVmCpuBuilder}; +use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{bench::run_with_metric_collection, p3_baby_bear::BabyBear}; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_transpiler::FromElf; use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; -use serde::{Deserialize, Serialize}; use tiny_keccak::{Hasher, Keccak}; fn make_input(signing_key: &SigningKey, msg: &[u8]) -> Vec { @@ -48,68 +26,13 @@ fn make_input(signing_key: &SigningKey, msg: &[u8]) -> Vec { input.into_iter().map(BabyBear::from_canonical_u8).collect() } -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Rv32ImEcRecoverConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub base: Rv32I, - #[extension] - pub mul: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub modular: ModularExtension, - #[extension] - pub keccak: Keccak256, - #[extension] - pub weierstrass: WeierstrassExtension, -} - -impl InitFileGenerator for Rv32ImEcRecoverConfig { - fn generate_init_file_contents(&self) -> Option { - Some(format!( - "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", - self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() - )) - } -} - -impl Rv32ImEcRecoverConfig { - pub fn for_curves(curves: Vec) -> Self { - let primes: Vec = curves - .iter() - .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) - .collect(); - Self { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(primes), - keccak: Default::default(), - weierstrass: WeierstrassExtension::new(curves), - } - } -} - fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Rv32ImEcRecoverConfig::for_curves(vec![SECP256K1_CONFIG.clone()]); - + let config = + SdkVmConfig::from_toml(include_str!("../../../guest/ecrecover/openvm.toml"))?.app_vm_config; let elf = args.build_bench_program("ecrecover", &config, None)?; - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension) - .with_extension(ModularTranspilerExtension) - .with_extension(EccTranspilerExtension), - )?; + let exe = VmExe::from_elf(elf, config.transpiler())?; run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { let mut rng = ChaCha8Rng::seed_from_u64(12345); @@ -135,6 +58,12 @@ fn main() -> Result<()> { .map(|s| make_input(&signing_key, s.as_bytes())) .collect::>(), ); - args.bench_from_exe("ecrecover_program", config, exe, input_stream.into()) + args.bench_from_exe( + "ecrecover_program", + SdkVmCpuBuilder, + config, + exe, + input_stream.into(), + ) }) } diff --git a/benchmarks/prove/src/bin/fib_e2e.rs b/benchmarks/prove/src/bin/fib_e2e.rs index 41611d0970..b88dfb335f 100644 --- a/benchmarks/prove/src/bin/fib_e2e.rs +++ b/benchmarks/prove/src/bin/fib_e2e.rs @@ -3,19 +3,22 @@ use std::{path::PathBuf, sync::Arc}; use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; -use openvm_circuit::arch::{instructions::exe::VmExe, DEFAULT_MAX_NUM_PUBLIC_VALUES}; -use openvm_native_recursion::halo2::utils::{CacheHalo2ParamsReader, DEFAULT_PARAMS_DIR}; -use openvm_rv32im_circuit::Rv32ImConfig; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +use openvm_circuit::arch::{ + execution_mode::metered::segment_ctx::SegmentationLimits, instructions::exe::VmExe, + DEFAULT_MAX_NUM_PUBLIC_VALUES, }; +use openvm_native_circuit::NativeCpuBuilder; +use openvm_native_recursion::halo2::utils::{CacheHalo2ParamsReader, DEFAULT_PARAMS_DIR}; use openvm_sdk::{ - commit::commit_app_exe, prover::EvmHalo2Prover, DefaultStaticVerifierPvHandler, Sdk, StdIn, + commit::commit_app_exe, + config::{SdkVmConfig, SdkVmCpuBuilder}, + prover::EvmHalo2Prover, + DefaultStaticVerifierPvHandler, Sdk, StdIn, }; use openvm_stark_sdk::{ bench::run_with_metric_collection, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, }; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_transpiler::FromElf; const NUM_PUBLIC_VALUES: usize = DEFAULT_MAX_NUM_PUBLIC_VALUES; @@ -26,12 +29,16 @@ async fn main() -> Result<()> { // Must be larger than RangeTupleCheckerAir.height == 524288 let max_segment_length = args.max_segment_length.unwrap_or(1_000_000); - let app_config = args.app_config(Rv32ImConfig::with_public_values_and_segment_len( - NUM_PUBLIC_VALUES, - max_segment_length, - )); - let elf = args.build_bench_program("fibonacci", &app_config.app_vm_config, None)?; + let mut config = + SdkVmConfig::from_toml(include_str!("../../../guest/fibonacci/openvm.toml"))?.app_vm_config; + config.as_mut().set_segmentation_limits( + SegmentationLimits::default().with_max_trace_height(max_segment_length as u32), + ); + config.as_mut().num_public_values = NUM_PUBLIC_VALUES; + let elf = args.build_bench_program("fibonacci", &config, None)?; + let exe = VmExe::from_elf(elf, config.transpiler())?; + let app_config = args.app_config(config); let agg_config = args.agg_config(); let sdk = Sdk::new(); @@ -46,29 +53,24 @@ async fn main() -> Result<()> { &halo2_params_reader, &DefaultStaticVerifierPvHandler, )?; - let exe = VmExe::from_elf( - elf, - Transpiler::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), - )?; let app_committed_exe = commit_app_exe(app_pk.app_fri_params(), exe); let n = 800_000u64; let mut stdin = StdIn::default(); stdin.write(&n); run_with_metric_collection("OUTPUT_PATH", || { - let mut e2e_prover = EvmHalo2Prover::<_, BabyBearPoseidon2Engine>::new( + let mut e2e_prover = EvmHalo2Prover::::new( &halo2_params_reader, + SdkVmCpuBuilder, + NativeCpuBuilder, app_pk, app_committed_exe, full_agg_pk, args.agg_tree_config, - ); + )?; e2e_prover.set_program_name("fib_e2e"); - let _proof = e2e_prover.generate_proof_for_evm(stdin); - }); + e2e_prover.generate_proof_for_evm(stdin) + })?; Ok(()) } diff --git a/benchmarks/prove/src/bin/fibonacci.rs b/benchmarks/prove/src/bin/fibonacci.rs index 1c886d8130..6a0a44f00d 100644 --- a/benchmarks/prove/src/bin/fibonacci.rs +++ b/benchmarks/prove/src/bin/fibonacci.rs @@ -2,31 +2,25 @@ use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::instructions::exe::VmExe; -use openvm_rv32im_circuit::Rv32ImConfig; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +use openvm_sdk::{ + config::{SdkVmConfig, SdkVmCpuBuilder}, + StdIn, }; -use openvm_sdk::StdIn; -use openvm_stark_sdk::{bench::run_with_metric_collection, p3_baby_bear::BabyBear}; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_stark_sdk::bench::run_with_metric_collection; +use openvm_transpiler::FromElf; fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Rv32ImConfig::default(); + let config = + SdkVmConfig::from_toml(include_str!("../../../guest/fibonacci/openvm.toml"))?.app_vm_config; let elf = args.build_bench_program("fibonacci", &config, None)?; - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), - )?; + let exe = VmExe::from_elf(elf, config.transpiler())?; run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { let n = 100_000u64; let mut stdin = StdIn::default(); stdin.write(&n); - args.bench_from_exe("fibonacci_program", config, exe, stdin) + args.bench_from_exe("fibonacci_program", SdkVmCpuBuilder, config, exe, stdin) }) } diff --git a/benchmarks/prove/src/bin/kitchen_sink.rs b/benchmarks/prove/src/bin/kitchen_sink.rs index 3102c9e3fe..4385f55050 100644 --- a/benchmarks/prove/src/bin/kitchen_sink.rs +++ b/benchmarks/prove/src/bin/kitchen_sink.rs @@ -1,73 +1,94 @@ -use std::{path::PathBuf, str::FromStr, sync::Arc}; +use std::{path::PathBuf, sync::Arc}; use clap::Parser; use eyre::Result; -use num_bigint::BigUint; -use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_benchmarks_prove::util::BenchmarkCli; -use openvm_circuit::arch::{instructions::exe::VmExe, SystemConfig}; -use openvm_ecc_circuit::{WeierstrassExtension, P256_CONFIG, SECP256K1_CONFIG}; +use openvm_circuit::{arch::instructions::exe::VmExe, system::program::trace::VmCommittedExe}; +use openvm_continuations::verifier::leaf::types::LeafVmVerifierInput; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_recursion::halo2::utils::{CacheHalo2ParamsReader, DEFAULT_PARAMS_DIR}; -use openvm_pairing_circuit::{PairingCurve, PairingExtension}; -use openvm_pairing_guest::{ - bls12_381::BLS12_381_COMPLEX_STRUCT_NAME, bn254::BN254_COMPLEX_STRUCT_NAME, -}; use openvm_sdk::{ - commit::commit_app_exe, config::SdkVmConfig, prover::EvmHalo2Prover, - DefaultStaticVerifierPvHandler, Sdk, StdIn, + commit::commit_app_exe, + config::{SdkVmConfig, SdkVmCpuBuilder}, + keygen::AppProvingKey, + prover::{ + vm::{new_local_prover, types::VmProvingKey}, + EvmHalo2Prover, + }, + DefaultStaticVerifierPvHandler, Sdk, StdIn, SC, }; use openvm_stark_sdk::{ bench::run_with_metric_collection, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, }; use openvm_transpiler::FromElf; +fn verify_native_max_trace_heights( + sdk: &Sdk, + app_pk: Arc>, + app_committed_exe: Arc>, + leaf_vm_pk: Arc>, + num_children_leaf: usize, +) -> Result<()> { + let app_proof = sdk.generate_app_proof( + SdkVmCpuBuilder, + app_pk.clone(), + app_committed_exe.clone(), + StdIn::default(), + )?; + let leaf_inputs = + LeafVmVerifierInput::chunk_continuation_vm_proof(&app_proof, num_children_leaf); + let mut leaf_prover = new_local_prover::( + NativeCpuBuilder, + &leaf_vm_pk, + &app_pk.leaf_committed_exe, + )?; + let executor_idx_to_air_idx = leaf_prover.vm.executor_idx_to_air_idx(); + + for leaf_input in leaf_inputs { + let exe = leaf_prover.exe().clone(); + let vm = &mut leaf_prover.vm; + let metered_ctx = vm.build_metered_ctx(); + let (segments, _) = vm + .executor() + .metered_instance(&exe, &executor_idx_to_air_idx)? + .execute_metered(leaf_input.write_to_stream(), metered_ctx)?; + assert_eq!(segments.len(), 1); + let estimated_trace_heights = &segments[0].trace_heights; + println!("estimated_trace_heights: {:?}", estimated_trace_heights); + + // Tracegen without proving since leaf proofs take a while + let state = vm.create_initial_state(&exe, leaf_input.write_to_stream()); + vm.transport_init_memory_to_device(&state.memory); + let out = vm.execute_preflight(&exe, state, None, estimated_trace_heights)?; + let actual_trace_heights = vm + .generate_proving_ctx(out.system_records, out.record_arenas)? + .per_air + .into_iter() + .map(|(_, air_ctx)| air_ctx.main_trace_height()) + .collect::>(); + println!("actual_trace_heights: {:?}", actual_trace_heights); + + actual_trace_heights + .iter() + .zip(NATIVE_MAX_TRACE_HEIGHTS) + .for_each(|(&actual, &expected)| { + assert!( + actual <= (expected as usize), + "Actual trace height {} exceeds expected height {}", + actual, + expected + ); + }); + } + Ok(()) +} + fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let bn_config = PairingCurve::Bn254.curve_config(); - let bls_config = PairingCurve::Bls12_381.curve_config(); - let vm_config = SdkVmConfig::builder() - .system(SystemConfig::default().with_continuations().into()) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .keccak(Default::default()) - .sha256(Default::default()) - .bigint(Default::default()) - .modular(ModularExtension::new(vec![ - BigUint::from_str("1000000000000000003").unwrap(), - SECP256K1_CONFIG.modulus.clone(), - SECP256K1_CONFIG.scalar.clone(), - P256_CONFIG.modulus.clone(), - P256_CONFIG.scalar.clone(), - bn_config.modulus.clone(), - bn_config.scalar.clone(), - bls_config.modulus.clone(), - bls_config.scalar.clone(), - BigUint::from(2u32).pow(61) - BigUint::from(1u32), - BigUint::from(7u32), - ])) - .fp2(Fp2Extension::new(vec![ - ( - BN254_COMPLEX_STRUCT_NAME.to_string(), - bn_config.modulus.clone(), - ), - ( - BLS12_381_COMPLEX_STRUCT_NAME.to_string(), - bls_config.modulus.clone(), - ), - ])) - .ecc(WeierstrassExtension::new(vec![ - SECP256K1_CONFIG.clone(), - P256_CONFIG.clone(), - bn_config.clone(), - bls_config.clone(), - ])) - .pairing(PairingExtension::new(vec![ - PairingCurve::Bn254, - PairingCurve::Bls12_381, - ])) - .build(); + let vm_config = + SdkVmConfig::from_toml(include_str!("../../../guest/kitchen-sink/openvm.toml"))? + .app_vm_config; let elf = args.build_bench_program("kitchen-sink", &vm_config, None)?; let exe = VmExe::from_elf(elf, vm_config.transpiler())?; @@ -88,17 +109,28 @@ fn main() -> Result<()> { &DefaultStaticVerifierPvHandler, )?; - run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { - let mut prover = EvmHalo2Prover::<_, BabyBearPoseidon2Engine>::new( + // Verify that NATIVE_MAX_TRACE_HEIGHTS remains valid + verify_native_max_trace_heights( + &sdk, + app_pk.clone(), + app_committed_exe.clone(), + full_agg_pk.agg_stark_pk.leaf_vm_pk.clone(), + args.agg_tree_config.num_children_leaf, + )?; + + run_with_metric_collection("OUTPUT_PATH", || { + let mut prover = EvmHalo2Prover::::new( &halo2_params_reader, + SdkVmCpuBuilder, + NativeCpuBuilder, app_pk, app_committed_exe, full_agg_pk, args.agg_tree_config, - ); + )?; prover.set_program_name("kitchen_sink"); let stdin = StdIn::default(); - let _proof = prover.generate_proof_for_evm(stdin); - Ok(()) - }) + prover.generate_proof_for_evm(stdin) + })?; + Ok(()) } diff --git a/benchmarks/prove/src/bin/pairing.rs b/benchmarks/prove/src/bin/pairing.rs index 1db6d1b491..457edddca4 100644 --- a/benchmarks/prove/src/bin/pairing.rs +++ b/benchmarks/prove/src/bin/pairing.rs @@ -1,41 +1,22 @@ use clap::Parser; use eyre::Result; -use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_benchmarks_prove::util::BenchmarkCli; -use openvm_circuit::arch::SystemConfig; -use openvm_ecc_circuit::WeierstrassExtension; -use openvm_pairing_circuit::{PairingCurve, PairingExtension}; -use openvm_pairing_guest::bn254::{BN254_COMPLEX_STRUCT_NAME, BN254_MODULUS, BN254_ORDER}; -use openvm_sdk::{config::SdkVmConfig, Sdk, StdIn}; +use openvm_sdk::{ + config::{SdkVmConfig, SdkVmCpuBuilder}, + Sdk, StdIn, +}; use openvm_stark_sdk::bench::run_with_metric_collection; fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let vm_config = SdkVmConfig::builder() - .system(SystemConfig::default().with_continuations().into()) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .keccak(Default::default()) - .modular(ModularExtension::new(vec![ - BN254_MODULUS.clone(), - BN254_ORDER.clone(), - ])) - .fp2(Fp2Extension::new(vec![( - BN254_COMPLEX_STRUCT_NAME.to_string(), - BN254_MODULUS.clone(), - )])) - .ecc(WeierstrassExtension::new(vec![ - PairingCurve::Bn254.curve_config() - ])) - .pairing(PairingExtension::new(vec![PairingCurve::Bn254])) - .build(); + let vm_config = + SdkVmConfig::from_toml(include_str!("../../../guest/pairing/openvm.toml"))?.app_vm_config; let elf = args.build_bench_program("pairing", &vm_config, None)?; let sdk = Sdk::new(); let exe = sdk.transpile(elf, vm_config.transpiler()).unwrap(); run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { - args.bench_from_exe("pairing", vm_config, exe, StdIn::default()) + args.bench_from_exe("pairing", SdkVmCpuBuilder, vm_config, exe, StdIn::default()) }) } diff --git a/benchmarks/prove/src/bin/regex.rs b/benchmarks/prove/src/bin/regex.rs index d1de43dad5..95b150801a 100644 --- a/benchmarks/prove/src/bin/regex.rs +++ b/benchmarks/prove/src/bin/regex.rs @@ -2,32 +2,30 @@ use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::instructions::exe::VmExe; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +use openvm_sdk::{ + config::{SdkVmConfig, SdkVmCpuBuilder}, + StdIn, }; -use openvm_sdk::StdIn; -use openvm_stark_sdk::{bench::run_with_metric_collection, p3_baby_bear::BabyBear}; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_stark_sdk::bench::run_with_metric_collection; +use openvm_transpiler::FromElf; fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Keccak256Rv32Config::default(); + let config = + SdkVmConfig::from_toml(include_str!("../../../guest/regex/openvm.toml"))?.app_vm_config; let elf = args.build_bench_program("regex", &config, None)?; - let exe = VmExe::from_elf( - elf.clone(), - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - )?; + let exe = VmExe::from_elf(elf.clone(), config.transpiler())?; run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { let data = include_str!("../../../guest/regex/regex_email.txt"); let fe_bytes = data.to_owned().into_bytes(); - args.bench_from_exe("regex_program", config, exe, StdIn::from_bytes(&fe_bytes)) + args.bench_from_exe( + "regex_program", + SdkVmCpuBuilder, + config, + exe, + StdIn::from_bytes(&fe_bytes), + ) }) } diff --git a/benchmarks/prove/src/bin/revm_transfer.rs b/benchmarks/prove/src/bin/revm_transfer.rs index 1df994dc78..70dbae8109 100644 --- a/benchmarks/prove/src/bin/revm_transfer.rs +++ b/benchmarks/prove/src/bin/revm_transfer.rs @@ -2,28 +2,26 @@ use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::instructions::exe::VmExe; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +use openvm_sdk::{ + config::{SdkVmConfig, SdkVmCpuBuilder}, + StdIn, }; -use openvm_sdk::StdIn; -use openvm_stark_sdk::{bench::run_with_metric_collection, p3_baby_bear::BabyBear}; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_stark_sdk::bench::run_with_metric_collection; +use openvm_transpiler::FromElf; fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Keccak256Rv32Config::default(); + let config = SdkVmConfig::from_toml(include_str!("../../../guest/revm_transfer/openvm.toml"))? + .app_vm_config; let elf = args.build_bench_program("revm_transfer", &config, None)?; - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - )?; + let exe = VmExe::from_elf(elf, config.transpiler())?; run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { - args.bench_from_exe("revm_100_transfers", config, exe, StdIn::default()) + args.bench_from_exe( + "revm_100_transfers", + SdkVmCpuBuilder, + config, + exe, + StdIn::default(), + ) }) } diff --git a/benchmarks/prove/src/bin/rkyv.rs b/benchmarks/prove/src/bin/rkyv.rs index 7bdf6ed920..350a5e2352 100644 --- a/benchmarks/prove/src/bin/rkyv.rs +++ b/benchmarks/prove/src/bin/rkyv.rs @@ -2,30 +2,24 @@ use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::instructions::exe::VmExe; -use openvm_rv32im_circuit::Rv32ImConfig; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +use openvm_sdk::{ + config::{SdkVmConfig, SdkVmCpuBuilder}, + StdIn, }; -use openvm_sdk::StdIn; -use openvm_stark_sdk::{bench::run_with_metric_collection, p3_baby_bear::BabyBear}; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use openvm_stark_sdk::bench::run_with_metric_collection; +use openvm_transpiler::FromElf; fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Rv32ImConfig::default(); + let config = + SdkVmConfig::from_toml(include_str!("../../../guest/rkyv/openvm.toml"))?.app_vm_config; let elf = args.build_bench_program("rkyv", &config, None)?; - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), - )?; + let exe = VmExe::from_elf(elf, config.transpiler())?; run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { let file_data = include_bytes!("../../../guest/rkyv/minecraft_savedata.bin"); let stdin = StdIn::from_bytes(file_data); - args.bench_from_exe("rkyv", config, exe, stdin) + args.bench_from_exe("rkyv", SdkVmCpuBuilder, config, exe, stdin) }) } diff --git a/benchmarks/prove/src/bin/verify_fibair.rs b/benchmarks/prove/src/bin/verify_fibair.rs index 1d8d6072da..5cb84983a2 100644 --- a/benchmarks/prove/src/bin/verify_fibair.rs +++ b/benchmarks/prove/src/bin/verify_fibair.rs @@ -2,7 +2,7 @@ use clap::Parser; use eyre::Result; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::DEFAULT_MAX_NUM_PUBLIC_VALUES; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_compiler::conversion::CompilerOptions; use openvm_native_recursion::testing_utils::inner::build_verification_program; use openvm_sdk::{ @@ -12,7 +12,6 @@ use openvm_sdk::{ }; use openvm_stark_sdk::{ bench::run_with_metric_collection, - collect_airs_and_inputs, config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, dummy_airs::fib_air::chip::FibonacciChip, engine::StarkFriEngine, @@ -37,8 +36,11 @@ fn main() -> Result<()> { run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { // run_test tries to setup tracing, but it will be ignored since run_with_metric_collection // already sets it. - let (fib_air, fib_input) = collect_airs_and_inputs!(fib_chip); - let vdata = engine.run_test(fib_air, fib_input).unwrap(); + let (fib_air, fib_ctx) = ( + vec![fib_chip.air()], + vec![fib_chip.generate_proving_ctx(())], + ); + let vdata = engine.run_test(fib_air, fib_ctx).unwrap(); // Unlike other apps, this "app" does not have continuations enabled. let app_fri_params = FriParameters::standard_with_100_bits_conjectured_security(leaf_log_blowup); @@ -60,9 +62,16 @@ fn main() -> Result<()> { let app_pk = sdk.app_keygen(app_config)?; let app_vk = app_pk.get_app_vk(); let committed_exe = sdk.commit_app_exe(app_fri_params, program.into())?; - let prover = AppProver::<_, BabyBearPoseidon2Engine>::new(app_pk.app_vm_pk, committed_exe) - .with_program_name("verify_fibair"); - let proof = prover.generate_app_proof_without_continuations(input_stream.into()); + let mut prover = AppProver::::new( + NativeCpuBuilder, + app_pk.app_vm_pk, + committed_exe, + )? + .with_program_name("verify_fibair"); + let proof = prover.generate_app_proof_without_continuations( + input_stream.into(), + NATIVE_MAX_TRACE_HEIGHTS, + )?; sdk.verify_app_proof_without_continuations(&app_vk, &proof)?; Ok(()) })?; diff --git a/benchmarks/prove/src/util.rs b/benchmarks/prove/src/util.rs index b3c17ead85..7f90015698 100644 --- a/benchmarks/prove/src/util.rs +++ b/benchmarks/prove/src/util.rs @@ -1,10 +1,14 @@ -use std::{path::PathBuf, sync::Arc}; +use std::path::PathBuf; use clap::{command, Parser}; use eyre::Result; use openvm_benchmarks_utils::{build_elf, get_programs_dir}; -use openvm_circuit::arch::{instructions::exe::VmExe, DefaultSegmentationStrategy, VmConfig}; -use openvm_native_circuit::NativeConfig; +use openvm_circuit::arch::{ + execution_mode::metered::segment_ctx::SegmentationLimits, instructions::exe::VmExe, + verify_single, Executor, MatrixRecordArena, MeteredExecutor, PreflightExecutor, SystemConfig, + VmBuilder, VmConfig, VmExecutionConfig, +}; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder}; use openvm_native_compiler::conversion::CompilerOptions; use openvm_sdk::{ commit::commit_app_exe, @@ -14,17 +18,15 @@ use openvm_sdk::{ DEFAULT_ROOT_LOG_BLOWUP, }, keygen::{leaf_keygen, AppProvingKey}, - prover::{vm::local::VmLocalProver, AppProver, LeafProvingController}, - Sdk, StdIn, + prover::{vm::new_local_prover, AppProver, LeafProvingController}, + GenericSdk, StdIn, }; -use openvm_stark_backend::utils::metrics_span; use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, FriParameters, }, engine::StarkFriEngine, - openvm_stark_backend::Chip, p3_baby_bear::BabyBear, }; use openvm_transpiler::elf::Elf; @@ -75,17 +77,18 @@ pub struct BenchmarkCli { } impl BenchmarkCli { - pub fn app_config>(&self, mut app_vm_config: VC) -> AppConfig { + pub fn app_config(&self, mut app_vm_config: VC) -> AppConfig + where + VC: AsMut, + { let app_log_blowup = self.app_log_blowup.unwrap_or(DEFAULT_APP_LOG_BLOWUP); let leaf_log_blowup = self.leaf_log_blowup.unwrap_or(DEFAULT_LEAF_LOG_BLOWUP); - app_vm_config.system_mut().profiling = self.profiling; + app_vm_config.as_mut().profiling = self.profiling; if let Some(max_segment_length) = self.max_segment_length { - app_vm_config - .system_mut() - .set_segmentation_strategy(Arc::new( - DefaultSegmentationStrategy::new_with_max_segment_len(max_segment_length), - )); + app_vm_config.as_mut().set_segmentation_limits( + SegmentationLimits::default().with_max_trace_height(max_segment_length as u32), + ); } AppConfig { app_fri_params: FriParameters::standard_with_100_bits_conjectured_security( @@ -143,7 +146,7 @@ impl BenchmarkCli { init_file_name: Option<&str>, ) -> Result where - VC: VmConfig, + VC: VmConfig, { let profile = if self.profiling { "profiling" @@ -156,21 +159,24 @@ impl BenchmarkCli { build_elf(&manifest_dir, profile) } - pub fn bench_from_exe( + pub fn bench_from_exe( &self, bench_name: impl ToString, + app_vm_builder: VB, vm_config: VC, exe: impl Into>, input_stream: StdIn, ) -> Result<()> where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + VB: VmBuilder>, + VC: VmExecutionConfig + VmConfig, + >::Executor: + Executor + MeteredExecutor + PreflightExecutor, { let app_config = self.app_config(vm_config); - bench_from_exe::( + bench_from_exe::( bench_name, + app_vm_builder, app_config, exe, input_stream, @@ -190,51 +196,56 @@ impl BenchmarkCli { /// 6. Verify STARK proofs. /// /// Returns the data necessary for proof aggregation. -pub fn bench_from_exe>( +pub fn bench_from_exe( bench_name: impl ToString, - app_config: AppConfig, + app_vm_builder: VB, + app_config: AppConfig, exe: impl Into>, input_stream: StdIn, leaf_vm_config: Option, ) -> Result<()> where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + E: StarkFriEngine, + VB: VmBuilder, + >::Executor: + Executor + MeteredExecutor + PreflightExecutor, + NativeBuilder: VmBuilder + Clone + Default, + >::Executor: + PreflightExecutor>::RecordArena>, { let bench_name = bench_name.to_string(); // 1. Generate proving key from config. - let app_pk = info_span!("keygen", group = &bench_name).in_scope(|| { - metrics_span("keygen_time_ms", || { - AppProvingKey::keygen(app_config.clone()) - }) - }); + let app_pk = info_span!("keygen", group = &bench_name) + .in_scope(|| AppProvingKey::keygen(app_config.clone()))?; // 2. Commit to the exe by generating cached trace for program. - let committed_exe = info_span!("commit_exe", group = &bench_name).in_scope(|| { - metrics_span("commit_exe_time_ms", || { - commit_app_exe(app_config.app_fri_params.fri_params, exe) - }) - }); + let committed_exe = info_span!("commit_exe", group = &bench_name) + .in_scope(|| commit_app_exe(app_config.app_fri_params.fri_params, exe)); // 3. Executes runtime // 4. Generate trace // 5. Generate STARK proofs for each segment (segmentation is determined by `config`), with // timer. let app_vk = app_pk.get_app_vk(); - let prover = - AppProver::::new(app_pk.app_vm_pk, committed_exe).with_program_name(bench_name); - let app_proof = prover.generate_app_proof(input_stream); + let mut prover = AppProver::::new(app_vm_builder, app_pk.app_vm_pk, committed_exe)? + .with_program_name(bench_name); + let app_proof = prover.generate_app_proof(input_stream)?; // 6. Verify STARK proofs, including boundary conditions. - let sdk = Sdk::new(); - sdk.verify_app_proof(&app_vk, &app_proof) - .expect("Verification failed"); + let sdk = GenericSdk::::new(); + sdk.verify_app_proof(&app_vk, &app_proof)?; if let Some(leaf_vm_config) = leaf_vm_config { - let leaf_vm_pk = leaf_keygen(app_config.leaf_fri_params.fri_params, leaf_vm_config); - let leaf_prover = - VmLocalProver::::new(leaf_vm_pk, app_pk.leaf_committed_exe); + let leaf_vm_pk = leaf_keygen(app_config.leaf_fri_params.fri_params, leaf_vm_config)?; + let vk = leaf_vm_pk.vm_pk.get_vk(); + let mut leaf_prover = new_local_prover( + sdk.native_builder().clone(), + &leaf_vm_pk, + &app_pk.leaf_committed_exe, + )?; let leaf_controller = LeafProvingController { num_children: AggregationTreeConfig::default().num_children_leaf, }; - leaf_controller.generate_proof(&leaf_prover, &app_proof); + let leaf_proofs = leaf_controller.generate_proof(&mut leaf_prover, &app_proof)?; + for proof in leaf_proofs { + verify_single(&leaf_prover.vm.engine, &vk, &proof)?; + } } Ok(()) } diff --git a/benchmarks/utils/Cargo.toml b/benchmarks/utils/Cargo.toml index 1b1d600a82..da1bf1866e 100644 --- a/benchmarks/utils/Cargo.toml +++ b/benchmarks/utils/Cargo.toml @@ -16,15 +16,35 @@ clap = { version = "4.5.9", features = ["derive", "env"] } eyre.workspace = true tempfile.workspace = true tracing.workspace = true -tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } +tracing-subscriber.workspace = true + +bitcode = { workspace = true, optional = true } +openvm-circuit = { workspace = true, optional = true } +openvm-continuations = { workspace = true, optional = true } +openvm-native-circuit = { workspace = true, optional = true } +openvm-sdk = { workspace = true, optional = true } +openvm-stark-sdk = { workspace = true, optional = true } [dev-dependencies] [features] default = [] -build-binaries = [] +build-elfs = [] +generate-fixtures = [ + "dep:bitcode", + "dep:openvm-circuit", + "dep:openvm-continuations", + "dep:openvm-native-circuit", + "dep:openvm-sdk", + "dep:openvm-stark-sdk", +] [[bin]] name = "build-elfs" path = "src/build-elfs.rs" -required-features = ["build-binaries"] +required-features = ["build-elfs"] + +[[bin]] +name = "generate-fixtures" +path = "src/generate-fixtures.rs" +required-features = ["generate-fixtures"] diff --git a/benchmarks/utils/src/build-elfs.rs b/benchmarks/utils/src/build-elfs.rs index 3bed7cf6fd..3ce24c7c5c 100644 --- a/benchmarks/utils/src/build-elfs.rs +++ b/benchmarks/utils/src/build-elfs.rs @@ -63,6 +63,12 @@ fn main() -> Result<()> { let programs_to_build = if cli.programs.is_empty() { available_programs } else { + for prog in &cli.programs { + if !available_programs.iter().any(|(name, _)| name == prog) { + tracing::warn!("Program '{}' not found in available programs", prog); + } + } + available_programs .into_iter() .filter(|(name, _)| cli.programs.contains(name)) @@ -70,6 +76,12 @@ fn main() -> Result<()> { }; // Filter out skipped programs + for prog in &cli.skip { + if !programs_to_build.iter().any(|(name, _)| name == prog) { + tracing::warn!("Program '{}' not found in programs to skip", prog); + } + } + let programs_to_build = programs_to_build .into_iter() .filter(|(name, _)| !cli.skip.contains(name)) diff --git a/benchmarks/utils/src/generate-fixtures.rs b/benchmarks/utils/src/generate-fixtures.rs new file mode 100644 index 0000000000..85c0ba0191 --- /dev/null +++ b/benchmarks/utils/src/generate-fixtures.rs @@ -0,0 +1,117 @@ +use std::{fs, sync::Arc}; + +use eyre::Result; +use openvm_benchmarks_utils::{get_elf_path, get_fixtures_dir, get_programs_dir, read_elf_file}; +use openvm_circuit::arch::{instructions::exe::VmExe, VmCircuitConfig}; +use openvm_continuations::verifier::common::types::VmVerifierPvs; +use openvm_native_circuit::NativeConfig; +use openvm_sdk::{ + commit::commit_app_exe, + config::{ + AppConfig, AppFriParams, LeafFriParams, SdkVmConfig, SdkVmCpuBuilder, + DEFAULT_APP_LOG_BLOWUP, DEFAULT_LEAF_LOG_BLOWUP, SBOX_SIZE, + }, + Sdk, StdIn, +}; +use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, + engine::StarkFriEngine, +}; +use openvm_transpiler::FromElf; +use tracing_subscriber::{fmt, EnvFilter}; + +const PROGRAM: &str = "kitchen-sink"; + +fn main() -> Result<()> { + // Set up logging + fmt::fmt().with_env_filter(EnvFilter::new("info")).init(); + + let program_dir = get_programs_dir().join(PROGRAM); + + tracing::info!("Loading VM config"); + let config_path = program_dir.join("openvm.toml"); + let config_content = fs::read_to_string(&config_path)?; + let vm_config = SdkVmConfig::from_toml(&config_content)?.app_vm_config; + + tracing::info!("Preparing ELF"); + let elf_path = get_elf_path(&program_dir); + let elf = read_elf_file(&elf_path)?; + + let exe = VmExe::from_elf(elf, vm_config.transpiler())?; + + let sdk = Sdk::new(); + + // Create app config with default parameters + let app_config = AppConfig { + app_fri_params: AppFriParams { + fri_params: FriParameters::standard_with_100_bits_conjectured_security( + DEFAULT_APP_LOG_BLOWUP, + ), + }, + leaf_fri_params: LeafFriParams { + fri_params: FriParameters::standard_with_100_bits_conjectured_security( + DEFAULT_LEAF_LOG_BLOWUP, + ), + }, + app_vm_config: vm_config, + compiler_options: Default::default(), + }; + + tracing::info!("Generating app proving key"); + let app_pk = Arc::new(sdk.app_keygen(app_config.clone())?); + let app_committed_exe = commit_app_exe(app_pk.app_fri_params(), exe); + + tracing::info!("Generating app proof"); + let app_proof = sdk.generate_app_proof( + SdkVmCpuBuilder, + app_pk.clone(), + app_committed_exe, + StdIn::default(), + )?; + + tracing::info!("Generating leaf proving key"); + // Generate leaf VM proving key using the circuit keygen approach + let leaf_vm_config = NativeConfig::aggregation( + VmVerifierPvs::::width(), + SBOX_SIZE.min( + app_config + .leaf_fri_params + .fri_params + .max_constraint_degree(), + ), + ); + let circuit = leaf_vm_config.create_airs()?; + let engine = BabyBearPoseidon2Engine::new(app_config.leaf_fri_params.fri_params); + let pk = circuit.keygen(&engine); + + tracing::info!("Saving keys and proof to files"); + // Create fixtures directory if it doesn't exist + let fixtures_dir = get_fixtures_dir(); + fs::create_dir_all(&fixtures_dir)?; + + // Serialize and write to files in fixtures directory + let leaf_exe_bytes = bitcode::serialize(&app_pk.leaf_committed_exe.exe)?; + fs::write( + fixtures_dir.join(&format!("{}.leaf.exe", PROGRAM)), + leaf_exe_bytes, + )?; + + let leaf_pk_bytes = bitcode::serialize(&pk)?; + fs::write( + fixtures_dir.join(&format!("{}.leaf.pk", PROGRAM)), + leaf_pk_bytes, + )?; + + let app_proof_bytes = bitcode::serialize(&app_proof)?; + fs::write( + fixtures_dir.join(&format!("{}.app.proof", PROGRAM)), + app_proof_bytes, + )?; + + tracing::info!( + "Generated and saved {name}.leaf.committed.exe, {name}.leaf.pk, and {name}.app.proof", + name = PROGRAM + ); + + Ok(()) +} diff --git a/benchmarks/utils/src/lib.rs b/benchmarks/utils/src/lib.rs index 99e5ce917b..ad11ab2f99 100644 --- a/benchmarks/utils/src/lib.rs +++ b/benchmarks/utils/src/lib.rs @@ -9,6 +9,10 @@ use openvm_build::{build_guest_package, get_package, guest_methods, GuestOptions use openvm_transpiler::{elf::Elf, openvm_platform::memory::MEM_SIZE}; use tempfile::tempdir; +pub fn get_fixtures_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../fixtures") +} + pub fn get_programs_dir() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../guest") } diff --git a/book/src/custom-extensions/algebra.md b/book/src/custom-extensions/algebra.md index 2e1f830153..c4bcc936b1 100644 --- a/book/src/custom-extensions/algebra.md +++ b/book/src/custom-extensions/algebra.md @@ -80,7 +80,7 @@ moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } complex_init! { - Bn254Fp2 { mod_idx = 0 }, + "Bn254Fp2" { mod_idx = 0 }, } */ ``` diff --git a/book/src/custom-extensions/ecc.md b/book/src/custom-extensions/ecc.md index ba0fff90d6..02ab2448c6 100644 --- a/book/src/custom-extensions/ecc.md +++ b/book/src/custom-extensions/ecc.md @@ -17,12 +17,20 @@ Developers can enable arbitrary Weierstrass curves by configuring this extension - `WeierstrassPoint` trait: It represents an affine point on a Weierstrass elliptic curve and it extends `Group`. - - `Coordinate` type is the type of the coordinates of the point, and it implements `IntMod`. - - `x()`, `y()` are used to get the affine coordinates + - `Coordinate` type is the type of the coordinates of the point, and it implements `Field`. + - `x()`, `y()` are used to get the affine coordinates. - `from_xy` is a constructor for the point, which checks if the point is either identity or on the affine curve. - The point supports elliptic curve operations through intrinsic functions `add_ne_nonidentity` and `double_nonidentity`. - `decompress`: Sometimes an elliptic curve point is compressed and represented by its `x` coordinate and the odd/even parity of the `y` coordinate. `decompress` is used to decompress the point back to `(x, y)`. +- `TwistedEdwardsPoint` trait: + It represents an affine point on a twisted Edwards elliptic curve and it extends `Group`. + + - `Coordinate` type is the type of the coordinates of the point, and it implements `Field`. + - `x()`, `y()` are used to get the affine coordinates. + - `from_xy` is a constructor for the point, which checks if the point is on the affine curve. + - The point supports elliptic curve addition through the `add_impl` method. + - `msm`: for multi-scalar multiplication. - `ecdsa`: for doing ECDSA signature verification and public key recovery from signature. @@ -31,17 +39,20 @@ Developers can enable arbitrary Weierstrass curves by configuring this extension For elliptic curve cryptography, the `openvm-ecc-guest` crate provides macros similar to those in [`openvm-algebra-guest`](./algebra.md): -1. **Declare**: Use `sw_declare!` to define elliptic curves over the previously declared moduli. For example: +1. **Declare**: Use `sw_declare!` or `te_declare!` to define short Weierstrass or twisted Edwards elliptic curves, respectively, over the previously declared moduli. For example: ```rust sw_declare! { Bls12_381G1Affine { mod_type = Bls12_381Fp, b = BLS12_381_B }, P256Affine { mod_type = P256Coord, a = P256_A, b = P256_B }, } +te_declare! { + Edwards25519 { mod_type = Edwards25519Coord, a = CURVE_A, d = CURVE_D }, +} ``` +This creates `Bls12_381G1Affine` and `P256Affine` structs which implement the `Group` and `WeierstrassPoint` traits, and the `Edwards25519` struct which implements the `Group` and `TwistedEdwardsPoint` traits. The underlying memory layout of the structs uses the memory layout of the `Bls12_381Fp`, `P256Coord`, and `Edwards25519Coord` structs, respectively. -Each declared curve must specify the `mod_type` (implementing `IntMod`) and a constant `b` for the Weierstrass curve equation \\(y^2 = x^3 + ax + b\\). `a` is optional and defaults to 0 for short Weierstrass curves. -This creates `Bls12_381G1Affine` and `P256Affine` structs which implement the `Group` and `WeierstrassPoint` traits. The underlying memory layout of the structs uses the memory layout of the `Bls12_381Fp` and `P256Coord` structs, respectively. +Each declared curve must specify the `mod_type` (implementing `Field`) and a constant `b` for the Weierstrass curve equation \\(y^2 = x^3 + ax + b\\) or `a` and `d` for the twisted Edwards curve equation \\(ax^2 + y^2 = 1 + dx^2y^2\\). For short Weierstrass curves, `a` is optional and defaults to 0. 2. **Init**: Called once, the [`openvm::init!` macro](./overview.md#automating-the-init-step) produces a call to `sw_init!` that enumerates these curves and allows the compiler to produce optimized instructions: @@ -49,19 +60,23 @@ This creates `Bls12_381G1Affine` and `P256Affine` structs which implement the `G openvm::init!(); /* This expands to sw_init! { - Bls12_381G1Affine, P256Affine, + "Bls12_381G1Affine", "P256Affine", +} +te_init! { + Edwards25519, } */ ``` **Summary**: -- `sw_declare!`: Declares elliptic curve structures. +- `sw_declare!`: Declares short Weierstrass elliptic curve structures. +- `te_declare!`: Declares twisted Edwards elliptic curve structures. - `init!`: Initializes them once, linking them to the underlying moduli. -To use elliptic curve operations on a struct defined with `sw_declare!`, it is expected that the struct for the curve's coordinate field was defined using `moduli_declare!`. In particular, the coordinate field needs to be initialized and set up as described in the [algebra extension](./algebra.md) chapter. +To use elliptic curve operations on a struct defined with `sw_declare!` or `te_declare!`, it is expected that the struct for the curve's coordinate field was defined using `moduli_declare!`. In particular, the coordinate field needs to be initialized and set up as described in the [algebra extension](./algebra.md) chapter. -For the basic operations provided by the `WeierstrassPoint` trait, the scalar field is not needed. For the ECDSA functions in the `ecdsa` module, the scalar field must also be declared, initialized, and set up. +For the basic operations provided by the `WeierstrassPoint` or `TwistedEdwardsPoint` traits, the scalar field is not needed. For the ECDSA functions in the `ecdsa` module, the scalar field must also be declared, initialized, and set up. ## ECDSA diff --git a/book/src/custom-extensions/overview.md b/book/src/custom-extensions/overview.md index 2b07a73ec4..9221a448ed 100644 --- a/book/src/custom-extensions/overview.md +++ b/book/src/custom-extensions/overview.md @@ -60,20 +60,39 @@ supported_moduli = ["", "", ...] [app_vm_config.pairing] supported_curves = ["Bls12_381", "Bn254"] -[[app_vm_config.ecc.supported_curves]] -struct_name = "" +[[app_vm_config.ecc.supported_sw_curves]] +struct_name = "" modulus = "" scalar = "" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "" b = "" -[[app_vm_config.ecc.supported_curves]] -struct_name = "" +[[app_vm_config.ecc.supported_sw_curves]] +struct_name = "" modulus = "" scalar = "" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "" b = "" + +[[app_vm_config.ecc.supported_te_curves]] +struct_name = "" +modulus = "" +scalar = "" +[app_vm_config.ecc.supported_te_curves.coeffs] +a = "" +d = "" + +[[app_vm_config.ecc.supported_te_curves]] +struct_name = "" +modulus = "" +scalar = "" +[app_vm_config.ecc.supported_te_curves.coeffs] +a = "" +d = "" +` ``` `rv32i`, `io`, and `rv32m` need to be always included if you make an `openvm.toml` file while the rest are optional and should be included if you want to use the corresponding extension. -All moduli and scalars must be provided in decimal format. Currently `pairing` supports only pre-defined `Bls12_381` and `Bn254` curves. To add more `ecc` curves you need to add more `[[app_vm_config.ecc.supported_curves]]` entries. +All moduli and scalars must be provided in decimal format. Currently `pairing` supports only pre-defined `Bls12_381` and `Bn254` curves. To add more `ecc` curves you need to add more `[[app_vm_config.ecc.supported_sw_curves]]` or `[[app_vm_config.ecc.supported_te_curves]]` entries. diff --git a/book/src/guest-libs/k256.md b/book/src/guest-libs/k256.md index 44aa4d2743..d17ebba97d 100644 --- a/book/src/guest-libs/k256.md +++ b/book/src/guest-libs/k256.md @@ -40,17 +40,18 @@ For the guest program to build successfully, all used moduli and curves must be [app_vm_config.modular] supported_moduli = ["115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337"] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "7" ``` The `supported_moduli` parameter is a list of moduli that the guest program will use. As mentioned in the [algebra extension](../custom-extensions/algebra.md) chapter, the order of moduli in `[app_vm_config.modular]` must match the order in the `moduli_init!` macro. -The `ecc.supported_curves` parameter is a list of supported curves that the guest program will use. They must be provided in decimal format in the `.toml` file. For multiple curves create multiple `[[app_vm_config.ecc.supported_curves]]` sections. The order of curves in `[[app_vm_config.ecc.supported_curves]]` must match the order in the `sw_init!` macro. -Also, the `struct_name` field must be the name of the elliptic curve struct created by `sw_declare!`. +The `ecc.supported_curves` parameter is a list of supported curves that the guest program will use. They must be provided in decimal format in the `.toml` file. For multiple curves create multiple `[[app_vm_config.ecc.supported_sw_curves]]`/`[[app_vm_config.ecc.supported_te_curves]]` sections. The order of curves in `[[app_vm_config.ecc.supported_sw/te_curves]]` must match the order in the `sw_init!`/`te_init!` macros respectively. +Also, the `struct_name` field must be the name of the elliptic curve struct created by `sw_declare!`/`te_declare!`. In this example, the `Secp256k1Point` struct is created in `openvm_ecc_guest::k256`. diff --git a/book/src/guest-libs/p256.md b/book/src/guest-libs/p256.md index 2d39422cc0..e68f0652e6 100644 --- a/book/src/guest-libs/p256.md +++ b/book/src/guest-libs/p256.md @@ -11,15 +11,16 @@ For the guest program to build successfully, all used moduli and curves must be [app_vm_config.modular] supported_moduli = ["115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369"] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "P256Point" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" b = "41058363725152142129326129780047268409114441015993725554835256314039467401291" ``` The `supported_moduli` parameter is a list of moduli that the guest program will use. As mentioned in the [algebra extension](../custom-extensions/algebra.md) chapter, the order of moduli in `[app_vm_config.modular]` must match the order in the `moduli_init!` macro. -The `ecc.supported_curves` parameter is a list of supported curves that the guest program will use. They must be provided in decimal format in the `.toml` file. For multiple curves create multiple `[[app_vm_config.ecc.supported_curves]]` sections. The order of curves in `[[app_vm_config.ecc.supported_curves]]` must match the order in the `sw_init!` macro. -Also, the `struct_name` field must be the name of the elliptic curve struct created by `sw_declare!`. +The `ecc.supported_curves` parameter is a list of supported curves that the guest program will use. They must be provided in decimal format in the `.toml` file. For multiple curves create multiple `[[app_vm_config.ecc.supported_sw_curves]]`/`[[app_vm_config.ecc.supported_te_curves]]` sections. The order of curves in `[[app_vm_config.ecc.supported_sw_curves]]`/`[[app_vm_config.ecc.supported_te_curves]]` must match the order in the `sw_init!`/`te_init!` macros respectively. +Also, the `struct_name` field must be the name of the elliptic curve struct created by `sw_declare!`/`te_declare!`. diff --git a/ci/scripts/bench.py b/ci/scripts/bench.py index 97db5180e6..9bf87f622f 100644 --- a/ci/scripts/bench.py +++ b/ci/scripts/bench.py @@ -32,7 +32,7 @@ def run_cargo_command( command.extend(["--max_segment_length", max_segment_length]) if kzg_params_dir is not None: command.extend(["--kzg-params-dir", kzg_params_dir]) - if "profiling" in feature_flags: + if "perf-metrics" in feature_flags: # set guest build args and vm config to profiling command.extend(["--profiling"]) @@ -50,7 +50,7 @@ def run_cargo_command( # Prepare the environment variables env = os.environ.copy() # Copy current environment variables env["OUTPUT_PATH"] = output_path - if "profiling" in feature_flags: + if "perf-metrics" in feature_flags: env["GUEST_SYMBOLS_PATH"] = os.path.splitext(output_path)[0] + ".syms" env["RUSTFLAGS"] = "-Ctarget-cpu=native" @@ -73,7 +73,7 @@ def bench(): parser.add_argument('--output_path', type=str, required=True, help="The path to write the metrics to") args = parser.parse_args() - feature_flags = ["bench-metrics", "parallel"] + (args.features.split(",") if args.features else []) + feature_flags = ["metrics", "parallel"] + (args.features.split(",") if args.features else []) assert (feature_flags.count("mimalloc") + feature_flags.count("jemalloc")) == 1 run_cargo_command( diff --git a/ci/scripts/metric_unify/flamegraph.py b/ci/scripts/metric_unify/flamegraph.py index f5054d864c..fe5dc157c2 100644 --- a/ci/scripts/metric_unify/flamegraph.py +++ b/ci/scripts/metric_unify/flamegraph.py @@ -60,6 +60,9 @@ def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_met function_symbols = [get_function_symbol(string_table, offset) for offset in symbol_offsets] stack_values.extend(function_symbols) else: + # don't make a stack frame for empty label + if labels[key] == '': + continue stack_values.append(labels[key]) if filter: continue diff --git a/crates/circuits/mod-builder/Cargo.toml b/crates/circuits/mod-builder/Cargo.toml index d756db326b..cfd5434dde 100644 --- a/crates/circuits/mod-builder/Cargo.toml +++ b/crates/circuits/mod-builder/Cargo.toml @@ -23,8 +23,6 @@ num-traits.workspace = true tracing.workspace = true itertools.workspace = true -serde = { workspace = true, features = ["derive"] } -serde_with.workspace = true [dev-dependencies] openvm-circuit-primitives = { workspace = true } @@ -35,4 +33,8 @@ openvm-circuit = { workspace = true, features = ["test-utils"] } [features] default = [] parallel = ["openvm-stark-backend/parallel"] -test-utils = ["dep:halo2curves-axiom", "dep:openvm-pairing-guest"] +test-utils = [ + "dep:halo2curves-axiom", + "dep:openvm-pairing-guest", + "openvm-circuit/test-utils", +] diff --git a/crates/circuits/mod-builder/src/builder.rs b/crates/circuits/mod-builder/src/builder.rs index 6e1c22a009..a1a2a43c8d 100644 --- a/crates/circuits/mod-builder/src/builder.rs +++ b/crates/circuits/mod-builder/src/builder.rs @@ -289,6 +289,22 @@ impl FieldExpr { ret.setup_values = setup_values; ret } + + pub fn num_inputs(&self) -> usize { + self.builder.num_input + } + + pub fn num_vars(&self) -> usize { + self.builder.num_variables + } + + pub fn num_flags(&self) -> usize { + self.builder.num_flags + } + + pub fn output_indices(&self) -> &[usize] { + &self.builder.output_indices + } } impl Deref for FieldExpr { @@ -402,6 +418,7 @@ impl SubAir for FieldExpr { for i in 0..self.constraints.len() { let expr = self.constraints[i] .evaluate_overflow_expr::(&inputs, &vars, &constants, &flags); + self.check_carry_mod_to_zero.eval( builder, ( diff --git a/crates/circuits/mod-builder/src/core_chip.rs b/crates/circuits/mod-builder/src/core_chip.rs index 30e9c65dbb..41a1a2e1ef 100644 --- a/crates/circuits/mod-builder/src/core_chip.rs +++ b/crates/circuits/mod-builder/src/core_chip.rs @@ -1,28 +1,32 @@ +use std::{ + marker::PhantomData, + mem::{align_of, size_of}, + sync::Arc, +}; + use itertools::Itertools; use num_bigint::BigUint; use num_traits::Zero; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, MinimalInstruction, - Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ - var_range::SharedVariableRangeCheckerChip, SubAir, TraceSubRowGenerator, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerChip}, + SubAir, TraceSubRowGenerator, }; -use openvm_instructions::instruction::Instruction; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, rap::BaseAirWithPublicValues, }; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use serde::{Deserialize, Serialize}; -use serde_with::{serde_as, DisplayFromStr}; use crate::{ - utils::{biguint_to_limbs_vec, limbs_to_biguint}, - FieldExpr, FieldExprCols, + builder::{FieldExpr, FieldExprCols}, + utils::biguint_to_limbs_vec, }; #[derive(Clone)] @@ -165,174 +169,411 @@ where } } -#[serde_as] -#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] -pub struct FieldExpressionRecord { - #[serde_as(as = "Vec")] - pub inputs: Vec, - pub flags: Vec, +pub struct FieldExpressionMetadata { + pub total_input_limbs: usize, // num_inputs * limbs_per_input + _phantom: PhantomData<(F, A)>, } -pub struct FieldExpressionCoreChip { - pub air: FieldExpressionCoreAir, - pub range_checker: SharedVariableRangeCheckerChip, +impl Clone for FieldExpressionMetadata { + fn clone(&self) -> Self { + Self { + total_input_limbs: self.total_input_limbs, + _phantom: PhantomData, + } + } +} - pub name: String, +impl Default for FieldExpressionMetadata { + fn default() -> Self { + Self { + total_input_limbs: 0, + _phantom: PhantomData, + } + } +} - /// Whether to finalize the trace. True if all-zero rows don't satisfy the constraints (e.g. - /// there is int_add) - pub should_finalize: bool, +impl FieldExpressionMetadata { + pub fn new(total_input_limbs: usize) -> Self { + Self { + total_input_limbs, + _phantom: PhantomData, + } + } +} + +impl AdapterCoreMetadata for FieldExpressionMetadata +where + A: AdapterTraceExecutor, +{ + #[inline(always)] + fn get_adapter_width() -> usize { + A::WIDTH * size_of::() + } +} + +pub type FieldExpressionRecordLayout = AdapterCoreLayout>; + +pub struct FieldExpressionCoreRecordMut<'a> { + pub opcode: &'a mut u8, + pub input_limbs: &'a mut [u8], } -impl FieldExpressionCoreChip { +impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout> + for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: FieldExpressionRecordLayout, + ) -> FieldExpressionCoreRecordMut<'a> { + let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) }; + + FieldExpressionCoreRecordMut { + opcode: &mut opcode_buf[0], + input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs], + } + } + + unsafe fn extract_layout(&self) -> FieldExpressionRecordLayout { + panic!("Should get the Layout information from FieldExpressionExecutor"); + } +} + +impl SizedRecord> for FieldExpressionCoreRecordMut<'_> { + fn size(layout: &FieldExpressionRecordLayout) -> usize { + layout.metadata.total_input_limbs + 1 + } + + fn alignment(_layout: &FieldExpressionRecordLayout) -> usize { + align_of::() + } +} + +impl<'a> FieldExpressionCoreRecordMut<'a> { + // This method is only used in testing + pub fn new_from_execution_data( + buffer: &'a mut [u8], + inputs: &[BigUint], + limbs_per_input: usize, + ) -> Self { + let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input); + + let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout { + metadata: record_info, + }); + record + } + + #[inline(always)] + pub fn fill_from_execution_data(&mut self, opcode: u8, data: &[u8]) { + // Rust will assert that length of `data` and `self.input_limbs` are the same + // That is `data.len() == num_inputs * limbs_per_input` + *self.opcode = opcode; + self.input_limbs.copy_from_slice(data); + } +} + +#[derive(Clone)] +pub struct FieldExpressionExecutor { + adapter: A, + pub expr: FieldExpr, + pub offset: usize, + pub local_opcode_idx: Vec, + pub opcode_flag_idx: Vec, + pub name: String, +} + +impl FieldExpressionExecutor { + #[allow(clippy::too_many_arguments)] pub fn new( + adapter: A, expr: FieldExpr, offset: usize, local_opcode_idx: Vec, opcode_flag_idx: Vec, - range_checker: SharedVariableRangeCheckerChip, name: &str, - should_finalize: bool, ) -> Self { - let air = FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx); + let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() { + // single op chip that needs setup, so there is only one default flag, must be 0. + vec![0] + } else { + // multi ops chip or no-setup chip, use as is. + opcode_flag_idx + }; + assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1); tracing::info!( - "FieldExpressionCoreChip: opcode={name}, main_width={}", - BaseAir::::width(&air) + "FieldExpressionCoreExecutor: opcode={name}, main_width={}", + BaseAir::::width(&expr) ); Self { - air, - range_checker, + adapter, + expr, + offset, + local_opcode_idx, + opcode_flag_idx, name: name.to_string(), + } + } + + pub fn get_record_layout(&self) -> FieldExpressionRecordLayout { + FieldExpressionRecordLayout { + metadata: FieldExpressionMetadata::new( + self.expr.builder.num_input * self.expr.canonical_num_limbs(), + ), + } + } +} + +pub struct FieldExpressionFiller { + adapter: A, + pub expr: FieldExpr, + pub local_opcode_idx: Vec, + pub opcode_flag_idx: Vec, + pub range_checker: SharedVariableRangeCheckerChip, + pub should_finalize: bool, +} + +impl FieldExpressionFiller { + #[allow(clippy::too_many_arguments)] + pub fn new( + adapter: A, + expr: FieldExpr, + local_opcode_idx: Vec, + opcode_flag_idx: Vec, + range_checker: SharedVariableRangeCheckerChip, + should_finalize: bool, + ) -> Self { + let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() { + // single op chip that needs setup, so there is only one default flag, must be 0. + vec![0] + } else { + // multi ops chip or no-setup chip, use as is. + opcode_flag_idx + }; + assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1); + Self { + adapter, + expr, + local_opcode_idx, + opcode_flag_idx, + range_checker, should_finalize, } } + pub fn num_inputs(&self) -> usize { + self.expr.builder.num_input + } - pub fn expr(&self) -> &FieldExpr { - &self.air.expr + pub fn num_flags(&self) -> usize { + self.expr.builder.num_flags + } + + pub fn get_record_layout(&self) -> FieldExpressionRecordLayout { + FieldExpressionRecordLayout { + metadata: FieldExpressionMetadata::new( + self.num_inputs() * self.expr.canonical_num_limbs(), + ), + } } } -impl VmCoreChip for FieldExpressionCoreChip +impl PreflightExecutor for FieldExpressionExecutor where - I: VmAdapterInterface, - I::Reads: Into>, - AdapterRuntimeContext: From>>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor>, WriteData: From>>, + for<'buf> RA: RecordArena< + 'buf, + FieldExpressionRecordLayout, + (A::RecordMut<'buf>, FieldExpressionCoreRecordMut<'buf>), + >, { - type Record = FieldExpressionRecord; - type Air = FieldExpressionCoreAir; - - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let field_element_limbs = self.air.expr.canonical_num_limbs(); - let limb_bits = self.air.expr.canonical_limb_bits(); - let data: DynArray<_> = reads.into(); - let data = data.0; - assert_eq!(data.len(), self.air.num_inputs() * field_element_limbs); - let data_u32: Vec = data.iter().map(|x| x.as_canonical_u32()).collect(); - - let mut inputs = vec![]; - for i in 0..self.air.num_inputs() { - let start = i * field_element_limbs; - let end = start + field_element_limbs; - let limb_slice = &data_u32[start..end]; - let input = limbs_to_biguint(limb_slice, limb_bits); - inputs.push(input); - } + ) -> Result<(), ExecutionError> { + let (mut adapter_record, mut core_record) = state.ctx.alloc(self.get_record_layout()); - let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(self.air.offset); - let mut flags = vec![]; - - // If the chip doesn't need setup, (right now) it must be single op chip and thus no flag is - // needed. Otherwise, there is a flag for each opcode and will be derived by - // is_valid - sum(flags). - if self.expr().needs_setup() { - flags = vec![false; self.air.num_flags()]; - self.air - .opcode_flag_idx - .iter() - .enumerate() - .for_each(|(i, &flag_idx)| { - flags[flag_idx] = local_opcode_idx == self.air.local_opcode_idx[i] - }); - } + A::start(*state.pc, state.memory, &mut adapter_record); - let vars = self.air.expr.execute(inputs.clone(), flags.clone()); - assert_eq!(vars.len(), self.air.num_vars()); + let data: DynArray<_> = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); - let outputs: Vec = self - .air - .output_indices() - .iter() - .map(|&i| vars[i].clone()) - .collect(); - let writes: Vec = outputs - .iter() - .map(|x| biguint_to_limbs_vec(x.clone(), limb_bits, field_element_limbs)) - .concat() - .into_iter() - .map(|x| F::from_canonical_u32(x)) - .collect(); + core_record.fill_from_execution_data( + instruction.opcode.local_opcode_idx(self.offset) as u8, + &data.0, + ); - let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes); - Ok((ctx.into(), FieldExpressionRecord { inputs, flags })) + let (writes, _, _) = run_field_expression( + &self.expr, + &self.local_opcode_idx, + &self.opcode_flag_idx, + core_record.input_limbs, + *core_record.opcode as usize, + ); + + self.adapter.write( + state.memory, + instruction, + writes.into(), + &mut adapter_record, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) } fn get_opcode_name(&self, _opcode: usize) -> String { self.name.clone() } +} + +impl TraceFiller for FieldExpressionFiller +where + F: PrimeField32 + Send + Sync + Clone, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + // Get the core record from the row slice + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.air.expr.generate_subrow( - (self.range_checker.as_ref(), record.inputs, record.flags), - row_slice, + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: FieldExpressionCoreRecordMut = + unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::()) }; + + let (_, inputs, flags) = run_field_expression( + &self.expr, + &self.local_opcode_idx, + &self.opcode_flag_idx, + record.input_limbs, + *record.opcode as usize, ); - } - fn air(&self) -> &Self::Air { - &self.air + let range_checker = self.range_checker.as_ref(); + self.expr + .generate_subrow((range_checker, inputs, flags), core_row); } - fn finalize(&self, trace: &mut RowMajorMatrix, num_records: usize) { - if !self.should_finalize || num_records == 0 { + fn fill_dummy_trace_row(&self, row_slice: &mut [F]) { + if !self.should_finalize { return; } - let core_width = >::width(&self.air); - let adapter_width = trace.width() - core_width; - let dummy_row = self.generate_dummy_trace_row(adapter_width, core_width); - for row in trace.rows_mut().skip(num_records) { - row.copy_from_slice(&dummy_row); + let inputs: Vec = vec![BigUint::zero(); self.num_inputs()]; + let flags: Vec = vec![false; self.num_flags()]; + let core_row = &mut row_slice[A::WIDTH..]; + // We **do not** want this trace row to update the range checker + // so we must create a temporary range checker + let tmp_range_checker = Arc::new(VariableRangeCheckerChip::new(self.range_checker.bus())); + self.expr + .generate_subrow((&tmp_range_checker, inputs, flags), core_row); + core_row[0] = F::ZERO; // is_valid = 0 + } +} + +fn run_field_expression( + expr: &FieldExpr, + local_opcode_flags: &[usize], + opcode_flag_idx: &[usize], + data: &[u8], + local_opcode_idx: usize, +) -> (DynArray, Vec, Vec) { + let field_element_limbs = expr.canonical_num_limbs(); + assert_eq!(data.len(), expr.builder.num_input * field_element_limbs); + + let mut inputs = Vec::with_capacity(expr.builder.num_input); + for i in 0..expr.builder.num_input { + let start = i * field_element_limbs; + let end = start + field_element_limbs; + let limb_slice = &data[start..end]; + let input = BigUint::from_bytes_le(limb_slice); + inputs.push(input); + } + + let mut flags = vec![]; + if expr.needs_setup() { + flags = vec![false; expr.builder.num_flags]; + + // Find which opcode this is in our local_opcode_idx list + if let Some(opcode_position) = local_opcode_flags + .iter() + .position(|&idx| idx == local_opcode_idx) + { + // If this is NOT the last opcode (setup), set the corresponding flag + if opcode_position < opcode_flag_idx.len() { + let flag_idx = opcode_flag_idx[opcode_position]; + flags[flag_idx] = true; + } + // If opcode_position == step.opcode_flag_idx.len(), it's the setup operation + // and all flags should remain false (which they already are) } } + + let vars = expr.execute(inputs.clone(), flags.clone()); + assert_eq!(vars.len(), expr.builder.num_variables); + + let outputs: Vec = expr + .builder + .output_indices + .iter() + .map(|&i| vars[i].clone()) + .collect(); + let writes: DynArray<_> = outputs + .iter() + .map(|x| biguint_to_limbs_vec(x, field_element_limbs)) + .concat() + .into_iter() + .collect::>() + .into(); + + (writes, inputs, flags) } -impl FieldExpressionCoreChip { - // We will be setting is_valid = 0. That forces all flags be 0 (otherwise setup will be -1). - // We generate a dummy row with all flags set to 0, then we set is_valid = 0. - fn generate_dummy_trace_row( - &self, - adapter_width: usize, - core_width: usize, - ) -> Vec { - let record = FieldExpressionRecord { - inputs: vec![BigUint::zero(); self.air.num_inputs()], - flags: vec![false; self.air.num_flags()], - }; - let mut row = vec![F::ZERO; adapter_width + core_width]; - let core_row = &mut row[adapter_width..]; - // We **do not** want this trace row to update the range checker - // so we must create a temporary range checker - let tmp_range_checker = SharedVariableRangeCheckerChip::new(self.range_checker.bus()); - self.air.expr.generate_subrow( - (tmp_range_checker.as_ref(), record.inputs, record.flags), - core_row, - ); - core_row[0] = F::ZERO; // is_valid = 0 - row +#[inline(always)] +pub fn run_field_expression_precomputed( + expr: &FieldExpr, + flag_idx: usize, + data: &[u8], +) -> DynArray { + let field_element_limbs = expr.canonical_num_limbs(); + assert_eq!(data.len(), expr.num_inputs() * field_element_limbs); + + let mut inputs = Vec::with_capacity(expr.num_inputs()); + for i in 0..expr.num_inputs() { + let start = i * expr.canonical_num_limbs(); + let end = start + expr.canonical_num_limbs(); + let limb_slice = &data[start..end]; + let input = BigUint::from_bytes_le(limb_slice); + inputs.push(input); } + + let flags = if NEEDS_SETUP { + let mut flags = vec![false; expr.num_flags()]; + if flag_idx < expr.num_flags() { + flags[flag_idx] = true; + } + flags + } else { + vec![] + }; + + let vars = expr.execute(inputs, flags); + assert_eq!(vars.len(), expr.num_vars()); + + let outputs: Vec = expr + .output_indices() + .iter() + .map(|&i| vars[i].clone()) + .collect(); + + outputs + .iter() + .map(|x| biguint_to_limbs_vec(x, field_element_limbs)) + .concat() + .into_iter() + .collect::>() + .into() } diff --git a/crates/circuits/mod-builder/src/tests.rs b/crates/circuits/mod-builder/src/tests.rs index d217c0c5c2..628043256d 100644 --- a/crates/circuits/mod-builder/src/tests.rs +++ b/crates/circuits/mod-builder/src/tests.rs @@ -11,14 +11,126 @@ use openvm_stark_sdk::{ p3_baby_bear::BabyBear, }; -use crate::{test_utils::*, ExprBuilder, FieldExpr, FieldExprCols, FieldVariable, SymbolicExpr}; +use crate::{ + test_utils::*, utils::biguint_to_limbs_vec, ExprBuilder, FieldExpr, FieldExprCols, + FieldExpressionCoreRecordMut, FieldVariable, SymbolicExpr, +}; const LIMB_BITS: usize = 8; +use std::sync::Arc; + +use openvm_circuit_primitives::var_range::VariableRangeCheckerChip; + +fn create_field_expr_with_setup( + builder: ExprBuilder, +) -> (FieldExpr, Arc, usize) { + let prime = secp256k1_coord_prime(); + let (range_checker, _) = setup(&prime); + let expr = FieldExpr::new(builder, range_checker.bus(), false); + let width = BaseAir::::width(&expr); + (expr, range_checker, width) +} + +fn create_field_expr_with_flags_setup( + builder: ExprBuilder, +) -> (FieldExpr, Arc, usize) { + let prime = secp256k1_coord_prime(); + let (range_checker, _) = setup(&prime); + let expr = FieldExpr::new(builder, range_checker.bus(), true); + let width = BaseAir::::width(&expr); + (expr, range_checker, width) +} + +fn generate_direct_trace( + expr: &FieldExpr, + range_checker: &Arc, + inputs: Vec, + flags: Vec, + width: usize, +) -> Vec { + let mut row = BabyBear::zero_vec(width); + expr.generate_subrow((range_checker, inputs, flags), &mut row); + row +} + +fn generate_recorded_trace( + expr: &FieldExpr, + range_checker: &Arc, + inputs: &[BigUint], + flags: Vec, + width: usize, +) -> Vec { + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + + let mut row = BabyBear::zero_vec(width); + expr.generate_subrow((range_checker, reconstructed_inputs, flags), &mut row); + row +} + +fn verify_stark_with_traces( + expr: FieldExpr, + range_checker: Arc, + trace: Vec, + width: usize, +) { + let trace_matrix = RowMajorMatrix::new(trace, width); + let range_trace = range_checker.generate_trace(); + BabyBearBlake3Engine::run_simple_test_no_pis_fast( + any_rap_arc_vec![expr, range_checker.air], + vec![trace_matrix, range_trace], + ) + .expect("Verification failed"); +} + +fn extract_and_verify_result( + expr: &FieldExpr, + trace: &[BabyBear], + expected: &BigUint, + var_index: usize, +) { + let FieldExprCols { vars, .. } = expr.load_vars(trace); + assert!(var_index < vars.len(), "Variable index out of bounds"); + let generated = evaluate_biguint(&vars[var_index], LIMB_BITS); + assert_eq!(generated, *expected); +} + +fn test_trace_equivalence( + expr: &FieldExpr, + range_checker: &Arc, + inputs: Vec, + flags: Vec, + width: usize, +) { + let direct_trace = + generate_direct_trace(expr, range_checker, inputs.clone(), flags.clone(), width); + let recorded_trace = generate_recorded_trace(expr, range_checker, &inputs, flags, width); + assert_eq!( + direct_trace, recorded_trace, + "Direct and recorded traces must be identical for inputs: {:?}", + inputs + ); +} #[test] fn test_add() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let x1 = ExprBuilder::new_input(builder.clone()); let x2 = ExprBuilder::new_input(builder.clone()); @@ -26,70 +138,45 @@ fn test_add() { x3.save(); let builder = builder.borrow().clone(); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x + &y) % prime; + let expected = (&x + &y) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_div() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let x1 = ExprBuilder::new_input(builder.clone()); let x2 = ExprBuilder::new_input(builder.clone()); let _x3 = x1 / x2; // auto save on division. let builder = builder.borrow().clone(); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); let y_inv = y.modinv(&prime).unwrap(); - let expected = (&x * &y_inv) % prime; + let expected = (&x * &y_inv) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_mul() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let mut x1 = ExprBuilder::new_input(builder.clone()); let mut x2 = ExprBuilder::new_input(builder.clone()); @@ -101,36 +188,25 @@ fn test_auto_carry_mul() { assert_eq!(x4.expr, SymbolicExpr::Var(1)); let builder = builder.borrow().clone(); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x * &x * &y) % prime; // x4 = x3 * x1 = (x1 * x2) * x1 + let expected = (&x * &x * &y) % ′ // x4 = x3 * x1 = (x1 * x2) * x1 let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - let generated = evaluate_biguint(&vars[1], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + extract_and_verify_result(&expr, &trace, &expected, 1); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_intmul() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); - let mut x1 = ExprBuilder::new_input(builder.clone()); + let (_, builder) = setup(&prime); + let mut x1: FieldVariable = ExprBuilder::new_input(builder.clone()); let mut x2 = ExprBuilder::new_input(builder.clone()); let mut x3 = &mut x1 * &mut x2; // The int_mul below will overflow: @@ -143,35 +219,24 @@ fn test_auto_carry_intmul() { assert_eq!(x4.expr, SymbolicExpr::Var(1)); let builder = builder.borrow().clone(); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x * &x * BigUint::from(9u32)) % prime; + let expected = (&x * &x * BigUint::from(9u32)) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - let generated = evaluate_biguint(&vars[1], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + extract_and_verify_result(&expr, &trace, &expected, 1); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_add() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let mut x1 = ExprBuilder::new_input(builder.clone()); let mut x2 = ExprBuilder::new_input(builder.clone()); @@ -194,36 +259,24 @@ fn test_auto_carry_add() { assert_eq!(x5.expr, SymbolicExpr::Var(1)); let builder = builder.borrow().clone(); - - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x * &x * BigUint::from(10u32)) % prime; + let expected = (&x * &x * BigUint::from(10u32)) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - let generated = evaluate_biguint(&vars[x5_id], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + extract_and_verify_result(&expr, &trace, &expected, x5_id); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_div() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let mut x1 = ExprBuilder::new_input(builder.clone()); let x2 = ExprBuilder::new_input(builder.clone()); @@ -237,29 +290,16 @@ fn test_auto_carry_div() { let builder = builder.borrow().clone(); assert_eq!(builder.num_variables, 2); // numerator autosaved, and the final division - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - // let expected = (&x * &x * BigUint::from(10u32)) % prime; let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - // let generated = evaluate_biguint(&vars[x5_id], LIMB_BITS); - // assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + verify_stark_with_traces(expr, range_checker, trace, width); } fn make_addsub_chip(builder: Rc>) -> ExprBuilder { @@ -283,65 +323,39 @@ fn make_addsub_chip(builder: Rc>) -> ExprBuilder { #[test] fn test_select() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let builder = make_addsub_chip(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), true); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_flags_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x + &prime - &y) % prime; + let expected = (&x + &prime - &y) % ′ let inputs = vec![x, y]; - let flags = vec![false, true]; + let flags: Vec = vec![false, true]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, flags), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, flags, width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_select2() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let builder = make_addsub_chip(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), true); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_flags_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x + &y) % prime; + let expected = (&x + &y) % ′ let inputs = vec![x, y]; - let flags = vec![true, false]; - - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, flags), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); + let flags: Vec = vec![true, false]; - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, flags, width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } fn test_symbolic_limbs(expr: SymbolicExpr, expected_q: usize, expected_carry: usize) { @@ -395,3 +409,299 @@ fn test_symbolic_limbs_mul() { let expected_carry = 64; test_symbolic_limbs(expr, expected_q, expected_carry); } + +#[test] +fn test_recorded_execution_records() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let mut x3 = x1 + x2; + x3.save(); + let builder = builder.borrow().clone(); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder); + + let x = generate_random_biguint(&prime); + let y = generate_random_biguint(&prime); + let expected = (&x + &y) % ′ + let inputs = vec![x.clone(), y.clone()]; + let flags: Vec = vec![]; + + // Test record creation and reconstruction + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + &inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + assert_eq!(*record.opcode, 0); + + // Verify input reconstruction preserves data + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + assert_eq!(reconstructed_inputs.len(), inputs.len()); + for (original, reconstructed) in inputs.iter().zip(reconstructed_inputs.iter()) { + assert_eq!(original, reconstructed); + } + + // Test standard execution and verification using reconstructed inputs + let trace = generate_direct_trace(&expr, &range_checker, reconstructed_inputs, flags, width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); +} + +#[test] +fn test_trace_mathematical_equivalence() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let x3 = &mut (x1.clone() * x2.clone()) + &mut (x1.clone().square()); + let mut x4 = x3.clone() / x2.clone(); // This will trigger auto-save + x4.save(); + let builder = builder.borrow().clone(); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder); + + for _ in 0..10 { + let x = generate_random_biguint(&prime); + let y = generate_random_biguint(&prime); + + let expected = { + let temp = (&x * &y + &x * &x) % ′ + let y_inv = y.modinv(&prime).unwrap(); + (temp * y_inv) % &prime + }; + + let inputs = vec![x.clone(), y.clone()]; + let flags: Vec = vec![]; + + // Test direct/recorded equivalence + test_trace_equivalence(&expr, &range_checker, inputs.clone(), flags.clone(), width); + + // Verify the actual computation is correct + let direct_row = generate_direct_trace(&expr, &range_checker, inputs.clone(), flags, width); + let FieldExprCols { vars, .. } = expr.load_vars(&direct_row); + extract_and_verify_result(&expr, &direct_row, &expected, vars.len() - 1); + } +} + +#[test] +fn test_record_arena_allocation_patterns() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let mut x3 = x1 + x2; + x3.save(); + let builder = builder.borrow().clone(); + + let (expr, _range_checker, _width) = create_field_expr_with_setup(builder); + + let inputs = vec![ + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ]; + + // Test record creation with various input sizes + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + &inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + assert_eq!(*record.opcode, 0); + + // Test with maximum inputs + let max_inputs = vec![BigUint::one(); 40]; // MAX_INPUT_LIMBS / 4 + let mut max_buffer = vec![0u8; 2048]; + let max_record = + FieldExpressionCoreRecordMut::new_from_execution_data(&mut max_buffer, &max_inputs, 4); + assert_eq!(*max_record.opcode, 0); + + // Test input reconstruction + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + assert_eq!(reconstructed_inputs.len(), inputs.len()); + for (original, reconstructed) in inputs.iter().zip(reconstructed_inputs.iter()) { + assert_eq!(original, reconstructed); + } +} + +#[test] +fn test_tracestep_tracefiller_roundtrip() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let x3 = x1.clone() * x2.clone(); + let x4 = x3.clone() + x1.clone(); + let mut x5 = x4.clone(); + x5.save(); + let builder_data = builder.borrow().clone(); + + let (expr, _range_checker, _width) = create_field_expr_with_setup(builder_data); + + let inputs = vec![ + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ]; + + let vars_direct = expr.execute(inputs.clone(), vec![]); + + // Test record creation and reconstruction roundtrip + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + &inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + let vars_reconstructed = expr.execute(reconstructed_inputs, vec![]); + + // All intermediate variables must be preserved + assert_eq!(vars_direct.len(), vars_reconstructed.len()); + for (direct, reconstructed) in vars_direct.iter().zip(vars_reconstructed.iter()) { + assert_eq!( + direct, reconstructed, + "Variable preservation failed in roundtrip" + ); + } +} + +#[test] +fn test_direct_recorded_with_complex_operations() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let x3 = ExprBuilder::new_input(builder.clone()); + + let numerator = x1.clone() * x2.clone() + x3.clone(); + let denominator = x1.clone() + x2.clone(); + let mut result = numerator / denominator; + result.save(); + + let builder_data = builder.borrow().clone(); + let (expr, range_checker, width) = create_field_expr_with_setup(builder_data); + + // Test edge cases with small and large numbers + let test_cases = vec![ + ( + BigUint::from(1u32), + BigUint::from(2u32), + BigUint::from(3u32), + ), + ( + BigUint::from(100u32), + BigUint::from(200u32), + BigUint::from(300u32), + ), + ( + generate_random_biguint(&prime), + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ), + ]; + + for (x, y, z) in test_cases { + let inputs = vec![x.clone(), y.clone(), z.clone()]; + let flags = vec![]; + + // Test direct/recorded equivalence + test_trace_equivalence(&expr, &range_checker, inputs.clone(), flags.clone(), width); + + // Verify mathematical correctness + let expected = { + let num = (&x * &y + &z) % ′ + let den_inv = (&x + &y).modinv(&prime).unwrap(); + (num * den_inv) % &prime + }; + + let direct_row = generate_direct_trace(&expr, &range_checker, inputs, flags, width); + let FieldExprCols { vars, .. } = expr.load_vars(&direct_row); + extract_and_verify_result(&expr, &direct_row, &expected, vars.len() - 1); + } +} + +#[test] +fn test_concurrent_direct_recorded_simulation() { + // Simulate mixed direct/recorded execution to ensure RecordArena abstraction works correctly + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let mut x3 = x1 + x2; + x3.save(); + let builder_data = builder.borrow().clone(); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder_data); + + // Simulate multiple "concurrent" executions with different modes + let execution_scenarios = vec![ + ("direct", true), + ("recorded", false), + ("direct", true), + ("recorded", false), + ]; + + let mut all_traces = Vec::new(); + + for (name, is_direct) in execution_scenarios { + let inputs = vec![ + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ]; + + let trace = if is_direct { + generate_direct_trace(&expr, &range_checker, inputs.clone(), vec![], width) + } else { + generate_recorded_trace(&expr, &range_checker, &inputs, vec![], width) + }; + + all_traces.push((name, inputs, trace)); + } + + // Verify each trace is mathematically valid + for (_, inputs, trace) in &all_traces { + let expected = (&inputs[0] + &inputs[1]) % ′ + extract_and_verify_result(&expr, trace, &expected, 0); + } + + // Verify that direct and recorded with same inputs produce same results + let same_inputs = vec![BigUint::from(123u32), BigUint::from(456u32)]; + test_trace_equivalence(&expr, &range_checker, same_inputs, vec![], width); +} diff --git a/crates/circuits/mod-builder/src/utils.rs b/crates/circuits/mod-builder/src/utils.rs index 7540f0ae2c..2f2561ba87 100644 --- a/crates/circuits/mod-builder/src/utils.rs +++ b/crates/circuits/mod-builder/src/utils.rs @@ -1,27 +1,14 @@ use num_bigint::BigUint; -use num_traits::{FromPrimitive, ToPrimitive, Zero}; - -// little endian. -pub fn limbs_to_biguint(x: &[u32], limb_size: usize) -> BigUint { - let mut result = BigUint::zero(); - let base = BigUint::from_u32(1 << limb_size).unwrap(); - for limb in x.iter().rev() { - result = result * &base + BigUint::from_u32(*limb).unwrap(); - } - result -} // Use this when num_limbs is not a constant. // little endian. -// Warning: This function only returns the last NUM_LIMBS*LIMB_SIZE bits of +// Warning: This function only returns the last NUM_LIMBS bytes of // the input, while the input can have more than that. -pub fn biguint_to_limbs_vec(mut x: BigUint, limb_size: usize, num_limbs: usize) -> Vec { - let mut result = vec![0; num_limbs]; - let base = BigUint::from_u32(1 << limb_size).unwrap(); - for r in result.iter_mut() { - *r = (x.clone() % &base).to_u32().unwrap(); - x /= &base; - } - assert!(x.is_zero()); - result +#[inline(always)] +pub fn biguint_to_limbs_vec(x: &BigUint, num_limbs: usize) -> Vec { + x.to_bytes_le() + .into_iter() + .chain(std::iter::repeat(0u8)) + .take(num_limbs) + .collect() } diff --git a/crates/circuits/poseidon2-air/src/babybear.rs b/crates/circuits/poseidon2-air/src/babybear.rs index e12b60bfb4..6989f992c7 100644 --- a/crates/circuits/poseidon2-air/src/babybear.rs +++ b/crates/circuits/poseidon2-air/src/babybear.rs @@ -18,7 +18,7 @@ pub(crate) fn horizen_to_p3_babybear(horizen_babybear: HorizenBabyBear) -> BabyB } pub(crate) fn horizen_round_consts() -> Poseidon2Constants { - let p3_rc16: Vec> = RC16 + let p3_rc16: Vec> = RC16 .iter() .map(|round| { round @@ -29,18 +29,10 @@ pub(crate) fn horizen_round_consts() -> Poseidon2Constants { .collect(); let p_end = BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS + BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS; - let beginning_full_round_constants: [[BabyBear; POSEIDON2_WIDTH]; - BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = from_fn(|i| p3_rc16[i].clone().try_into().unwrap()); - let partial_round_constants: [BabyBear; BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS] = - from_fn(|i| p3_rc16[i + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS][0]); - let ending_full_round_constants: [[BabyBear; POSEIDON2_WIDTH]; - BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = - from_fn(|i| p3_rc16[i + p_end].clone().try_into().unwrap()); - Poseidon2Constants { - beginning_full_round_constants, - partial_round_constants, - ending_full_round_constants, + beginning_full_round_constants: from_fn(|i| p3_rc16[i].clone().try_into().unwrap()), + partial_round_constants: from_fn(|i| p3_rc16[i + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS][0]), + ending_full_round_constants: from_fn(|i| p3_rc16[i + p_end].clone().try_into().unwrap()), } } diff --git a/crates/circuits/poseidon2-air/src/config.rs b/crates/circuits/poseidon2-air/src/config.rs index be597c6dc6..6007f0b4fb 100644 --- a/crates/circuits/poseidon2-air/src/config.rs +++ b/crates/circuits/poseidon2-air/src/config.rs @@ -15,7 +15,7 @@ pub struct Poseidon2Config { pub constants: Poseidon2Constants, } -impl Default for Poseidon2Config { +impl Default for Poseidon2Config { fn default() -> Self { Self { constants: default_baby_bear_rc(), diff --git a/crates/circuits/poseidon2-air/src/lib.rs b/crates/circuits/poseidon2-air/src/lib.rs index 8a51ee88c7..747f94630e 100644 --- a/crates/circuits/poseidon2-air/src/lib.rs +++ b/crates/circuits/poseidon2-air/src/lib.rs @@ -42,7 +42,7 @@ pub const BABY_BEAR_POSEIDON2_SBOX_DEGREE: u64 = 7; /// `SBOX_REGISTERS` affects the max constraint degree of the AIR. See [p3_poseidon2_air] for more /// details. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Poseidon2SubChip { // This is Arc purely because Poseidon2Air cannot derive Clone pub air: Arc>, diff --git a/crates/circuits/primitives/derive/src/lib.rs b/crates/circuits/primitives/derive/src/lib.rs index 47ff1e220a..35e5f8fd5b 100644 --- a/crates/circuits/primitives/derive/src/lib.rs +++ b/crates/circuits/primitives/derive/src/lib.rs @@ -73,6 +73,49 @@ pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { TokenStream::from(methods) } +/// `S` is the type the derive macro is being called on +/// Implements Borrow and BorrowMut for [u8] +/// [u8] has to have (checked via `debug_assert!`s) +/// - at least size_of(S) length +/// - at least align_of(S) alignment +#[proc_macro_derive(AlignedBytesBorrow)] +pub fn aligned_bytes_borrow_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let name = &ast.ident; + + // Get impl generics, type generics, where clause + // Note, need to add the new type generic to the `impl_generics` + let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl(); + + let methods = quote! { + impl #impl_generics core::borrow::Borrow<#name #type_generics> for [u8] + where + #where_clause + { + fn borrow(&self) -> &#name #type_generics { + use core::mem::{align_of, size_of_val}; + debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>()); + debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0); + unsafe { &*(self.as_ptr() as *const #name #type_generics) } + } + } + + impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [u8] + where + #where_clause + { + fn borrow_mut(&mut self) -> &mut #name #type_generics { + use core::mem::{align_of, size_of_val}; + debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>()); + debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0); + unsafe { &mut *(self.as_mut_ptr() as *mut #name #type_generics) } + } + } + }; + + TokenStream::from(methods) +} + #[proc_macro_derive(Chip, attributes(chip))] pub fn chip_derive(input: TokenStream) -> TokenStream { // Parse the attributes from the struct or enum @@ -86,9 +129,10 @@ pub fn chip_derive(input: TokenStream) -> TokenStream { Data::Struct(inner) => { let generics = &ast.generics; let mut new_generics = generics.clone(); + new_generics.params.push(syn::parse_quote! { R }); new_generics .params - .push(syn::parse_quote! { SC: openvm_stark_backend::config::StarkGenericConfig }); + .push(syn::parse_quote! { PB: openvm_stark_backend::prover::hal::ProverBackend }); let (impl_generics, _, _) = new_generics.split_for_impl(); // Check if the struct has only one unnamed field @@ -105,17 +149,11 @@ pub fn chip_derive(input: TokenStream) -> TokenStream { let where_clause = new_generics.make_where_clause(); where_clause .predicates - .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::Chip }); + .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::Chip }); quote! { - impl #impl_generics openvm_stark_backend::Chip for #name #ty_generics #where_clause { - fn air(&self) -> openvm_stark_backend::AirRef { - self.0.air() - } - fn generate_air_proof_input(self) -> openvm_stark_backend::prover::types::AirProofInput { - self.0.generate_air_proof_input() - } - fn generate_air_proof_input_with_id(self, air_id: usize) -> (usize, openvm_stark_backend::prover::types::AirProofInput) { - self.0.generate_air_proof_input_with_id(air_id) + impl #impl_generics openvm_stark_backend::Chip for #name #ty_generics #where_clause { + fn generate_proving_ctx(&self, records: R) -> openvm_stark_backend::prover::types::AirProvingContext { + self.0.generate_proving_ctx(records) } } }.into() @@ -134,34 +172,32 @@ pub fn chip_derive(input: TokenStream) -> TokenStream { }) .collect::>(); - let (air_arms, generate_air_proof_input_arms, generate_air_proof_input_with_id_arms): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(variants.iter().map(|(variant_name, field)| { + let (generate_proving_ctx_arms, where_predicates): (Vec<_>, Vec<_>) = + variants.iter().map(|(variant_name, field)| { let field_ty = &field.ty; - let air_arm = quote! { - #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip>::air(x) - }; - let generate_air_proof_input_arm = quote! { - #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip>::generate_air_proof_input(x) - }; - let generate_air_proof_input_with_id_arm = quote! { - #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip>::generate_air_proof_input_with_id(x, air_id) + let generate_proving_ctx_arm = quote! { + #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip>::generate_proving_ctx(x, records) }; - (air_arm, generate_air_proof_input_arm, generate_air_proof_input_with_id_arm) - })); + let where_predicate = + syn::parse_quote! { #field_ty: openvm_stark_backend::Chip }; + (generate_proving_ctx_arm, where_predicate) + }).collect(); - // Attach an extra generic SC: StarkGenericConfig to the impl_generics + // Attach extra generics R and PB to the impl_generics let generics = &ast.generics; let mut new_generics = generics.clone(); + new_generics.params.push(syn::parse_quote! { R }); new_generics .params - .push(syn::parse_quote! { SC: openvm_stark_backend::config::StarkGenericConfig }); + .push(syn::parse_quote! { PB: openvm_stark_backend::prover::hal::ProverBackend }); let (impl_generics, _, _) = new_generics.split_for_impl(); // Implement Chip whenever the inner type implements Chip let mut new_generics = generics.clone(); let where_clause = new_generics.make_where_clause(); - where_clause.predicates.push(syn::parse_quote! { openvm_stark_backend::config::Domain: openvm_stark_backend::p3_commit::PolynomialSpace - }); + for predicate in where_predicates { + where_clause.predicates.push(predicate); + } let attributes = ast.attrs.iter().find(|&attr| attr.path().is_ident("chip")); if let Some(attr) = attributes { let mut fail_flag = false; @@ -195,20 +231,10 @@ pub fn chip_derive(input: TokenStream) -> TokenStream { } quote! { - impl #impl_generics openvm_stark_backend::Chip for #name #ty_generics #where_clause { - fn air(&self) -> openvm_stark_backend::AirRef { - match self { - #(#air_arms,)* - } - } - fn generate_air_proof_input(self) -> openvm_stark_backend::prover::types::AirProofInput { - match self { - #(#generate_air_proof_input_arms,)* - } - } - fn generate_air_proof_input_with_id(self, air_id: usize) -> (usize, openvm_stark_backend::prover::types::AirProofInput) { + impl #impl_generics openvm_stark_backend::Chip for #name #ty_generics #where_clause { + fn generate_proving_ctx(&self, records: R) -> openvm_stark_backend::prover::types::AirProvingContext { match self { - #(#generate_air_proof_input_with_id_arms,)* + #(#generate_proving_ctx_arms,)* } } } diff --git a/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs b/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs index a9e649f84e..f3a0152b35 100644 --- a/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs +++ b/crates/circuits/primitives/src/bitwise_op_lookup/mod.rs @@ -11,9 +11,9 @@ use openvm_stark_backend::{ p3_air::{Air, BaseAir, PairBuilder}, p3_field::{Field, FieldAlgebra}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + Chip, ChipUsageGetter, }; mod bus; @@ -112,10 +112,8 @@ pub struct BitwiseOperationLookupChip { pub count_xor: Vec, } -#[derive(Clone)] -pub struct SharedBitwiseOperationLookupChip( - Arc>, -); +pub type SharedBitwiseOperationLookupChip = + Arc>; impl BitwiseOperationLookupChip { pub fn new(bus: BitwiseOperationLookupBus) -> Self { @@ -159,15 +157,17 @@ impl BitwiseOperationLookupChip { } } + /// Generates trace and resets all internal counters to 0. pub fn generate_trace(&self) -> RowMajorMatrix { let mut rows = F::zero_vec(self.count_range.len() * NUM_BITWISE_OP_LOOKUP_COLS); for (n, row) in rows.chunks_mut(NUM_BITWISE_OP_LOOKUP_COLS).enumerate() { let cols: &mut BitwiseOperationLookupCols = row.borrow_mut(); cols.mult_range = F::from_canonical_u32( - self.count_range[n].load(std::sync::atomic::Ordering::SeqCst), + self.count_range[n].swap(0, std::sync::atomic::Ordering::SeqCst), + ); + cols.mult_xor = F::from_canonical_u32( + self.count_xor[n].swap(0, std::sync::atomic::Ordering::SeqCst), ); - cols.mult_xor = - F::from_canonical_u32(self.count_xor[n].load(std::sync::atomic::Ordering::SeqCst)); } RowMajorMatrix::new(rows, NUM_BITWISE_OP_LOOKUP_COLS) } @@ -177,57 +177,13 @@ impl BitwiseOperationLookupChip { } } -impl SharedBitwiseOperationLookupChip { - pub fn new(bus: BitwiseOperationLookupBus) -> Self { - Self(Arc::new(BitwiseOperationLookupChip::new(bus))) - } - pub fn bus(&self) -> BitwiseOperationLookupBus { - self.0.bus() - } - - pub fn air_width(&self) -> usize { - self.0.air_width() - } - - pub fn request_range(&self, x: u32, y: u32) { - self.0.request_range(x, y); - } - - pub fn request_xor(&self, x: u32, y: u32) -> u32 { - self.0.request_xor(x, y) - } - - pub fn clear(&self) { - self.0.clear() - } - - pub fn generate_trace(&self) -> RowMajorMatrix { - self.0.generate_trace() - } -} - -impl Chip +impl Chip> for BitwiseOperationLookupChip { - fn air(&self) -> AirRef { - Arc::new(self.air) - } - - fn generate_air_proof_input(self) -> AirProofInput { + /// Generates trace and resets all internal counters to 0. + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { let trace = self.generate_trace::>(); - AirProofInput::simple_no_pis(trace) - } -} - -impl Chip - for SharedBitwiseOperationLookupChip -{ - fn air(&self) -> AirRef { - self.0.air() - } - - fn generate_air_proof_input(self) -> AirProofInput { - self.0.generate_air_proof_input() + AirProvingContext::simple_no_pis(Arc::new(trace)) } } @@ -245,29 +201,3 @@ impl ChipUsageGetter for BitwiseOperationLookupChip ChipUsageGetter for SharedBitwiseOperationLookupChip { - fn air_name(&self) -> String { - self.0.air_name() - } - - fn constant_trace_height(&self) -> Option { - self.0.constant_trace_height() - } - - fn current_trace_height(&self) -> usize { - self.0.current_trace_height() - } - - fn trace_width(&self) -> usize { - self.0.trace_width() - } -} - -impl AsRef> - for SharedBitwiseOperationLookupChip -{ - fn as_ref(&self) -> &BitwiseOperationLookupChip { - &self.0 - } -} diff --git a/crates/circuits/primitives/src/range/mod.rs b/crates/circuits/primitives/src/range/mod.rs index 39dd70aae7..dc94c03c9c 100644 --- a/crates/circuits/primitives/src/range/mod.rs +++ b/crates/circuits/primitives/src/range/mod.rs @@ -122,7 +122,7 @@ impl RangeCheckerChip { let cols: &mut RangeCols = (*row).borrow_mut(); // Set multiplicity for each value in range cols.mult = - F::from_canonical_u32(self.count[n].load(std::sync::atomic::Ordering::SeqCst)); + F::from_canonical_u32(self.count[n].swap(0, std::sync::atomic::Ordering::Relaxed)); } RowMajorMatrix::new(rows, NUM_RANGE_COLS) } diff --git a/crates/circuits/primitives/src/range_gate/mod.rs b/crates/circuits/primitives/src/range_gate/mod.rs index 7c1a877c49..a3401e0c97 100644 --- a/crates/circuits/primitives/src/range_gate/mod.rs +++ b/crates/circuits/primitives/src/range_gate/mod.rs @@ -143,7 +143,7 @@ impl RangeCheckerGateChip { .iter() .enumerate() .flat_map(|(i, count)| { - let c = count.load(std::sync::atomic::Ordering::Relaxed); + let c = count.swap(0, std::sync::atomic::Ordering::Relaxed); vec![F::from_canonical_usize(i), F::from_canonical_u32(c)] }) .collect(); diff --git a/crates/circuits/primitives/src/range_tuple/mod.rs b/crates/circuits/primitives/src/range_tuple/mod.rs index 3d0754cc9a..4962d567c5 100644 --- a/crates/circuits/primitives/src/range_tuple/mod.rs +++ b/crates/circuits/primitives/src/range_tuple/mod.rs @@ -16,9 +16,9 @@ use openvm_stark_backend::{ p3_air::{Air, BaseAir, PairBuilder}, p3_field::{Field, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + Chip, ChipUsageGetter, }; mod bus; @@ -105,8 +105,7 @@ pub struct RangeTupleCheckerChip { pub count: Vec>, } -#[derive(Debug, Clone)] -pub struct SharedRangeTupleCheckerChip(Arc>); +pub type SharedRangeTupleCheckerChip = Arc>; impl RangeTupleCheckerChip { pub fn new(bus: RangeTupleCheckerBus) -> Self { @@ -154,61 +153,19 @@ impl RangeTupleCheckerChip { let rows = self .count .iter() - .map(|c| F::from_canonical_u32(c.load(std::sync::atomic::Ordering::SeqCst))) + .map(|c| F::from_canonical_u32(c.swap(0, std::sync::atomic::Ordering::Relaxed))) .collect::>(); RowMajorMatrix::new(rows, 1) } } -impl SharedRangeTupleCheckerChip { - pub fn new(bus: RangeTupleCheckerBus) -> Self { - Self(Arc::new(RangeTupleCheckerChip::new(bus))) - } - pub fn bus(&self) -> &RangeTupleCheckerBus { - self.0.bus() - } - - pub fn sizes(&self) -> &[u32; N] { - self.0.sizes() - } - - pub fn add_count(&self, ids: &[u32]) { - self.0.add_count(ids); - } - - pub fn clear(&self) { - self.0.clear(); - } - - pub fn generate_trace(&self) -> RowMajorMatrix { - self.0.generate_trace() - } -} - -impl Chip for RangeTupleCheckerChip +impl Chip> for RangeTupleCheckerChip where Val: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air) - } - - fn generate_air_proof_input(self) -> AirProofInput { + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { let trace = self.generate_trace::>(); - AirProofInput::simple_no_pis(trace) - } -} - -impl Chip for SharedRangeTupleCheckerChip -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - self.0.air() - } - - fn generate_air_proof_input(self) -> AirProofInput { - self.0.generate_air_proof_input() + AirProvingContext::simple_no_pis(Arc::new(trace)) } } @@ -226,27 +183,3 @@ impl ChipUsageGetter for RangeTupleCheckerChip { NUM_RANGE_TUPLE_COLS } } - -impl ChipUsageGetter for SharedRangeTupleCheckerChip { - fn air_name(&self) -> String { - self.0.air_name() - } - - fn constant_trace_height(&self) -> Option { - self.0.constant_trace_height() - } - - fn current_trace_height(&self) -> usize { - self.0.current_trace_height() - } - - fn trace_width(&self) -> usize { - self.0.trace_width() - } -} - -impl AsRef> for SharedRangeTupleCheckerChip { - fn as_ref(&self) -> &RangeTupleCheckerChip { - &self.0 - } -} diff --git a/crates/circuits/primitives/src/var_range/mod.rs b/crates/circuits/primitives/src/var_range/mod.rs index 1ba3f2e776..82999a8bda 100644 --- a/crates/circuits/primitives/src/var_range/mod.rs +++ b/crates/circuits/primitives/src/var_range/mod.rs @@ -16,9 +16,9 @@ use openvm_stark_backend::{ p3_air::{Air, BaseAir, PairBuilder}, p3_field::{Field, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + Chip, ChipUsageGetter, }; use tracing::instrument; @@ -102,8 +102,7 @@ pub struct VariableRangeCheckerChip { pub count: Vec, } -#[derive(Clone)] -pub struct SharedVariableRangeCheckerChip(Arc); +pub type SharedVariableRangeCheckerChip = Arc; impl VariableRangeCheckerChip { pub fn new(bus: VariableRangeCheckerBus) -> Self { @@ -153,12 +152,13 @@ impl VariableRangeCheckerChip { } } + /// Generates trace and resets the internal counters all to 0. pub fn generate_trace(&self) -> RowMajorMatrix { let mut rows = F::zero_vec(self.count.len() * NUM_VARIABLE_RANGE_COLS); for (n, row) in rows.chunks_mut(NUM_VARIABLE_RANGE_COLS).enumerate() { let cols: &mut VariableRangeCols = row.borrow_mut(); cols.mult = - F::from_canonical_u32(self.count[n].load(std::sync::atomic::Ordering::SeqCst)); + F::from_canonical_u32(self.count[n].swap(0, std::sync::atomic::Ordering::Relaxed)); } RowMajorMatrix::new(rows, NUM_VARIABLE_RANGE_COLS) } @@ -186,60 +186,15 @@ impl VariableRangeCheckerChip { } } -impl SharedVariableRangeCheckerChip { - pub fn new(bus: VariableRangeCheckerBus) -> Self { - Self(Arc::new(VariableRangeCheckerChip::new(bus))) - } - - pub fn bus(&self) -> VariableRangeCheckerBus { - self.0.bus() - } - - pub fn range_max_bits(&self) -> usize { - self.0.range_max_bits() - } - - pub fn air_width(&self) -> usize { - self.0.air_width() - } - - pub fn add_count(&self, value: u32, max_bits: usize) { - self.0.add_count(value, max_bits) - } - - pub fn clear(&self) { - self.0.clear() - } - - pub fn generate_trace(&self) -> RowMajorMatrix { - self.0.generate_trace() - } -} - -impl Chip for VariableRangeCheckerChip +// We allow any `R` type so this can work with arbitrary record arenas. +impl Chip> for VariableRangeCheckerChip where Val: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air) - } - - fn generate_air_proof_input(self) -> AirProofInput { + /// Generates trace and resets the internal counters all to 0. + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { let trace = self.generate_trace::>(); - AirProofInput::simple_no_pis(trace) - } -} - -impl Chip for SharedVariableRangeCheckerChip -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - self.0.air() - } - - fn generate_air_proof_input(self) -> AirProofInput { - self.0.generate_air_proof_input() + AirProvingContext::simple_no_pis(Arc::new(trace)) } } @@ -257,27 +212,3 @@ impl ChipUsageGetter for VariableRangeCheckerChip { NUM_VARIABLE_RANGE_COLS } } - -impl ChipUsageGetter for SharedVariableRangeCheckerChip { - fn air_name(&self) -> String { - self.0.air_name() - } - - fn constant_trace_height(&self) -> Option { - self.0.constant_trace_height() - } - - fn current_trace_height(&self) -> usize { - self.0.current_trace_height() - } - - fn trace_width(&self) -> usize { - self.0.trace_width() - } -} - -impl AsRef for SharedVariableRangeCheckerChip { - fn as_ref(&self) -> &VariableRangeCheckerChip { - &self.0 - } -} diff --git a/crates/circuits/primitives/src/xor/lookup/mod.rs b/crates/circuits/primitives/src/xor/lookup/mod.rs index c9e76ad4c9..af9175183d 100644 --- a/crates/circuits/primitives/src/xor/lookup/mod.rs +++ b/crates/circuits/primitives/src/xor/lookup/mod.rs @@ -19,9 +19,9 @@ use openvm_stark_backend::{ p3_air::{Air, BaseAir, PairBuilder}, p3_field::Field, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + Chip, ChipUsageGetter, }; use super::bus::XorBus; @@ -170,14 +170,10 @@ impl XorLookupChip { } } -impl Chip for XorLookupChip { - fn air(&self) -> AirRef { - Arc::new(self.air) - } - - fn generate_air_proof_input(self) -> AirProofInput { +impl Chip> for XorLookupChip { + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { let trace = self.generate_trace::>(); - AirProofInput::simple_no_pis(trace) + AirProvingContext::simple_no_pis(Arc::new(trace)) } } diff --git a/crates/circuits/sha256-air/src/air.rs b/crates/circuits/sha256-air/src/air.rs index 96578984d0..b27af6ffa9 100644 --- a/crates/circuits/sha256-air/src/air.rs +++ b/crates/circuits/sha256-air/src/air.rs @@ -15,11 +15,11 @@ use openvm_stark_backend::{ use super::{ big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, - small_sig1_field, u32_into_limbs, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, - SHA256_H, SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, - SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S, + small_sig1_field, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, SHA256_H, + SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_WORD_BITS, + SHA256_WORD_U16S, SHA256_WORD_U8S, }; -use crate::constraint_word_addition; +use crate::{constraint_word_addition, u32_into_u16s}; /// Expects the message to be padded to a multiple of 512 bits #[derive(Clone, Debug)] @@ -154,7 +154,7 @@ impl Sha256Air { .assert_eq( a_limb, AB::Expr::from_canonical_u32( - u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], + u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], ), ); @@ -166,7 +166,7 @@ impl Sha256Air { .assert_eq( e_limb, AB::Expr::from_canonical_u32( - u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], + u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], ), ); } @@ -561,9 +561,8 @@ impl Sha256Air { .map(|rw_idx| { ( rw_idx, - u32_into_limbs::( - SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i], - )[j] as usize, + u32_into_u16s(SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i])[j] + as usize, ) }) .collect::>(), diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs index 903b7b0695..7ad0229185 100644 --- a/crates/circuits/sha256-air/src/tests.rs +++ b/crates/circuits/sha256-air/src/tests.rs @@ -1,11 +1,14 @@ -use std::{array, borrow::BorrowMut, cmp::max, sync::Arc}; +use std::{array, borrow::BorrowMut, sync::Arc}; use openvm_circuit::arch::{ instructions::riscv::RV32_CELL_BITS, testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, SubAir, }; use openvm_stark_backend::{ @@ -13,18 +16,19 @@ use openvm_stark_backend::{ interaction::{BusIndex, InteractionBuilder}, p3_air::{Air, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut}, - prover::types::AirProofInput, - rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + utils::disable_debug_builder, + verifier::VerificationError, + AirRef, Chip, }; -use openvm_stark_sdk::utils::create_seeded_rng; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::Rng; use crate::{ - compose, small_sig0_field, Sha256Air, Sha256RoundCols, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_U16S, SHA256_WORD_U8S, + Sha256Air, Sha256DigestCols, Sha256FillerHelper, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, + SHA256_HASH_WORDS, SHA256_WIDTH, SHA256_WORD_U8S, }; // A wrapper AIR purely for testing purposes @@ -47,51 +51,47 @@ impl Air for Sha256TestAir { } } +const SELF_BUS_IDX: BusIndex = 28; +type F = BabyBear; +type RecordType = Vec<([u8; SHA256_BLOCK_U8S], bool)>; + // A wrapper Chip purely for testing purposes pub struct Sha256TestChip { - pub air: Sha256TestAir, + pub step: Sha256FillerHelper, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, } -impl Chip for Sha256TestChip +impl Chip> for Sha256TestChip where Val: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { + fn generate_proving_ctx(&self, records: RecordType) -> AirProvingContext> { let trace = crate::generate_trace::>( - &self.air.sub_air, - self.bitwise_lookup_chip.clone(), - self.records, + &self.step, + self.bitwise_lookup_chip.as_ref(), + SHA256_WIDTH, + records, ); - AirProofInput::simple_no_pis(trace) + AirProvingContext::simple_no_pis(Arc::new(trace)) } } -impl ChipUsageGetter for Sha256TestChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.len() * SHA256_ROWS_PER_BLOCK - } - - fn trace_width(&self) -> usize { - max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH) - } -} - -const SELF_BUS_IDX: BusIndex = 28; -#[test] -fn rand_sha256_test() { +#[allow(clippy::type_complexity)] +fn create_air_with_air_ctx() -> ( + (AirRef, AirProvingContext>), + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) +where + Val: PrimeField32, +{ let mut rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); let len = rng.gen_range(1..100); let random_records: Vec<_> = (0..len) .map(|i| { @@ -101,133 +101,63 @@ fn rand_sha256_test() { ) }) .collect(); + + let air = Sha256TestAir { + sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), + }; let chip = Sha256TestChip { - air: Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), - }, + step: Sha256FillerHelper::new(), bitwise_lookup_chip: bitwise_chip.clone(), - records: random_records, }; + let air_ctx = chip.generate_proving_ctx(random_records); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); + ((Arc::new(air), air_ctx), (bitwise_chip.air, bitwise_chip)) } -// A wrapper Chip to test that the final_hash is properly constrained. -// This chip implements a malicious trace gen that violates the final_hash constraints. -pub struct Sha256TestBadFinalHashChip { - pub air: Sha256TestAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, +#[test] +fn rand_sha256_test() { + let tester = VmChipTestBuilder::default(); + let (air_ctx, bitwise) = create_air_with_air_ctx(); + let tester = tester + .build() + .load_air_proving_ctx(air_ctx) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); } -impl Chip for Sha256TestBadFinalHashChip -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let mut trace = crate::generate_trace::>( - &self.air.sub_air, - self.bitwise_lookup_chip.clone(), - self.records.clone(), - ); - - // Set the final_hash in the digest row of the last block of each hash to zero. - // That is, every hash that this chip does will result in a final_hash of zero. - for (i, row) in self.records.iter().enumerate() { - if row.1 { - let last_digest_row_idx = (i + 1) * SHA256_ROWS_PER_BLOCK - 1; - let last_digest_row: &mut crate::Sha256DigestCols> = - trace.row_mut(last_digest_row_idx)[..SHA256_DIGEST_WIDTH].borrow_mut(); - // Set the final_hash to all zeros +#[test] +fn negative_sha256_test_bad_final_hash() { + let tester = VmChipTestBuilder::default(); + let ((air, mut air_ctx), bitwise) = create_air_with_air_ctx(); + + // Set the final_hash to all zeros + let modify_trace = |trace: &mut RowMajorMatrix| { + trace.row_chunks_exact_mut(1).for_each(|row| { + let mut row_slice = row.row_slice(0).to_vec(); + let cols: &mut Sha256DigestCols = row_slice[..SHA256_DIGEST_WIDTH].borrow_mut(); + if cols.flags.is_last_block.is_one() && cols.flags.is_digest_row.is_one() { for i in 0..SHA256_HASH_WORDS { for j in 0..SHA256_WORD_U8S { - last_digest_row.final_hash[i][j] = Val::::ZERO; + cols.final_hash[i][j] = F::ZERO; } } - - let (last_round_row, last_digest_row) = - trace.row_pair_mut(last_digest_row_idx - 1, last_digest_row_idx); - let last_round_row: &mut crate::Sha256RoundCols> = - last_round_row.borrow_mut(); - let last_digest_row: &mut crate::Sha256RoundCols> = - last_digest_row.borrow_mut(); - // fix the intermed_4 for the digest row - generate_intermed_4(last_round_row, last_digest_row); + row.values.copy_from_slice(&row_slice); } - } - - let non_padded_height = self.records.len() * SHA256_ROWS_PER_BLOCK; - let width = >>::width(&self.air.sub_air); - // recalculate the missing cells (second pass of generate_trace) - trace.values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.air.sub_air.generate_missing_cells(chunk, width, 0); - }); - - AirProofInput::simple_no_pis(trace) - } -} - -// Copy of private method in Sha256Air used for testing -/// Puts the correct intermed_4 in the `next_row` -fn generate_intermed_4( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, -) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let sig_w = small_sig0_field::(&w[i + 1]); - let sig_w_limbs: [F; SHA256_WORD_U16S] = - array::from_fn(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)); - for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { - next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb; - } - } -} - -impl ChipUsageGetter for Sha256TestBadFinalHashChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.len() * SHA256_ROWS_PER_BLOCK - } - - fn trace_width(&self) -> usize { - max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH) - } -} - -#[test] -#[should_panic] -fn test_sha256_final_hash_constraints() { - let mut rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let len = rng.gen_range(1..100); - let random_records: Vec<_> = (0..len) - .map(|_| (array::from_fn(|_| rng.gen::()), true)) - .collect(); - let chip = Sha256TestBadFinalHashChip { - air: Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), - }, - bitwise_lookup_chip: bitwise_chip.clone(), - records: random_records, + }); }; - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); + // Modify the air_ctx + let trace = Option::take(&mut air_ctx.common_main).unwrap(); + let mut trace = Arc::into_inner(trace).unwrap(); + modify_trace(&mut trace); + air_ctx.common_main = Some(Arc::new(trace)); + + disable_debug_builder(); + let tester = tester + .build() + .load_air_proving_ctx((air, air_ctx)) + .load_periphery(bitwise) + .finalize(); + tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); } diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs index eaf9174f50..8cbaebbc55 100644 --- a/crates/circuits/sha256-air/src/trace.rs +++ b/crates/circuits/sha256-air/src/trace.rs @@ -1,31 +1,48 @@ use std::{array, borrow::BorrowMut, ops::Range}; use openvm_circuit_primitives::{ - bitwise_op_lookup::SharedBitwiseOperationLookupChip, utils::next_power_of_two_or_zero, + bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder, + utils::next_power_of_two_or_zero, }; use openvm_stark_backend::{ - p3_air::BaseAir, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, + p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, }; use sha2::{compress256, digest::generic_array::GenericArray}; use super::{ - air::Sha256Air, big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, - get_flag_pt_array, maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, - SHA256_DIGEST_WIDTH, SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, + big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, get_flag_pt_array, + maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, SHA256_DIGEST_WIDTH, + SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, }; use crate::{ big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1, - u32_into_limbs, SHA256_BLOCK_U8S, SHA256_BUFFER_SIZE, SHA256_H, SHA256_INVALID_CARRY_A, + u32_into_bits_field, u32_into_u16s, SHA256_BLOCK_U8S, SHA256_H, SHA256_INVALID_CARRY_A, SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S, + SHA256_WORD_U16S, SHA256_WORD_U8S, }; +/// A helper struct for the SHA256 trace generation. +/// Also, separates the inner AIR from the trace generation. +pub struct Sha256FillerHelper { + pub row_idx_encoder: Encoder, +} + +impl Default for Sha256FillerHelper { + fn default() -> Self { + Self::new() + } +} + /// The trace generation of SHA256 should be done in two passes. /// The first pass should do `get_block_trace` for every block and generate the invalid rows through /// `get_default_row` The second pass should go through all the blocks and call /// `generate_missing_cells` -impl Sha256Air { +impl Sha256FillerHelper { + pub fn new() -> Self { + Self { + row_idx_encoder: Encoder::new(18, 2, false), + } + } /// This function takes the input_message (padding not handled), the previous hash, /// and returns the new hash after processing the block input pub fn get_block_hash( @@ -52,18 +69,16 @@ impl Sha256Air { trace_width: usize, trace_start_col: usize, input: &[u32; SHA256_BLOCK_WORDS], - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, prev_hash: &[u32; SHA256_HASH_WORDS], is_last_block: bool, global_block_idx: u32, local_block_idx: u32, - buffer_vals: &[[F; SHA256_BUFFER_SIZE]; 4], ) { #[cfg(debug_assertions)] { assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK); assert!(trace_start_col + super::SHA256_WIDTH <= trace_width); - assert!(self.bitwise_lookup_bus == bitwise_lookup_chip.bus()); if local_block_idx == 0 { assert!(*prev_hash == SHA256_H); } @@ -87,14 +102,10 @@ impl Sha256Air { cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); // W_idx = M_idx - if i < SHA256_ROWS_PER_BLOCK / SHA256_ROUNDS_PER_ROW { + if i < 4 { for j in 0..SHA256_ROUNDS_PER_ROW { - cols.message_schedule.w[j] = u32_into_limbs::( - input[i * SHA256_ROUNDS_PER_ROW + j], - ) - .map(F::from_canonical_u32); - cols.message_schedule.carry_or_buffer[j] = - array::from_fn(|k| buffer_vals[i][j * SHA256_WORD_U16S * 2 + k]); + cols.message_schedule.w[j] = + u32_into_bits_field::(input[i * SHA256_ROUNDS_PER_ROW + j]); } } // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} @@ -108,14 +119,10 @@ impl Sha256Air { message_schedule[idx - 16], ]; let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - cols.message_schedule.w[j] = - u32_into_limbs::(w).map(F::from_canonical_u32); + cols.message_schedule.w[j] = u32_into_bits_field::(w); - let nums_limbs = nums - .iter() - .map(|x| u32_into_limbs::(*x)) - .collect::>(); - let w_limbs = u32_into_limbs::(w); + let nums_limbs = nums.map(u32_into_u16s); + let w_limbs = u32_into_u16s(w); // fill in the carrys for k in 0..SHA256_WORD_U16S { @@ -157,25 +164,18 @@ impl Sha256Air { // e = d + t1 let e = work_vars[3].wrapping_add(t1_sum); - cols.work_vars.e[j] = - u32_into_limbs::(e).map(F::from_canonical_u32); - let e_limbs = u32_into_limbs::(e); + cols.work_vars.e[j] = u32_into_bits_field::(e); + let e_limbs = u32_into_u16s(e); // a = t1 + t2 let a = t1_sum.wrapping_add(t2_sum); - cols.work_vars.a[j] = - u32_into_limbs::(a).map(F::from_canonical_u32); - let a_limbs = u32_into_limbs::(a); + cols.work_vars.a[j] = u32_into_bits_field::(a); + let a_limbs = u32_into_u16s(a); // fill in the carrys for k in 0..SHA256_WORD_U16S { - let t1_limb = t1.iter().fold(0, |acc, &num| { - acc + u32_into_limbs::(num)[k] - }); - let t2_limb = t2.iter().fold(0, |acc, &num| { - acc + u32_into_limbs::(num)[k] - }); + let t1_limb = t1.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); + let t2_limb = t2.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - let mut e_limb = - t1_limb + u32_into_limbs::(work_vars[3])[k]; + let mut e_limb = t1_limb + u32_into_u16s(work_vars[3])[k]; let mut a_limb = t1_limb + t2_limb; if k > 0 { a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32(); @@ -203,16 +203,14 @@ impl Sha256Air { if i > 0 { for j in 0..SHA256_ROUNDS_PER_ROW { let idx = i * SHA256_ROUNDS_PER_ROW + j; - let w_4 = u32_into_limbs::(message_schedule[idx - 4]); - let sig_0_w_3 = u32_into_limbs::(small_sig0( - message_schedule[idx - 3], - )); + let w_4 = u32_into_u16s(message_schedule[idx - 4]); + let sig_0_w_3 = u32_into_u16s(small_sig0(message_schedule[idx - 3])); cols.schedule_helper.intermed_4[j] = array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])); if j < SHA256_ROUNDS_PER_ROW - 1 { let w_3 = message_schedule[idx - 3]; cols.schedule_helper.w_3[j] = - u32_into_limbs::(w_3).map(F::from_canonical_u32); + u32_into_u16s(w_3).map(F::from_canonical_u32); } } } @@ -223,8 +221,7 @@ impl Sha256Air { row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut(); for j in 0..SHA256_ROUNDS_PER_ROW - 1 { let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3]; - cols.schedule_helper.w_3[j] = - u32_into_limbs::(w_3).map(F::from_canonical_u32); + cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_canonical_u32); } cols.flags.is_round_row = F::ZERO; cols.flags.is_first_4_rows = F::ZERO; @@ -237,29 +234,27 @@ impl Sha256Air { cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); let final_hash: [u32; SHA256_HASH_WORDS] = array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i])); - let final_hash_limbs: [[u32; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = - array::from_fn(|i| u32_into_limbs::(final_hash[i])); + let final_hash_limbs: [[u8; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = + array::from_fn(|i| final_hash[i].to_le_bytes()); // need to ensure final hash limbs are bytes, in order for // prev_hash[i] + work_vars[i] == final_hash[i] // to be constrained correctly for word in final_hash_limbs.iter() { for chunk in word.chunks(2) { - bitwise_lookup_chip.request_range(chunk[0], chunk[1]); + bitwise_lookup_chip.request_range(chunk[0] as u32, chunk[1] as u32); } } cols.final_hash = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(final_hash_limbs[i][j])) + array::from_fn(|j| F::from_canonical_u8(final_hash_limbs[i][j])) }); - cols.prev_hash = prev_hash - .map(|f| u32_into_limbs::(f).map(F::from_canonical_u32)); + cols.prev_hash = prev_hash.map(|f| u32_into_u16s(f).map(F::from_canonical_u32)); let hash = if is_last_block { - SHA256_H.map(u32_into_limbs::) + SHA256_H.map(u32_into_bits_field::) } else { cols.final_hash - .map(|f| limbs_into_u32(f.map(|x| x.as_canonical_u32()))) - .map(u32_into_limbs::) - } - .map(|x| x.map(F::from_canonical_u32)); + .map(|f| u32::from_le_bytes(f.map(|x| x.as_canonical_u32() as u8))) + .map(u32_into_bits_field::) + }; for i in 0..SHA256_ROUNDS_PER_ROW { cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; @@ -338,24 +333,14 @@ impl Sha256Air { /// Fills the `cols` as a padding row /// Note: we still need to correctly fill in the hash values, carries and intermeds - pub fn generate_default_row(self: &Sha256Air, cols: &mut Sha256RoundCols) { - cols.flags.is_round_row = F::ZERO; - cols.flags.is_first_4_rows = F::ZERO; - cols.flags.is_digest_row = F::ZERO; - - cols.flags.is_last_block = F::ZERO; - cols.flags.global_block_idx = F::ZERO; + pub fn generate_default_row( + self: &Sha256FillerHelper, + cols: &mut Sha256RoundCols, + ) { cols.flags.row_idx = get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_canonical_u32); - cols.flags.local_block_idx = F::ZERO; - - cols.message_schedule.w = [[F::ZERO; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW]; - cols.message_schedule.carry_or_buffer = - [[F::ZERO; SHA256_WORD_U16S * 2]; SHA256_ROUNDS_PER_ROW]; - let hash = SHA256_H - .map(u32_into_limbs::) - .map(|x| x.map(F::from_canonical_u32)); + let hash = SHA256_H.map(u32_into_bits_field::); for i in 0..SHA256_ROUNDS_PER_ROW { cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; @@ -486,15 +471,16 @@ impl Sha256Air { } } +/// Generates a trace for a standalone SHA256 computation (currently only used for testing) /// `records` consists of pairs of `(input_block, is_last_block)`. pub fn generate_trace( - sub_air: &Sha256Air, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + step: &Sha256FillerHelper, + bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, + width: usize, records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, ) -> RowMajorMatrix { let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; let height = next_power_of_two_or_zero(non_padded_height); - let width = >::width(sub_air); let mut values = F::zero_vec(height * width); struct BlockContext { @@ -522,7 +508,7 @@ pub fn generate_trace( prev_hash = SHA256_H; } else { local_block_idx += 1; - prev_hash = Sha256Air::get_block_hash(&prev_hash, input); + prev_hash = Sha256FillerHelper::get_block_hash(&prev_hash, input); } } // first pass @@ -542,17 +528,16 @@ pub fn generate_trace( input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32 })) }); - sub_air.generate_block_trace( + step.generate_block_trace( block, width, 0, &input_words, - bitwise_lookup_chip.clone(), + bitwise_lookup_chip, &prev_hash, is_last_block, global_block_idx, local_block_idx, - &[[F::ZERO; 16]; 4], ); }); // second pass: padding rows @@ -560,14 +545,14 @@ pub fn generate_trace( .par_chunks_mut(width) .for_each(|row| { let cols: &mut Sha256RoundCols = row.borrow_mut(); - sub_air.generate_default_row(cols); + step.generate_default_row(cols); }); // second pass: non-padding rows values[width..] .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) .take(non_padded_height / SHA256_ROWS_PER_BLOCK) .for_each(|chunk| { - sub_air.generate_missing_cells(chunk, width, 0); + step.generate_missing_cells(chunk, width, 0); }); RowMajorMatrix::new(values, width) } diff --git a/crates/circuits/sha256-air/src/utils.rs b/crates/circuits/sha256-air/src/utils.rs index abf8b6e7f2..ba598f2604 100644 --- a/crates/circuits/sha256-air/src/utils.rs +++ b/crates/circuits/sha256-air/src/utils.rs @@ -6,7 +6,6 @@ use openvm_circuit_primitives::{ utils::{not, select}, }; use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra}; -use rand::{rngs::StdRng, Rng}; use super::{Sha256DigestCols, Sha256RoundCols}; @@ -74,10 +73,21 @@ pub const SHA256_H: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, ]; -/// Convert a u32 into a list of limbs in little endian -pub fn u32_into_limbs(num: u32) -> [u32; NUM_LIMBS] { - let limb_bits = 32 / NUM_LIMBS; - array::from_fn(|i| (num >> (limb_bits * i)) & ((1 << limb_bits) - 1)) +/// Returns the number of blocks required to hash a message of length `len` +pub fn get_sha256_num_blocks(len: u32) -> u32 { + // need to pad with one 1 bit, 64 bits for the message length and then pad until the length + // is divisible by [SHA256_BLOCK_BITS] + ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS) as u32 +} + +/// Convert a u32 into a list of bits in little endian then convert each bit into a field element +pub fn u32_into_bits_field(num: u32) -> [F; SHA256_WORD_BITS] { + array::from_fn(|i| F::from_bool((num >> i) & 1 == 1)) +} + +/// Convert a u32 into a an array of 2 16-bit limbs in little endian +pub fn u32_into_u16s(num: u32) -> [u32; 2] { + [num & 0xffff, num >> 16] } /// Convert a list of limbs in little endian into a u32 @@ -227,13 +237,6 @@ pub(crate) fn small_sig1_field( xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) } -/// Generate a random message of a given length -pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { - let mut random_message: Vec = vec![0u8; len]; - rng.fill(&mut random_message[..]); - random_message -} - /// Wrapper of `get_flag_pt` to get the flag pointer as an array pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> [u32; N] { encoder.get_flag_pt(flag_idx).try_into().unwrap() diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index f105c588a3..9afc91ac40 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -25,10 +25,10 @@ openvm-stark-sdk.workspace = true openvm-stark-backend.workspace = true openvm-circuit = { workspace = true } -aws-sdk-s3 = "1.78" -aws-config = "1.5" -tokio = { version = "1.43.1", features = ["rt", "rt-multi-thread", "macros"] } -clap = { version = "4.5.9", features = ["derive", "env"] } +aws-sdk-s3 = "1.98.0" +aws-config = "1.8.2" +tokio = { version = "1.46.1", features = ["rt", "rt-multi-thread", "macros"] } +clap = { workspace = true, features = ["derive", "env"] } eyre.workspace = true tracing.workspace = true serde.workspace = true @@ -42,12 +42,12 @@ toml_edit = "0.22" include_dir = "0.7" [features] -default = ["parallel", "jemalloc", "evm-verify", "bench-metrics"] +default = ["parallel", "jemalloc", "evm-verify", "metrics"] evm-prove = ["openvm-sdk/evm-prove"] evm-verify = ["evm-prove", "openvm-sdk/evm-verify"] -bench-metrics = ["openvm-sdk/bench-metrics"] +metrics = ["openvm-sdk/metrics"] # for guest profiling: -profiling = ["openvm-sdk/profiling"] +perf-metrics = ["openvm-sdk/perf-metrics", "metrics"] # performance features: # (rayon is always imported because of halo2, so "parallel" feature is redundant) parallel = ["openvm-sdk/parallel"] @@ -55,3 +55,4 @@ mimalloc = ["openvm-sdk/mimalloc"] jemalloc = ["openvm-sdk/jemalloc"] jemalloc-prof = ["openvm-sdk/jemalloc-prof"] nightly-features = ["openvm-sdk/nightly-features"] +ci = [] diff --git a/crates/cli/src/commands/prove.rs b/crates/cli/src/commands/prove.rs index 55c10901b9..3112b116c9 100644 --- a/crates/cli/src/commands/prove.rs +++ b/crates/cli/src/commands/prove.rs @@ -6,7 +6,7 @@ use eyre::Result; use openvm_sdk::fs::write_evm_proof_to_file; use openvm_sdk::{ commit::AppExecutionCommit, - config::{AggregationTreeConfig, SdkVmConfig}, + config::{AggregationTreeConfig, SdkVmConfig, SdkVmCpuBuilder}, fs::{ read_agg_stark_pk_from_file, read_app_pk_from_file, read_exe_from_file, write_app_proof_to_file, write_to_file_json, @@ -115,6 +115,7 @@ enum ProveSubCommand { impl ProveCmd { pub fn run(&self) -> Result<()> { + let vm_builder = SdkVmCpuBuilder; match &self.command { ProveSubCommand::App { app_pk, @@ -127,8 +128,12 @@ impl ProveCmd { let (committed_exe, target_name) = load_or_build_and_commit_exe(&sdk, run_args, cargo_args, &app_pk)?; - let app_proof = - sdk.generate_app_proof(app_pk, committed_exe, read_to_stdin(&run_args.input)?)?; + let app_proof = sdk.generate_app_proof( + vm_builder, + app_pk, + committed_exe, + read_to_stdin(&run_args.input)?, + )?; let proof_path = if let Some(proof) = proof { proof @@ -161,6 +166,7 @@ impl ProveCmd { eyre::eyre!("Failed to read aggregation proving key: {}\nPlease run 'cargo openvm setup' first", e) })?; let stark_proof = sdk.generate_e2e_stark_proof( + vm_builder, app_pk, committed_exe, agg_stark_pk, @@ -206,6 +212,7 @@ impl ProveCmd { let params_reader = CacheHalo2ParamsReader::new(default_params_dir()); let evm_proof = sdk.generate_evm_proof( ¶ms_reader, + vm_builder, app_pk, committed_exe, agg_pk, diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index b126946af5..1b58c45920 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -6,8 +6,7 @@ pub mod util; use std::process::{Command, Stdio}; use eyre::{Context, Result}; - -pub const RUSTUP_TOOLCHAIN_NAME: &str = "nightly-2025-02-14"; +pub use openvm_build::{get_rustup_toolchain_name, DEFAULT_RUSTUP_TOOLCHAIN_NAME}; pub const OPENVM_VERSION_MESSAGE: &str = concat!( "v", diff --git a/crates/cli/tests/app_e2e.rs b/crates/cli/tests/app_e2e.rs index 482b583ef1..30032c0f81 100644 --- a/crates/cli/tests/app_e2e.rs +++ b/crates/cli/tests/app_e2e.rs @@ -3,16 +3,27 @@ use std::{ fs::{self, read_to_string}, path::Path, process::Command, + sync::OnceLock, }; use eyre::Result; use itertools::Itertools; use tempfile::tempdir; +fn install_cli() { + static FORCE_INSTALL: OnceLock = OnceLock::new(); + FORCE_INSTALL.get_or_init(|| { + if !matches!(env::var("SKIP_INSTALL"), Ok(x) if !x.is_empty()) { + run_cmd("cargo", &["install", "--path", ".", "--force", "--locked"]).unwrap(); + } + true + }); +} + #[test] fn test_cli_app_e2e() -> Result<()> { let temp_dir = tempdir()?; - run_cmd("cargo", &["install", "--path", ".", "--force", "--locked"])?; + install_cli(); let exe_path = "tests/programs/fibonacci/target/openvm/release/openvm-cli-example-test.vmexe"; let temp_pk = temp_dir.path().join("app.pk"); let temp_vk = temp_dir.path().join("app.vk"); @@ -87,7 +98,7 @@ fn test_cli_app_e2e() -> Result<()> { #[test] fn test_cli_app_e2e_simplified() -> Result<()> { - run_cmd("cargo", &["install", "--path", ".", "--force", "--locked"])?; + install_cli(); run_cmd( "cargo", &[ @@ -128,7 +139,7 @@ fn test_cli_init_build() -> Result<()> { let temp_path = temp_dir.path(); let config_path = temp_path.join("openvm.toml"); let manifest_path = temp_path.join("Cargo.toml"); - run_cmd("cargo", &["install", "--path", ".", "--force", "--locked"])?; + install_cli(); // Cargo will not respect patches if run within a workspace run_cmd( diff --git a/crates/continuations/src/verifier/leaf/mod.rs b/crates/continuations/src/verifier/leaf/mod.rs index 969733ba41..7ab08cdb0b 100644 --- a/crates/continuations/src/verifier/leaf/mod.rs +++ b/crates/continuations/src/verifier/leaf/mod.rs @@ -1,6 +1,6 @@ use openvm_circuit::{ - arch::{instructions::program::Program, SystemConfig}, - system::memory::tree::public_values::PUBLIC_VALUES_ADDRESS_SPACE_OFFSET, + arch::{instructions::program::Program, SystemConfig, ADDR_SPACE_OFFSET}, + system::memory::merkle::public_values::PUBLIC_VALUES_ADDRESS_SPACE_OFFSET, }; use openvm_native_compiler::{conversion::CompilerOptions, prelude::*}; use openvm_native_recursion::{ @@ -113,7 +113,7 @@ impl LeafVmVerifierConfig { builder: &mut Builder, ) -> ([Felt; DIGEST_SIZE], [Felt; DIGEST_SIZE]) { let memory_dimensions = self.app_system_config.memory_config.memory_dimensions(); - let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; + let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + ADDR_SPACE_OFFSET; let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0)); let pv_height = log2_strict_usize(self.app_system_config.num_public_values / DIGEST_SIZE); let proof_len = memory_dimensions.overall_height() - pv_height; diff --git a/crates/continuations/src/verifier/leaf/types.rs b/crates/continuations/src/verifier/leaf/types.rs index 16aca7a169..d47b36f248 100644 --- a/crates/continuations/src/verifier/leaf/types.rs +++ b/crates/continuations/src/verifier/leaf/types.rs @@ -1,6 +1,6 @@ use derivative::Derivative; use openvm_circuit::{ - arch::ContinuationVmProof, system::memory::tree::public_values::UserPublicValuesProof, + arch::ContinuationVmProof, system::memory::merkle::public_values::UserPublicValuesProof, }; use openvm_native_compiler::ir::DIGEST_SIZE; use openvm_stark_sdk::{ diff --git a/crates/prof/src/aggregate.rs b/crates/prof/src/aggregate.rs index 047d16b30a..b19f869ad0 100644 --- a/crates/prof/src/aggregate.rs +++ b/crates/prof/src/aggregate.rs @@ -165,11 +165,14 @@ impl AggregateMetrics { let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0)); for (group_name, metrics) in &self.by_group { let stats = metrics.get(PROOF_TIME_LABEL); - let execute_stats = metrics.get(EXECUTE_TIME_LABEL); + let execute_metered_stats = metrics.get(EXECUTE_METERED_TIME_LABEL); + let execute_e1_stats = metrics.get(EXECUTE_E1_TIME_LABEL); if stats.is_none() { continue; } - let stats = stats.unwrap(); + let stats = stats.unwrap_or_else(|| { + panic!("Missing proof time statistics for group '{}'", group_name) + }); let mut sum = stats.sum; let mut max = stats.max; // convert ms to s @@ -184,26 +187,61 @@ impl AggregateMetrics { if !group_name.contains("keygen") { // Proving time in keygen group is dummy and not part of total. total_proof_time.val += sum.val; - *total_proof_time.diff.as_mut().unwrap() += sum.diff.unwrap_or(0.0); + *total_proof_time + .diff + .as_mut() + .expect("total_proof_time.diff should be initialized") += + sum.diff.unwrap_or(0.0); total_par_proof_time.val += max.val; - *total_par_proof_time.diff.as_mut().unwrap() += max.diff.unwrap_or(0.0); + *total_par_proof_time + .diff + .as_mut() + .expect("total_par_proof_time.diff should be initialized") += + max.diff.unwrap_or(0.0); - // Account for the fact that execution is serial - // Add total execution time for the app proofs, and subtract the max segment - // execution time + // Account for the serial execute_metered and execute_e1 for app outside of segments if group_name != "leaf" && group_name != "root" && group_name != "halo2_outer" && group_name != "halo2_wrapper" && !group_name.starts_with("internal") { - let execute_stats = execute_stats.unwrap(); - total_par_proof_time.val += - (execute_stats.sum.val - execute_stats.max.val) / 1000.0; - *total_par_proof_time.diff.as_mut().unwrap() += - (execute_stats.sum.diff.unwrap_or(0.0) - - execute_stats.max.diff.unwrap_or(0.0)) - / 1000.0; + if let Some(execute_metered_stats) = execute_metered_stats { + // For metered metrics without segment labels, we just use the value + // directly Count is 1, so avg = sum = max = min = + // value + total_proof_time.val += execute_metered_stats.avg.val / 1000.0; + total_par_proof_time.val += execute_metered_stats.avg.val / 1000.0; + if let Some(diff) = execute_metered_stats.avg.diff { + *total_proof_time + .diff + .as_mut() + .expect("total_proof_time.diff should be initialized") += + diff / 1000.0; + *total_par_proof_time + .diff + .as_mut() + .expect("total_par_proof_time.diff should be initialized") += + diff / 1000.0; + } + } + + if let Some(execute_e1_stats) = execute_e1_stats { + total_proof_time.val += execute_e1_stats.avg.val / 1000.0; + total_par_proof_time.val += execute_e1_stats.avg.val / 1000.0; + if let Some(diff) = execute_e1_stats.avg.diff { + *total_proof_time + .diff + .as_mut() + .expect("total_proof_time.diff should be initialized") += + diff / 1000.0; + *total_par_proof_time + .diff + .as_mut() + .expect("total_par_proof_time.diff should be initialized") += + diff / 1000.0; + } + } } } } @@ -239,7 +277,13 @@ impl AggregateMetrics { .into_iter() .map(|group_name| { let key = group_name.clone(); - let value = self.by_group.get(group_name).unwrap().clone(); + let value = self + .by_group + .get(group_name) + .unwrap_or_else(|| { + panic!("Group '{}' should exist in by_group map", group_name) + }) + .clone(); (key, value) }) .collect() @@ -252,6 +296,7 @@ impl AggregateMetrics { .map(|(group_name, metrics)| { let metrics = metrics .iter() + .filter(|(_, stats)| stats.avg.val.is_finite() && stats.sum.val.is_finite()) .flat_map(|(metric_name, stats)| { [ (format!("{metric_name}::sum"), stats.sum.into()), @@ -295,11 +340,37 @@ impl AggregateMetrics { for metric_name in names { let summary = summaries.get(metric_name); if let Some(summary) = summary { - writeln!( - writer, - "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", - metric_name, summary.avg, summary.sum, summary.max, summary.min, - )?; + // Special handling for execute_metered metrics (not aggregated across segments + // in the app proof case) + if metric_name == EXECUTE_METERED_TIME_LABEL + && group_name != "leaf" + && group_name != "root" + && group_name != "halo2_outer" + && group_name != "halo2_wrapper" + && !group_name.starts_with("internal") + { + writeln!( + writer, + "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", + metric_name, summary.avg, "-", "-", "-", + )?; + } else if metric_name == EXECUTE_E1_INSN_MI_S_LABEL + || metric_name == EXECUTE_PREFLIGHT_INSN_MI_S_LABEL + || metric_name == EXECUTE_METERED_INSN_MI_S_LABEL + { + // skip sum because it is misleading + writeln!( + writer, + "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", + metric_name, summary.avg, "-", summary.max, summary.min, + )?; + } else { + writeln!( + writer, + "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", + metric_name, summary.avg, summary.sum, summary.max, summary.min, + )?; + } } } writeln!(writer)?; @@ -317,11 +388,16 @@ impl AggregateMetrics { writeln!(writer, "|:---|---:|---:|")?; let mut rows = Vec::new(); for (group_name, summaries) in self.to_vec() { + if group_name.contains("keygen") { + continue; + } let stats = summaries.get(PROOF_TIME_LABEL); if stats.is_none() { continue; } - let stats = stats.unwrap(); + let stats = stats.unwrap_or_else(|| { + panic!("Missing proof time statistics for group '{}'", group_name) + }); let mut sum = stats.sum; let mut max = stats.max; // convert ms to s @@ -352,7 +428,12 @@ impl AggregateMetrics { self.by_group .keys() .find(|k| group_weight(k) == 0) - .unwrap_or_else(|| self.by_group.keys().next().unwrap()) + .unwrap_or_else(|| { + self.by_group + .keys() + .next() + .expect("by_group should contain at least one group") + }) .clone() } } @@ -381,18 +462,36 @@ impl BenchmarkOutput { } pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms"; -pub const CELLS_USED_LABEL: &str = "main_cells_used"; -pub const CYCLES_LABEL: &str = "total_cycles"; -pub const EXECUTE_TIME_LABEL: &str = "execute_time_ms"; +pub const MAIN_CELLS_USED_LABEL: &str = "main_cells_used"; +pub const TOTAL_CELLS_USED_LABEL: &str = "total_cells_used"; +pub const INSNS_LABEL: &str = "insns"; +pub const EXECUTE_E1_TIME_LABEL: &str = "execute_e1_time_ms"; +pub const EXECUTE_E1_INSN_MI_S_LABEL: &str = "execute_e1_insn_mi/s"; +pub const EXECUTE_METERED_TIME_LABEL: &str = "execute_metered_time_ms"; +pub const EXECUTE_METERED_INSN_MI_S_LABEL: &str = "execute_metered_insn_mi/s"; +pub const EXECUTE_PREFLIGHT_TIME_LABEL: &str = "execute_preflight_time_ms"; +pub const EXECUTE_PREFLIGHT_INSN_MI_S_LABEL: &str = "execute_preflight_insn_mi/s"; pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms"; +pub const MEM_FIN_TIME_LABEL: &str = "memory_finalize_time_ms"; +pub const BOUNDARY_FIN_TIME_LABEL: &str = "boundary_finalize_time_ms"; +pub const MERKLE_FIN_TIME_LABEL: &str = "merkle_finalize_time_ms"; pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms"; pub const VM_METRIC_NAMES: &[&str] = &[ PROOF_TIME_LABEL, - CELLS_USED_LABEL, - CYCLES_LABEL, - EXECUTE_TIME_LABEL, + MAIN_CELLS_USED_LABEL, + TOTAL_CELLS_USED_LABEL, + INSNS_LABEL, + EXECUTE_E1_TIME_LABEL, + EXECUTE_E1_INSN_MI_S_LABEL, + EXECUTE_METERED_TIME_LABEL, + EXECUTE_METERED_INSN_MI_S_LABEL, + EXECUTE_PREFLIGHT_TIME_LABEL, + EXECUTE_PREFLIGHT_INSN_MI_S_LABEL, TRACE_GEN_TIME_LABEL, + MEM_FIN_TIME_LABEL, + BOUNDARY_FIN_TIME_LABEL, + MERKLE_FIN_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, "main_trace_commit_time_ms", "generate_perm_trace_time_ms", diff --git a/crates/prof/src/lib.rs b/crates/prof/src/lib.rs index 58440a8e02..ec6117c1e7 100644 --- a/crates/prof/src/lib.rs +++ b/crates/prof/src/lib.rs @@ -1,12 +1,13 @@ use std::{collections::HashMap, fs::File, path::Path}; -use aggregate::{ - EXECUTE_TIME_LABEL, PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, TRACE_GEN_TIME_LABEL, -}; +use aggregate::{PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, TRACE_GEN_TIME_LABEL}; use eyre::Result; use memmap2::Mmap; -use crate::types::{Labels, Metric, MetricDb, MetricsFile}; +use crate::{ + aggregate::{EXECUTE_METERED_TIME_LABEL, EXECUTE_PREFLIGHT_TIME_LABEL}, + types::{Labels, Metric, MetricDb, MetricsFile}, +}; pub mod aggregate; pub mod summary; @@ -45,13 +46,29 @@ impl MetricDb { pub fn apply_aggregations(&mut self) { for metrics in self.flat_dict.values_mut() { let get = |key: &str| metrics.iter().find(|m| m.name == key).map(|m| m.value); - let execute_time = get(EXECUTE_TIME_LABEL); + let total_proof_time = get(PROOF_TIME_LABEL); + if total_proof_time.is_some() { + // We have instrumented total_proof_time_ms + continue; + } + // otherwise, calculate it from sub-components + let execute_metered_time = get(EXECUTE_METERED_TIME_LABEL); + let execute_preflight_time = get(EXECUTE_PREFLIGHT_TIME_LABEL); let trace_gen_time = get(TRACE_GEN_TIME_LABEL); let prove_excl_trace_time = get(PROVE_EXCL_TRACE_TIME_LABEL); - if let (Some(execute_time), Some(trace_gen_time), Some(prove_excl_trace_time)) = - (execute_time, trace_gen_time, prove_excl_trace_time) - { - let total_time = execute_time + trace_gen_time + prove_excl_trace_time; + if let ( + Some(execute_preflight_time), + Some(trace_gen_time), + Some(prove_excl_trace_time), + ) = ( + execute_preflight_time, + trace_gen_time, + prove_excl_trace_time, + ) { + let total_time = execute_metered_time.unwrap_or(0.0) + + execute_preflight_time + + trace_gen_time + + prove_excl_trace_time; metrics.push(Metric::new(PROOF_TIME_LABEL.to_string(), total_time)); } } @@ -90,7 +107,12 @@ impl MetricDb { let label_values: Vec = label_keys .iter() - .map(|key| label_dict.get(key).unwrap().clone()) + .map(|key| { + label_dict + .get(key) + .unwrap_or_else(|| panic!("Label key '{}' should exist in label_dict", key)) + .clone() + }) .collect(); // Add to dict_by_label_types diff --git a/crates/prof/src/main.rs b/crates/prof/src/main.rs index 31ddb2b359..1474153a9f 100644 --- a/crates/prof/src/main.rs +++ b/crates/prof/src/main.rs @@ -84,8 +84,9 @@ fn main() -> Result<()> { // If this is a new benchmark, prev_path will not exist if let Ok(prev_db) = MetricDb::new(&prev_path) { let prev_grouped = GroupedMetrics::new(&prev_db, "group")?; - prev_aggregated = Some(prev_grouped.aggregate()); - aggregated.set_diff(prev_aggregated.as_ref().unwrap()); + let prev_grouped_aggregated = prev_grouped.aggregate(); + aggregated.set_diff(&prev_grouped_aggregated); + prev_aggregated = Some(prev_grouped_aggregated); } } if name.is_empty() { diff --git a/crates/prof/src/summary.rs b/crates/prof/src/summary.rs index 9501b03e05..b6d7284e70 100644 --- a/crates/prof/src/summary.rs +++ b/crates/prof/src/summary.rs @@ -4,7 +4,10 @@ use eyre::Result; use itertools::Itertools; use crate::{ - aggregate::{AggregateMetrics, CELLS_USED_LABEL, CYCLES_LABEL, PROOF_TIME_LABEL}, + aggregate::{ + AggregateMetrics, EXECUTE_METERED_TIME_LABEL, EXECUTE_PREFLIGHT_TIME_LABEL, INSNS_LABEL, + MAIN_CELLS_USED_LABEL, PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, TRACE_GEN_TIME_LABEL, + }, types::MdTableCell, }; @@ -37,7 +40,7 @@ pub struct SingleSummaryMetrics { /// Parallel proof time is approximated as the max of proof times within a group pub par_proof_time_ms: MdTableCell, pub cells_used: MdTableCell, - pub cycles: MdTableCell, + pub insns: MdTableCell, } impl GithubSummary { @@ -52,8 +55,14 @@ impl GithubSummary { .zip_eq(md_paths.iter()) .zip_eq(names) .map(|(((aggregated, prev_aggregated), md_path), name)| { - let md_filename = md_path.file_name().unwrap().to_str().unwrap(); - let mut row = aggregated.get_summary_row(md_filename).unwrap(); + let md_filename = md_path + .file_name() + .expect("Path should have a filename") + .to_str() + .expect("Filename should be valid UTF-8"); + let mut row = aggregated.get_summary_row(md_filename).unwrap_or_else(|| { + panic!("Failed to get summary row for file '{}'", md_filename) + }); if let Some(prev_aggregated) = prev_aggregated { // md_filename doesn't matter if let Some(prev_row) = prev_aggregated.get_summary_row(md_filename) { @@ -136,14 +145,14 @@ impl SingleSummaryMetrics { write!( writer, "{} | {} | {} |", - self.proof_time_ms, self.cycles, self.cells_used, + self.proof_time_ms, self.insns, self.cells_used, )?; Ok(()) } pub fn set_diff(&mut self, prev: &Self) { self.cells_used.diff = Some(self.cells_used.val - prev.cells_used.val); - self.cycles.diff = Some(self.cycles.val - prev.cycles.val); + self.insns.diff = Some(self.insns.val - prev.insns.val); self.proof_time_ms.diff = Some(self.proof_time_ms.val - prev.proof_time_ms.val); } } @@ -152,16 +161,70 @@ impl AggregateMetrics { pub fn get_single_summary(&self, name: &str) -> Option { let stats = self.by_group.get(name)?; // Any group must have proof_time, but may not have cells_used or cycles (e.g., halo2) - let proof_time_ms = stats.get(PROOF_TIME_LABEL)?.sum; - let par_proof_time_ms = stats.get(PROOF_TIME_LABEL)?.max; + let proof_time_ms = if let Some(proof_stats) = stats.get(PROOF_TIME_LABEL) { + proof_stats.sum + } else { + // Note: execute_metered is outside any segment scope, so it should have sum = max = avg + let execute_metered = stats + .get(EXECUTE_METERED_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + let execute_preflight = stats + .get(EXECUTE_PREFLIGHT_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + // If total_proof_time_ms is not available, compute it from components + let trace_gen = stats + .get(TRACE_GEN_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + let stark_prove = stats + .get(PROVE_EXCL_TRACE_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + println!( + "{} {} {} {}", + execute_metered, execute_preflight, trace_gen, stark_prove + ); + MdTableCell::new( + execute_metered + execute_preflight + trace_gen + stark_prove, + None, + ) + }; + println!("{}", self.total_proof_time.val); + let par_proof_time_ms = if let Some(proof_stats) = stats.get(PROOF_TIME_LABEL) { + proof_stats.max + } else { + // Use the same computation for max + let execute_metered = stats + .get(EXECUTE_METERED_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + let execute_preflight = stats + .get(EXECUTE_PREFLIGHT_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + let trace_gen = stats + .get(TRACE_GEN_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + let stark_prove = stats + .get(PROVE_EXCL_TRACE_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + MdTableCell::new( + execute_metered + execute_preflight + trace_gen + stark_prove, + None, + ) + }; let cells_used = stats - .get(CELLS_USED_LABEL) + .get(MAIN_CELLS_USED_LABEL) .map(|s| s.sum) .unwrap_or_default(); - let cycles = stats.get(CYCLES_LABEL).map(|s| s.sum).unwrap_or_default(); + let insns = stats.get(INSNS_LABEL).map(|s| s.sum).unwrap_or_default(); Some(SingleSummaryMetrics { cells_used, - cycles, + insns, proof_time_ms, par_proof_time_ms, }) diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 6a868a3beb..8e1bdd449a 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -42,7 +42,6 @@ derivative = { workspace = true } derive_more = { workspace = true } serde = { workspace = true } eyre.workspace = true -async-trait.workspace = true metrics.workspace = true tracing.workspace = true itertools.workspace = true @@ -50,14 +49,16 @@ getset.workspace = true clap = { workspace = true, features = ["derive"] } serde_with = { workspace = true, features = ["hex"] } serde_json.workspace = true +toml.workspace = true thiserror.workspace = true +rand.workspace = true snark-verifier = { workspace = true, optional = true } snark-verifier-sdk = { workspace = true, optional = true } tempfile.workspace = true hex.workspace = true forge-fmt = { workspace = true, optional = true } -rrs-lib = { workspace = true } -num-bigint = { workspace = true } +rrs-lib.workspace = true +num-bigint.workspace = true [features] default = ["parallel", "jemalloc"] @@ -73,13 +74,15 @@ evm-verify = [ "dep:alloy-sol-types", "dep:forge-fmt", ] -bench-metrics = [ - "openvm-circuit/bench-metrics", - "openvm-native-recursion/bench-metrics", - "openvm-native-compiler/bench-metrics", +metrics = [ + "openvm-circuit/metrics", + "openvm-native-recursion/metrics", + "openvm-native-compiler/metrics", ] # for guest profiling: -profiling = ["openvm-circuit/function-span", "openvm-transpiler/function-span"] +perf-metrics = ["openvm-circuit/perf-metrics", "openvm-transpiler/function-span"] +# turns on stark-backend debugger in all proofs +stark-debug = ["openvm-circuit/stark-debug"] test-utils = ["openvm-circuit/test-utils"] # performance features: # (rayon is always imported because of halo2, so "parallel" feature is redundant) @@ -93,3 +96,6 @@ nightly-features = ["openvm-circuit/nightly-features"] name = "sdk_evm" path = "examples/sdk_evm.rs" required-features = ["evm-verify"] + +[package.metadata.cargo-shear] +ignored = ["derive_more", "rand"] diff --git a/crates/sdk/examples/sdk_app.rs b/crates/sdk/examples/sdk_app.rs index 31ba0ab264..0bf90efc8f 100644 --- a/crates/sdk/examples/sdk_app.rs +++ b/crates/sdk/examples/sdk_app.rs @@ -5,7 +5,7 @@ use eyre::Result; use openvm::platform::memory::MEM_SIZE; use openvm_build::GuestOptions; use openvm_sdk::{ - config::{AppConfig, SdkVmConfig}, + config::{AppConfig, SdkVmConfig, SdkVmCpuBuilder}, prover::AppProver, Sdk, StdIn, }; @@ -93,15 +93,23 @@ fn main() -> Result<(), Box> { // 8. Generate an AppProvingKey let app_pk = Arc::new(sdk.app_keygen(app_config)?); + // Choose a VmBuilder that matches the VmConfig + let builder = SdkVmCpuBuilder; // 9a. Generate a proof - let proof = sdk.generate_app_proof(app_pk.clone(), app_committed_exe.clone(), stdin.clone())?; + let proof = sdk.generate_app_proof( + builder, + app_pk.clone(), + app_committed_exe.clone(), + stdin.clone(), + )?; // 9b. Generate a proof with an AppProver with custom fields - let app_prover = AppProver::<_, BabyBearPoseidon2Engine>::new( + let mut app_prover = AppProver::::new( + builder, app_pk.app_vm_pk.clone(), app_committed_exe.clone(), - ) + )? .with_program_name("test_program"); - let proof = app_prover.generate_app_proof(stdin.clone()); + let proof = app_prover.generate_app_proof(stdin.clone())?; // ANCHOR_END: proof_generation // ANCHOR: verification diff --git a/crates/sdk/examples/sdk_evm.rs b/crates/sdk/examples/sdk_evm.rs index 8833542b73..4f2da0943f 100644 --- a/crates/sdk/examples/sdk_evm.rs +++ b/crates/sdk/examples/sdk_evm.rs @@ -6,7 +6,7 @@ use openvm::platform::memory::MEM_SIZE; use openvm_build::GuestOptions; use openvm_native_recursion::halo2::utils::CacheHalo2ParamsReader; use openvm_sdk::{ - config::{AggConfig, AppConfig, SdkVmConfig}, + config::{AggConfig, AppConfig, SdkVmConfig, SdkVmCpuBuilder}, DefaultStaticVerifierPvHandler, Sdk, StdIn, }; use openvm_stark_sdk::config::FriParameters; @@ -105,8 +105,10 @@ fn main() -> Result<(), Box> { let verifier = sdk.generate_halo2_verifier_solidity(&halo2_params_reader, &agg_pk)?; // 10. Generate an EVM proof + let builder = SdkVmCpuBuilder; let proof = sdk.generate_evm_proof( &halo2_params_reader, + builder, app_pk, app_committed_exe, agg_pk, diff --git a/crates/sdk/guest/fib/src/main.rs b/crates/sdk/guest/fib/src/main.rs index bc6d94cda8..7b65644496 100644 --- a/crates/sdk/guest/fib/src/main.rs +++ b/crates/sdk/guest/fib/src/main.rs @@ -3,15 +3,31 @@ openvm::entry!(main); -pub fn main() { - let n = core::hint::black_box(1 << 3); +fn fibonacci(n: u32) -> (u32, u32) { + if n <= 1 { + return (0, n); + } let mut a: u32 = 0; let mut b: u32 = 1; - for _ in 1..n { + for _ in 2..=n { let sum = a + b; a = b; b = sum; } + (a, b) +} + +pub fn main() { + // arbitrary n that results in more than 1 segment + let n = core::hint::black_box(1 << 5); + + let mut a = 0; + let mut b = 0; + // calculate nth fibonacci number n times + for _ in 0..n { + (a, b) = fibonacci(n); + } + if a == 0 { panic!(); } diff --git a/crates/sdk/src/codec.rs b/crates/sdk/src/codec.rs index 9d0ab48a93..c75268bbe3 100644 --- a/crates/sdk/src/codec.rs +++ b/crates/sdk/src/codec.rs @@ -1,7 +1,7 @@ use std::io::{self, Cursor, Read, Result, Write}; use openvm_circuit::{ - arch::ContinuationVmProof, system::memory::tree::public_values::UserPublicValuesProof, + arch::ContinuationVmProof, system::memory::merkle::public_values::UserPublicValuesProof, }; use openvm_continuations::verifier::{ internal::types::VmStarkProof, root::types::RootVmVerifierInput, diff --git a/crates/sdk/src/commit.rs b/crates/sdk/src/commit.rs index 53207d463b..25fecf1f59 100644 --- a/crates/sdk/src/commit.rs +++ b/crates/sdk/src/commit.rs @@ -6,7 +6,9 @@ use openvm_circuit::{ system::program::trace::VmCommittedExe, }; use openvm_native_compiler::ir::DIGEST_SIZE; -use openvm_stark_backend::{config::StarkGenericConfig, p3_field::PrimeField32}; +use openvm_stark_backend::{ + config::StarkGenericConfig, engine::StarkEngine, p3_field::PrimeField32, +}; use openvm_stark_sdk::{ config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, engine::StarkFriEngine, @@ -79,15 +81,15 @@ pub struct AppExecutionCommit { impl AppExecutionCommit { /// Users should use this function to compute `AppExecutionCommit` and check it against the /// final proof. - pub fn compute>( + pub fn compute>( app_vm_config: &VC, app_exe: &NonRootCommittedExe, leaf_vm_verifier_exe: &NonRootCommittedExe, ) -> Self { let exe_commit: [F; DIGEST_SIZE] = app_exe - .compute_exe_commit(&app_vm_config.system().memory_config) + .compute_exe_commit(&app_vm_config.as_ref().memory_config) .into(); - let vm_commit: [F; DIGEST_SIZE] = leaf_vm_verifier_exe.committed_program.commitment.into(); + let vm_commit: [F; DIGEST_SIZE] = leaf_vm_verifier_exe.commitment.into(); Self::from_field_commit(exe_commit, vm_commit) } @@ -105,7 +107,7 @@ pub fn commit_app_exe( ) -> Arc { let exe: VmExe<_> = app_exe.into(); let app_engine = BabyBearPoseidon2Engine::new(app_fri_params); - Arc::new(VmCommittedExe::::commit(exe, app_engine.config.pcs())) + Arc::new(VmCommittedExe::::commit(exe, app_engine.config().pcs())) } pub(crate) fn babybear_digest_to_bn254(digest: &[F; DIGEST_SIZE]) -> Bn254Fr { diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index faf8182246..d7c1c196c7 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -1,55 +1,54 @@ use bon::Builder; -use derive_more::derive::From; use openvm_algebra_circuit::{ - Fp2Extension, Fp2ExtensionExecutor, Fp2ExtensionPeriphery, ModularExtension, - ModularExtensionExecutor, ModularExtensionPeriphery, + AlgebraCpuProverExt, Fp2Extension, Fp2ExtensionExecutor, ModularExtension, + ModularExtensionExecutor, }; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; -use openvm_bigint_circuit::{Int256, Int256Executor, Int256Periphery}; +use openvm_bigint_circuit::{Int256, Int256CpuProverExt, Int256Executor}; use openvm_bigint_transpiler::Int256TranspilerExtension; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, SystemExecutor, SystemPeriphery, VmChipComplex, VmConfig, - VmInventoryError, + instructions::NATIVE_AS, AirInventory, AirInventoryError, ChipInventoryError, + ExecutorInventory, ExecutorInventoryError, InitFileGenerator, MatrixRecordArena, + SystemConfig, VmBuilder, VmChipComplex, VmCircuitConfig, VmExecutionConfig, + VmProverExtension, }, - circuit_derive::{Chip, ChipUsageGetter}, - derive::{AnyEnum, InstructionExecutor}, -}; -use openvm_ecc_circuit::{ - WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, + derive::VmConfig, + system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, }; +use openvm_ecc_circuit::{EccCpuProverExt, EccExtension, EccExtensionExecutor}; use openvm_ecc_transpiler::EccTranspilerExtension; -use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; +use openvm_keccak256_circuit::{Keccak256, Keccak256CpuProverExt, Keccak256Executor}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; use openvm_native_circuit::{ - CastFExtension, CastFExtensionExecutor, CastFExtensionPeriphery, Native, NativeExecutor, - NativePeriphery, + CastFExtension, CastFExtensionExecutor, Native, NativeCpuProverExt, NativeExecutor, }; use openvm_native_transpiler::LongFormTranspilerExtension; -use openvm_pairing_circuit::{ - PairingExtension, PairingExtensionExecutor, PairingExtensionPeriphery, -}; +use openvm_pairing_circuit::{PairingExtension, PairingExtensionExecutor, PairingProverExt}; use openvm_pairing_transpiler::PairingTranspilerExtension; use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; +use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; use openvm_sha256_transpiler::Sha256TranspilerExtension; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::{Field, PrimeField32}, + prover::cpu::{CpuBackend, CpuDevice}, +}; use openvm_transpiler::transpiler::Transpiler; use serde::{Deserialize, Serialize}; -use crate::F; +use crate::{config::AppConfig, F}; #[derive(Builder, Clone, Debug, Serialize, Deserialize)] +#[serde(from = "SdkVmConfigWithDefaultDeser")] pub struct SdkVmConfig { - #[serde(default)] pub system: SdkSystemConfig, - pub rv32i: Option, pub io: Option, pub keccak: Option, @@ -62,69 +61,48 @@ pub struct SdkVmConfig { pub modular: Option, pub fp2: Option, pub pairing: Option, - pub ecc: Option, + pub ecc: Option, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum SdkVmConfigExecutor { - #[any_enum] - System(SystemExecutor), - #[any_enum] - Rv32i(Rv32IExecutor), - #[any_enum] - Io(Rv32IoExecutor), - #[any_enum] - Keccak(Keccak256Executor), - #[any_enum] - Sha256(Sha256Executor), - #[any_enum] - Native(NativeExecutor), - #[any_enum] - Rv32m(Rv32MExecutor), - #[any_enum] - BigInt(Int256Executor), - #[any_enum] - Modular(ModularExtensionExecutor), - #[any_enum] - Fp2(Fp2ExtensionExecutor), - #[any_enum] - Pairing(PairingExtensionExecutor), - #[any_enum] - Ecc(WeierstrassExtensionExecutor), - #[any_enum] - CastF(CastFExtensionExecutor), -} +#[derive(Copy, Clone)] +pub struct SdkVmCpuBuilder; + +/// Internal struct to use for the VmConfig derive macro. +/// Can be obtained via [`SdkVmConfig::to_inner`]. +#[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] +pub struct SdkVmConfigInner { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension(executor = "Rv32IExecutor")] + pub rv32i: Option, + #[extension(executor = "Rv32IoExecutor")] + pub io: Option, + #[extension(executor = "Keccak256Executor")] + pub keccak: Option, + #[extension(executor = "Sha256Executor")] + pub sha256: Option, + #[extension(executor = "NativeExecutor")] + pub native: Option, + #[extension(executor = "CastFExtensionExecutor")] + pub castf: Option, -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum SdkVmConfigPeriphery { - #[any_enum] - System(SystemPeriphery), - #[any_enum] - Rv32i(Rv32IPeriphery), - #[any_enum] - Io(Rv32IoPeriphery), - #[any_enum] - Keccak(Keccak256Periphery), - #[any_enum] - Sha256(Sha256Periphery), - #[any_enum] - Native(NativePeriphery), - #[any_enum] - Rv32m(Rv32MPeriphery), - #[any_enum] - BigInt(Int256Periphery), - #[any_enum] - Modular(ModularExtensionPeriphery), - #[any_enum] - Fp2(Fp2ExtensionPeriphery), - #[any_enum] - Pairing(PairingExtensionPeriphery), - #[any_enum] - Ecc(WeierstrassExtensionPeriphery), - #[any_enum] - CastF(CastFExtensionPeriphery), + #[extension(executor = "Rv32MExecutor")] + pub rv32m: Option, + #[extension(executor = "Int256Executor")] + pub bigint: Option, + #[extension(executor = "ModularExtensionExecutor")] + pub modular: Option, + #[extension(executor = "Fp2ExtensionExecutor")] + pub fp2: Option, + #[extension(executor = "PairingExtensionExecutor")] + pub pairing: Option, + #[extension(executor = "EccExtensionExecutor")] + pub ecc: Option, } +// Generated by macro +pub type SdkVmConfigExecutor = SdkVmConfigInnerExecutor; + impl SdkVmConfig { pub fn transpiler(&self) -> Transpiler { let mut transpiler = Transpiler::default(); @@ -163,82 +141,159 @@ impl SdkVmConfig { } transpiler } -} -impl VmConfig for SdkVmConfig { - type Executor = SdkVmConfigExecutor; - type Periphery = SdkVmConfigPeriphery; + /// `openvm_toml` should be the TOML string read from an openvm.toml file. + pub fn from_toml(openvm_toml: &str) -> Result, toml::de::Error> { + toml::from_str(openvm_toml) + } +} - fn system(&self) -> &SystemConfig { +impl AsRef for SdkVmConfig { + fn as_ref(&self) -> &SystemConfig { &self.system.config } +} - fn system_mut(&mut self) -> &mut SystemConfig { +impl AsMut for SdkVmConfig { + fn as_mut(&mut self) -> &mut SystemConfig { &mut self.system.config } +} - fn create_chip_complex( +impl SdkVmConfig { + pub fn to_inner(&self) -> SdkVmConfigInner { + let system = self.system.config.clone(); + let rv32i = self.rv32i.map(|_| Rv32I); + let io = self.io.map(|_| Rv32Io); + let keccak = self.keccak.map(|_| Keccak256); + let sha256 = self.sha256.map(|_| Sha256); + let native = self.native.map(|_| Native); + let castf = self.castf.map(|_| CastFExtension); + let mut rv32m = self.rv32m; + let mut bigint = self.bigint; + if let Some(bigint) = &mut bigint { + if let Some(rv32m) = &mut rv32m { + rv32m.range_tuple_checker_sizes[0] = + rv32m.range_tuple_checker_sizes[0].max(bigint.range_tuple_checker_sizes[0]); + rv32m.range_tuple_checker_sizes[1] = + rv32m.range_tuple_checker_sizes[1].max(bigint.range_tuple_checker_sizes[1]); + bigint.range_tuple_checker_sizes = rv32m.range_tuple_checker_sizes; + } + } + let modular = self.modular.clone(); + let fp2 = self.fp2.clone(); + let pairing = self.pairing.clone(); + let ecc = self.ecc.clone(); + + SdkVmConfigInner { + system, + rv32i, + io, + keccak, + sha256, + native, + castf, + rv32m, + bigint, + modular, + fp2, + pairing, + ecc, + } + } +} + +impl VmExecutionConfig for SdkVmConfig +where + SdkVmConfigInner: VmExecutionConfig, +{ + type Executor = >::Executor; + + fn create_executors( &self, - ) -> Result, VmInventoryError> { - let mut complex = self.system.config.create_chip_complex()?.transmute(); + ) -> Result, ExecutorInventoryError> { + self.to_inner().create_executors() + } +} - if self.rv32i.is_some() { - complex = complex.extend(&Rv32I)?; +impl VmCircuitConfig for SdkVmConfig +where + SdkVmConfigInner: VmCircuitConfig, +{ + fn create_airs(&self) -> Result, AirInventoryError> { + self.to_inner().create_airs() + } +} + +impl VmBuilder for SdkVmCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = SdkVmConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &SdkVmConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let config = config.to_inner(); + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + if let Some(rv32i) = &config.rv32i { + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, rv32i, inventory)?; } - if self.io.is_some() { - complex = complex.extend(&Rv32Io)?; + if let Some(io) = &config.io { + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, io, inventory)?; } - if self.keccak.is_some() { - complex = complex.extend(&Keccak256)?; + if let Some(keccak) = &config.keccak { + VmProverExtension::::extend_prover(&Keccak256CpuProverExt, keccak, inventory)?; } - if self.sha256.is_some() { - complex = complex.extend(&Sha256)?; + if let Some(sha256) = &config.sha256 { + VmProverExtension::::extend_prover(&Sha2CpuProverExt, sha256, inventory)?; } - if self.native.is_some() { - complex = complex.extend(&Native)?; + if let Some(native) = &config.native { + VmProverExtension::::extend_prover(&NativeCpuProverExt, native, inventory)?; } - if self.castf.is_some() { - complex = complex.extend(&CastFExtension)?; + if let Some(castf) = &config.castf { + VmProverExtension::::extend_prover(&NativeCpuProverExt, castf, inventory)?; } - if let Some(rv32m) = self.rv32m { - let mut rv32m = rv32m; - if let Some(ref bigint) = self.bigint { - rv32m.range_tuple_checker_sizes[0] = - rv32m.range_tuple_checker_sizes[0].max(bigint.range_tuple_checker_sizes[0]); - rv32m.range_tuple_checker_sizes[1] = - rv32m.range_tuple_checker_sizes[1].max(bigint.range_tuple_checker_sizes[1]); - } - complex = complex.extend(&rv32m)?; + if let Some(rv32m) = &config.rv32m { + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, rv32m, inventory)?; } - if let Some(bigint) = self.bigint { - let mut bigint = bigint; - if let Some(ref rv32m) = self.rv32m { - bigint.range_tuple_checker_sizes[0] = - rv32m.range_tuple_checker_sizes[0].max(bigint.range_tuple_checker_sizes[0]); - bigint.range_tuple_checker_sizes[1] = - rv32m.range_tuple_checker_sizes[1].max(bigint.range_tuple_checker_sizes[1]); - } - complex = complex.extend(&bigint)?; + if let Some(bigint) = &config.bigint { + VmProverExtension::::extend_prover(&Int256CpuProverExt, bigint, inventory)?; } - if let Some(ref modular) = self.modular { - complex = complex.extend(modular)?; + if let Some(modular) = &config.modular { + VmProverExtension::::extend_prover(&AlgebraCpuProverExt, modular, inventory)?; } - if let Some(ref fp2) = self.fp2 { - complex = complex.extend(fp2)?; + if let Some(fp2) = &config.fp2 { + VmProverExtension::::extend_prover(&AlgebraCpuProverExt, fp2, inventory)?; } - if let Some(ref pairing) = self.pairing { - complex = complex.extend(pairing)?; + if let Some(pairing) = &config.pairing { + VmProverExtension::::extend_prover(&PairingProverExt, pairing, inventory)?; } - if let Some(ref ecc) = self.ecc { - complex = complex.extend(ecc)?; + if let Some(ecc) = &config.ecc { + VmProverExtension::::extend_prover(&EccCpuProverExt, ecc, inventory)?; } - - Ok(complex) + Ok(chip_complex) } } impl InitFileGenerator for SdkVmConfig { + fn generate_init_file_contents(&self) -> Option { + self.to_inner().generate_init_file_contents() + } +} +impl InitFileGenerator for SdkVmConfigInner { fn generate_init_file_contents(&self) -> Option { if self.modular.is_some() || self.fp2.is_some() || self.ecc.is_some() { let mut contents = String::new(); @@ -262,7 +317,7 @@ impl InitFileGenerator for SdkVmConfig { } if let Some(ecc_config) = &self.ecc { - contents.push_str(&ecc_config.generate_sw_init()); + contents.push_str(&ecc_config.generate_ecc_init()); contents.push('\n'); } @@ -335,3 +390,49 @@ impl From for UnitStruct { UnitStruct {} } } + +#[derive(Deserialize)] +struct SdkVmConfigWithDefaultDeser { + #[serde(default)] + pub system: SdkSystemConfig, + + pub rv32i: Option, + pub io: Option, + pub keccak: Option, + pub sha256: Option, + pub native: Option, + pub castf: Option, + + pub rv32m: Option, + pub bigint: Option, + pub modular: Option, + pub fp2: Option, + pub pairing: Option, + pub ecc: Option, +} + +impl From for SdkVmConfig { + fn from(config: SdkVmConfigWithDefaultDeser) -> Self { + let mut system = config.system; + if config.native.is_none() && config.castf.is_none() { + // There should be no need to write to native address space if Native extension and + // CastF extension are not enabled. + system.config.memory_config.addr_spaces[NATIVE_AS as usize].num_cells = 0; + } + Self { + system, + rv32i: config.rv32i, + io: config.io, + keccak: config.keccak, + sha256: config.sha256, + native: config.native, + castf: config.castf, + rv32m: config.rv32m, + bigint: config.bigint, + modular: config.modular, + fp2: config.fp2, + pairing: config.pairing, + ecc: config.ecc, + } + } +} diff --git a/crates/sdk/src/config/mod.rs b/crates/sdk/src/config/mod.rs index 3a231f180d..18498b9ed1 100644 --- a/crates/sdk/src/config/mod.rs +++ b/crates/sdk/src/config/mod.rs @@ -33,7 +33,7 @@ pub struct AppConfig { pub compiler_options: CompilerOptions, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct AggConfig { /// STARK aggregation config pub agg_stark_config: AggStarkConfig, @@ -55,7 +55,7 @@ pub struct AggStarkConfig { pub root_max_constraint_degree: usize, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Halo2Config { /// Log degree for the outer recursion verifier circuit. pub verifier_k: usize, @@ -196,7 +196,7 @@ impl From for LeafFriParams { } } -const SBOX_SIZE: usize = 7; +pub const SBOX_SIZE: usize = 7; impl AggStarkConfig { pub fn leaf_vm_config(&self) -> NativeConfig { diff --git a/crates/sdk/src/fs.rs b/crates/sdk/src/fs.rs index e795eebc86..37ada7c482 100644 --- a/crates/sdk/src/fs.rs +++ b/crates/sdk/src/fs.rs @@ -4,7 +4,7 @@ use std::{ }; use eyre::{Report, Result}; -use openvm_circuit::arch::{instructions::exe::VmExe, ContinuationVmProof, VmConfig}; +use openvm_circuit::arch::{instructions::exe::VmExe, ContinuationVmProof}; use openvm_continuations::verifier::root::types::RootVmVerifierInput; #[cfg(feature = "evm-prove")] use openvm_native_recursion::halo2::wrapper::EvmVerifierByteCode; @@ -35,13 +35,13 @@ pub fn write_exe_to_file>(exe: VmExe, path: P) -> Result<()> { write_to_file_bitcode(&path, exe) } -pub fn read_app_pk_from_file, P: AsRef>( +pub fn read_app_pk_from_file>( path: P, ) -> Result> { read_from_file_bitcode(&path) } -pub fn write_app_pk_to_file, P: AsRef>( +pub fn write_app_pk_to_file>( app_pk: AppProvingKey, path: P, ) -> Result<()> { diff --git a/crates/sdk/src/keygen/dummy.rs b/crates/sdk/src/keygen/dummy.rs index 3fe2bcd300..838dfc6aa9 100644 --- a/crates/sdk/src/keygen/dummy.rs +++ b/crates/sdk/src/keygen/dummy.rs @@ -3,11 +3,11 @@ use std::sync::Arc; use openvm_circuit::{ arch::{ instructions::{ - exe::VmExe, instruction::Instruction, program::Program, LocalOpcode, - SystemOpcode::TERMINATE, + instruction::Instruction, program::Program, LocalOpcode, SystemOpcode::TERMINATE, }, - ContinuationVmProof, SingleSegmentVmExecutor, VirtualMachine, VmComplexTraceHeights, - VmConfig, VmExecutor, + ContinuationVmProof, Executor, MatrixRecordArena, MeteredExecutor, + PreflightExecutionOutput, PreflightExecutor, SingleSegmentVmProver, SystemConfig, + VirtualMachine, VirtualMachineError, VmBuilder, VmExecutionConfig, PUBLIC_VALUES_AIR_ID, }, system::program::trace::VmCommittedExe, utils::next_power_of_two_or_zero, @@ -17,73 +17,99 @@ use openvm_continuations::verifier::{ leaf::{types::LeafVmVerifierInput, LeafVmVerifierConfig}, root::types::RootVmVerifierInput, }; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_compiler::ir::DIGEST_SIZE; use openvm_native_recursion::hints::Hintable; -use openvm_rv32im_circuit::Rv32ImConfig; +use openvm_rv32im_circuit::{Rv32ImConfig, Rv32ImCpuBuilder}; +use openvm_stark_backend::{ + p3_matrix::dense::RowMajorMatrix, + prover::{ + cpu::CpuBackend, + types::{AirProvingContext, ProvingContext}, + }, +}; use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::BabyBearPoseidon2Engine, - fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, - }, - engine::StarkFriEngine, - openvm_stark_backend::{ - config::StarkGenericConfig, p3_field::FieldAlgebra, proof::Proof, Chip, + baby_bear_poseidon2_root::{BabyBearPoseidon2RootConfig, BabyBearPoseidon2RootEngine}, + fri_params::standard_fri_params_with_100_bits_conjectured_security, + FriParameters, }, + engine::{StarkEngine, StarkFriEngine}, + openvm_stark_backend::{config::StarkGenericConfig, p3_field::FieldAlgebra, proof::Proof}, }; use crate::{ - prover::vm::{ - local::VmLocalProver, types::VmProvingKey, ContinuationVmProver, SingleSegmentVmProver, - }, + prover::vm::{new_local_prover, types::VmProvingKey}, NonRootCommittedExe, F, SC, }; +/// Given a dummy internal proof, which is the input to the root verifier circuit, we will run +/// tracegen on the root verifier circuit to determine the trace heights. These trace heights will +/// become the fixed trace heights that we **force** the root verifier circuit's trace matrices to +/// have. +/// /// Returns: /// - trace heights ordered by AIR ID -/// - internal ordering of trace heights. /// /// All trace heights are rounded to the next power of two (or 0 -> 0). pub(super) fn compute_root_proof_heights( - root_vm_config: NativeConfig, - root_exe: VmExe, + root_vm: &mut VirtualMachine, + root_committed_exe: &VmCommittedExe, dummy_internal_proof: &Proof, -) -> (Vec, VmComplexTraceHeights) { - let num_user_public_values = root_vm_config.system.num_public_values - 2 * DIGEST_SIZE; +) -> Result, VirtualMachineError> { + let num_public_values = root_vm.config().as_ref().num_public_values; + let num_user_public_values = num_public_values - 2 * DIGEST_SIZE; let root_input = RootVmVerifierInput { proofs: vec![dummy_internal_proof.clone()], public_values: vec![F::ZERO; num_user_public_values], }; - let vm = SingleSegmentVmExecutor::new(root_vm_config); - let res = vm - .execute_and_compute_heights(root_exe, root_input.write()) - .unwrap(); - let air_heights: Vec<_> = res - .air_heights + // The following is the same as impl SingleSegmentVmProver for VmLocalProver except we stop + // after tracegen: + let mut trace_heights = NATIVE_MAX_TRACE_HEIGHTS.to_vec(); + trace_heights[PUBLIC_VALUES_AIR_ID] = num_public_values as u32; + let state = root_vm.create_initial_state(&root_committed_exe.exe, root_input.write()); + let cached_program_trace = root_vm.transport_committed_exe_to_device(root_committed_exe); + root_vm.load_program(cached_program_trace); + root_vm.transport_init_memory_to_device(&state.memory); + let PreflightExecutionOutput { + system_records, + record_arenas, + .. + } = root_vm.execute_preflight(&root_committed_exe.exe, state, None, &trace_heights)?; + let ctx = root_vm.generate_proving_ctx(system_records, record_arenas)?; + let air_heights = ctx .into_iter() - .map(next_power_of_two_or_zero) + .map(|(_, air_ctx)| { + next_power_of_two_or_zero(air_ctx.main_trace_height()) + .try_into() + .unwrap() + }) .collect(); - let mut vm_heights = res.vm_heights; - vm_heights.round_to_next_power_of_two_or_zero(); - (air_heights, vm_heights) + Ok(air_heights) } pub(super) fn dummy_internal_proof( internal_vm_pk: Arc>, internal_exe: Arc, leaf_proof: Proof, -) -> Proof { +) -> Result, VirtualMachineError> { let mut internal_inputs = InternalVmVerifierInput::chunk_leaf_or_internal_proofs( internal_exe.get_program_commit().into(), &[leaf_proof], 1, ); let internal_input = internal_inputs.pop().unwrap(); - let internal_prover = VmLocalProver::::new( - internal_vm_pk, - internal_exe, - ); - SingleSegmentVmProver::prove(&internal_prover, internal_input.write()) + let mut internal_prover = new_local_prover::( + NativeCpuBuilder, + &internal_vm_pk, + &internal_exe, + )?; + SingleSegmentVmProver::prove( + &mut internal_prover, + internal_input.write(), + NATIVE_MAX_TRACE_HEIGHTS, + ) } pub(super) fn dummy_internal_proof_riscv_app_vm( @@ -91,44 +117,33 @@ pub(super) fn dummy_internal_proof_riscv_app_vm( internal_vm_pk: Arc>, internal_exe: Arc, num_public_values: usize, -) -> Proof { +) -> Result, VirtualMachineError> { let fri_params = standard_fri_params_with_100_bits_conjectured_security(1); - let leaf_proof = dummy_leaf_proof_riscv_app_vm(leaf_vm_pk, num_public_values, fri_params); + let leaf_proof = dummy_leaf_proof_riscv_app_vm(leaf_vm_pk, num_public_values, fri_params)?; dummy_internal_proof(internal_vm_pk, internal_exe, leaf_proof) } -#[allow(dead_code)] -pub fn dummy_leaf_proof>( - leaf_vm_pk: Arc>, - app_vm_pk: Arc>, - overridden_heights: Option, -) -> Proof -where - VC::Executor: Chip, - VC::Periphery: Chip, -{ - let app_proof = dummy_app_proof_impl(app_vm_pk.clone(), overridden_heights); - dummy_leaf_proof_impl(leaf_vm_pk, app_vm_pk, &app_proof) -} - pub(super) fn dummy_leaf_proof_riscv_app_vm( leaf_vm_pk: Arc>, num_public_values: usize, app_fri_params: FriParameters, -) -> Proof { - let app_vm_pk = Arc::new(dummy_riscv_app_vm_pk(num_public_values, app_fri_params)); - let app_proof = dummy_app_proof_impl(app_vm_pk.clone(), None); - dummy_leaf_proof_impl(leaf_vm_pk, app_vm_pk, &app_proof) +) -> Result, VirtualMachineError> { + let app_vm_pk = Arc::new(dummy_riscv_app_vm_pk(num_public_values, app_fri_params)?); + let app_proof = dummy_app_proof(Rv32ImCpuBuilder, app_vm_pk.clone())?; + dummy_leaf_proof(leaf_vm_pk, app_vm_pk, &app_proof) } -fn dummy_leaf_proof_impl>( +fn dummy_leaf_proof( leaf_vm_pk: Arc>, app_vm_pk: Arc>, app_proof: &ContinuationVmProof, -) -> Proof { +) -> Result, VirtualMachineError> +where + VC: AsRef, +{ let leaf_program = LeafVmVerifierConfig { app_fri_params: app_vm_pk.fri_params, - app_system_config: app_vm_pk.vm_config.system().clone(), + app_system_config: app_vm_pk.vm_config.as_ref().clone(), compiler_options: Default::default(), } .build_program(&app_vm_pk.vm_pk.get_vk()); @@ -140,71 +155,71 @@ fn dummy_leaf_proof_impl>( let e = BabyBearPoseidon2Engine::new(leaf_vm_pk.fri_params); let leaf_exe = Arc::new(VmCommittedExe::::commit( leaf_program.into(), - e.config.pcs(), + e.config().pcs(), )); - let leaf_prover = - VmLocalProver::::new(leaf_vm_pk, leaf_exe); + let mut leaf_prover = + new_local_prover::(NativeCpuBuilder, &leaf_vm_pk, &leaf_exe)?; let mut leaf_inputs = LeafVmVerifierInput::chunk_continuation_vm_proof(app_proof, 1); let leaf_input = leaf_inputs.pop().unwrap(); - SingleSegmentVmProver::prove(&leaf_prover, leaf_input.write_to_stream()) + SingleSegmentVmProver::prove( + &mut leaf_prover, + leaf_input.write_to_stream(), + NATIVE_MAX_TRACE_HEIGHTS, + ) } fn dummy_riscv_app_vm_pk( num_public_values: usize, fri_params: FriParameters, -) -> VmProvingKey { +) -> Result, VirtualMachineError> { let vm_config = Rv32ImConfig::with_public_values(num_public_values); - let vm = VirtualMachine::new(BabyBearPoseidon2Engine::new(fri_params), vm_config.clone()); - let vm_pk = vm.keygen(); - VmProvingKey { + let (_, vm_pk) = VirtualMachine::new_with_keygen( + BabyBearPoseidon2Engine::new(fri_params), + Rv32ImCpuBuilder, + vm_config.clone(), + )?; + Ok(VmProvingKey { fri_params, vm_config, vm_pk, - } + }) } -fn dummy_app_proof_impl>( +fn dummy_app_proof( + app_vm_builder: VB, app_vm_pk: Arc>, - overridden_heights: Option, -) -> ContinuationVmProof +) -> Result, VirtualMachineError> where - VC::Executor: Chip, - VC::Periphery: Chip, + VB: VmBuilder>, + VC: VmExecutionConfig, + >::Executor: Executor + MeteredExecutor + PreflightExecutor, { let fri_params = app_vm_pk.fri_params; let dummy_exe = dummy_app_committed_exe(fri_params); - // Enforce each AIR to have at least 1 row. - let overridden_heights = if let Some(overridden_heights) = overridden_heights { - overridden_heights - } else { - // We first execute once to get the trace heights from dummy_exe, then pad to powers of 2 - // (forcing trace height 0 to 1) - let executor = VmExecutor::new(app_vm_pk.vm_config.clone()); - let mut results = executor - .execute_segments(dummy_exe.exe.clone(), vec![]) - .unwrap(); - // ASSUMPTION: the dummy exe has only 1 segment - assert_eq!(results.len(), 1, "dummy exe should have only 1 segment"); - let mut result = results.pop().unwrap(); - result.chip_complex.finalize_memory(); - let mut vm_heights = result.chip_complex.get_internal_trace_heights(); - vm_heights.round_to_next_power_of_two(); - vm_heights + let mut app_prover = + new_local_prover::(app_vm_builder, &app_vm_pk, &dummy_exe)?; + // Force all AIRs to have non-empty trace matrices (height 0 -> height 1) + let modify_ctx = |_seg_idx: usize, ctx: &mut ProvingContext>| { + for (i, pk) in app_vm_pk.vm_pk.per_air.iter().enumerate() { + let width = pk.vk.params.width.common_main; + if ctx.per_air[i].0 != i { + let dummy_trace = RowMajorMatrix::new_row(F::zero_vec(width)); + let dummy_ctx = AirProvingContext::simple_no_pis(Arc::new(dummy_trace)); + ctx.per_air.insert(i, (i, dummy_ctx)); + } + } }; - // For the dummy proof, we must override the trace heights. - let app_prover = - VmLocalProver::::new_with_overridden_trace_heights( - app_vm_pk, - dummy_exe, - Some(overridden_heights), - ); - ContinuationVmProver::prove(&app_prover, vec![]) + let dummy_proof = app_prover.prove_continuations(vec![], modify_ctx)?; + Ok(dummy_proof) } fn dummy_app_committed_exe(fri_params: FriParameters) -> Arc { let program = dummy_app_program(); let e = BabyBearPoseidon2Engine::new(fri_params); - Arc::new(VmCommittedExe::::commit(program.into(), e.config.pcs())) + Arc::new(VmCommittedExe::::commit( + program.into(), + e.config().pcs(), + )) } fn dummy_app_program() -> Program { diff --git a/crates/sdk/src/keygen/mod.rs b/crates/sdk/src/keygen/mod.rs index 0806cc6f3d..0744b171ab 100644 --- a/crates/sdk/src/keygen/mod.rs +++ b/crates/sdk/src/keygen/mod.rs @@ -1,18 +1,19 @@ use std::sync::Arc; use derivative::Derivative; -use dummy::{compute_root_proof_heights, dummy_internal_proof_riscv_app_vm}; +// use dummy::{compute_root_proof_heights, dummy_internal_proof_riscv_app_vm}; use openvm_circuit::{ - arch::{VirtualMachine, VmComplexTraceHeights, VmConfig}, + arch::{AirInventoryError, SystemConfig, VirtualMachine, VirtualMachineError, VmCircuitConfig}, system::{memory::dimensions::MemoryDimensions, program::trace::VmCommittedExe}, }; use openvm_continuations::verifier::{ internal::InternalVmVerifierConfig, leaf::LeafVmVerifierConfig, root::RootVmVerifierConfig, }; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder}; use openvm_native_compiler::ir::DIGEST_SIZE; use openvm_stark_backend::{ config::Val, + engine::StarkEngine, p3_field::{FieldExtensionAlgebra, PrimeField32, TwoAdicField}, }; use openvm_stark_sdk::{ @@ -25,12 +26,11 @@ use openvm_stark_sdk::{ config::{Com, StarkGenericConfig}, keygen::types::MultiStarkVerifyingKey, proof::Proof, - Chip, }, p3_bn254_fr::Bn254Fr, }; use serde::{Deserialize, Serialize}; -use tracing::info_span; +use tracing::{info_span, instrument}; #[cfg(feature = "evm-prove")] use { crate::config::AggConfig, @@ -44,7 +44,10 @@ use { use crate::{ commit::babybear_digest_to_bn254, config::{AggStarkConfig, AppConfig}, - keygen::perm::AirIdPermutation, + keygen::{ + dummy::{compute_root_proof_heights, dummy_internal_proof_riscv_app_vm}, + perm::AirIdPermutation, + }, prover::vm::types::VmProvingKey, NonRootCommittedExe, RootSC, F, SC, }; @@ -97,16 +100,14 @@ pub struct Halo2ProvingKey { pub profiling: bool, } -impl> AppProvingKey +impl AppProvingKey where - VC::Executor: Chip, - VC::Periphery: Chip, + VC: Clone + VmCircuitConfig + AsRef, { - pub fn keygen(config: AppConfig) -> Self { + pub fn keygen(config: AppConfig) -> Result { let app_engine = BabyBearPoseidon2Engine::new(config.app_fri_params.fri_params); let app_vm_pk = { - let vm = VirtualMachine::new(app_engine, config.app_vm_config.clone()); - let vm_pk = vm.keygen(); + let vm_pk = config.app_vm_config.create_airs()?.keygen(&app_engine); assert!( vm_pk.max_constraint_degree <= config.app_fri_params.fri_params.max_constraint_degree() @@ -126,24 +127,24 @@ where let leaf_engine = BabyBearPoseidon2Engine::new(config.leaf_fri_params.fri_params); let leaf_program = LeafVmVerifierConfig { app_fri_params: config.app_fri_params.fri_params, - app_system_config: config.app_vm_config.system().clone(), + app_system_config: config.app_vm_config.as_ref().clone(), compiler_options: config.compiler_options, } .build_program(&app_vm_pk.vm_pk.get_vk()); Arc::new(VmCommittedExe::commit( leaf_program.into(), - leaf_engine.config.pcs(), + leaf_engine.config().pcs(), )) }; - Self { + Ok(Self { leaf_committed_exe, leaf_fri_params: config.leaf_fri_params.fri_params, app_vm_pk: Arc::new(app_vm_pk), - } + }) } pub fn num_public_values(&self) -> usize { - self.app_vm_pk.vm_config.system().num_public_values + self.app_vm_pk.vm_config.as_ref().num_public_values } pub fn get_app_vk(&self) -> AppVerifyingKey { @@ -153,7 +154,7 @@ where memory_dimensions: self .app_vm_pk .vm_config - .system() + .as_ref() .memory_config .memory_dimensions(), } @@ -256,27 +257,37 @@ fn check_recursive_verifier_size( } impl AggStarkProvingKey { - pub fn keygen(config: AggStarkConfig) -> Self { - tracing::info_span!("agg_stark_keygen", group = "agg_stark_keygen") - .in_scope(|| Self::dummy_proof_and_keygen(config).0) + #[instrument( + name = "agg_stark_keygen", + fields(group = "agg_stark_keygen"), + skip_all + )] + pub fn keygen(config: AggStarkConfig) -> Result { + let (pk, _) = Self::dummy_proof_and_keygen(config)?; + Ok(pk) } - pub fn dummy_proof_and_keygen(config: AggStarkConfig) -> (Self, Proof) { + fn dummy_proof_and_keygen( + config: AggStarkConfig, + ) -> Result<(Self, Proof), VirtualMachineError> { let leaf_vm_config = config.leaf_vm_config(); let internal_vm_config = config.internal_vm_config(); let root_vm_config = config.root_verifier_vm_config(); let leaf_engine = BabyBearPoseidon2Engine::new(config.leaf_fri_params); - let leaf_vm_pk = Arc::new({ - let vm = VirtualMachine::new(leaf_engine, leaf_vm_config.clone()); - let vm_pk = vm.keygen(); + let leaf_vm_pk = { + let (_, vm_pk) = VirtualMachine::new_with_keygen( + leaf_engine, + NativeCpuBuilder, + leaf_vm_config.clone(), + )?; assert!(vm_pk.max_constraint_degree <= config.leaf_fri_params.max_constraint_degree()); - VmProvingKey { + Arc::new(VmProvingKey { fri_params: config.leaf_fri_params, vm_config: leaf_vm_config, vm_pk, - } - }); + }) + }; let leaf_vm_vk = leaf_vm_pk.vm_pk.get_vk(); check_recursive_verifier_size( &leaf_vm_vk, @@ -285,17 +296,16 @@ impl AggStarkProvingKey { ); let internal_engine = BabyBearPoseidon2Engine::new(config.internal_fri_params); - let internal_vm = VirtualMachine::new(internal_engine, internal_vm_config.clone()); - let internal_vm_pk = Arc::new({ - let vm_pk = internal_vm.keygen(); - assert!( - vm_pk.max_constraint_degree <= config.internal_fri_params.max_constraint_degree() - ); - VmProvingKey { - fri_params: config.internal_fri_params, - vm_config: internal_vm_config, - vm_pk, - } + let (internal_vm, vm_pk) = VirtualMachine::new_with_keygen( + internal_engine, + NativeCpuBuilder, + internal_vm_config.clone(), + )?; + assert!(vm_pk.max_constraint_degree <= config.internal_fri_params.max_constraint_degree()); + let internal_vm_pk = Arc::new(VmProvingKey { + fri_params: config.internal_fri_params, + vm_config: internal_vm_config, + vm_pk, }); let internal_vm_vk = internal_vm_pk.vm_pk.get_vk(); check_recursive_verifier_size( @@ -310,17 +320,14 @@ impl AggStarkProvingKey { compiler_options: config.compiler_options, } .build_program(&leaf_vm_vk, &internal_vm_vk); - let internal_committed_exe = Arc::new(VmCommittedExe::::commit( - internal_program.into(), - internal_vm.engine.config.pcs(), - )); + let internal_committed_exe = Arc::new(internal_vm.commit_exe(internal_program)); let internal_proof = dummy_internal_proof_riscv_app_vm( leaf_vm_pk.clone(), internal_vm_pk.clone(), internal_committed_exe.clone(), config.max_num_user_public_values, - ); + )?; let root_verifier_pk = { let mut root_engine = BabyBearPoseidon2RootEngine::new(config.root_fri_params); @@ -333,22 +340,24 @@ impl AggStarkProvingKey { compiler_options: config.compiler_options, } .build_program(&leaf_vm_vk, &internal_vm_vk); - let root_committed_exe = Arc::new(VmCommittedExe::::commit( - root_program.into(), - root_engine.config.pcs(), - )); + let (mut vm, mut vm_pk) = VirtualMachine::new_with_keygen( + root_engine, + NativeCpuBuilder, + root_vm_config.clone(), + )?; + let root_committed_exe = Arc::new(vm.commit_exe(root_program)); - let vm = VirtualMachine::new(root_engine, root_vm_config.clone()); - let mut vm_pk = vm.keygen(); assert!(vm_pk.max_constraint_degree <= config.root_fri_params.max_constraint_degree()); - let (air_heights, vm_heights) = compute_root_proof_heights( - root_vm_config.clone(), - root_committed_exe.exe.clone(), - &internal_proof, - ); + let air_heights = + compute_root_proof_heights(&mut vm, &root_committed_exe, &internal_proof)?; let root_air_perm = AirIdPermutation::compute(&air_heights); + // ATTENTION: make sure to permute everything in vm_pk that references the original AIR + // ID ordering: root_air_perm.permute(&mut vm_pk.per_air); + for thc in &mut vm_pk.trace_height_constraints { + root_air_perm.permute(&mut thc.coefficients); + } RootVerifierProvingKey { vm_pk: Arc::new(VmProvingKey { @@ -358,10 +367,9 @@ impl AggStarkProvingKey { }), root_committed_exe, air_heights, - vm_heights, } }; - ( + Ok(( Self { leaf_vm_pk, internal_vm_pk, @@ -369,7 +377,7 @@ impl AggStarkProvingKey { root_verifier_pk, }, internal_proof, - ) + )) } pub fn internal_program_commit(&self) -> [F; DIGEST_SIZE] { @@ -389,7 +397,7 @@ impl AggStarkProvingKey { /// Proving key for the root verifier. /// Properties: /// - Traces heights of each AIR is constant. This is required by the static verifier. -/// - Instead of the AIR order specified by VC. AIRs are ordered by trace heights. +/// - Instead of the AIR order specified by VmConfig. AIRs are ordered by trace heights. #[derive(Serialize, Deserialize, Derivative)] #[derivative(Clone(bound = "Com: Clone"))] pub struct RootVerifierProvingKey { @@ -400,14 +408,13 @@ pub struct RootVerifierProvingKey { pub vm_pk: Arc>, /// Committed executable for the root VM. pub root_committed_exe: Arc>, - /// The constant trace heights, ordered by AIR ID. - pub air_heights: Vec, - /// The constant trace heights in a semantic way for VM. - pub vm_heights: VmComplexTraceHeights, + /// The constant trace heights, ordered by AIR ID (the original ordering from VmConfig). + pub air_heights: Vec, } +#[cfg(feature = "evm-prove")] impl RootVerifierProvingKey { - pub fn air_id_permutation(&self) -> AirIdPermutation { + pub(crate) fn air_id_permutation(&self) -> AirIdPermutation { AirIdPermutation::compute(&self.air_heights) } } @@ -422,16 +429,16 @@ impl AggProvingKey { config: AggConfig, reader: &impl Halo2ParamsReader, pv_handler: &impl StaticVerifierPvHandler, - ) -> Self { + ) -> Result { let AggConfig { agg_stark_config, halo2_config, } = config; let (agg_stark_pk, dummy_internal_proof) = - AggStarkProvingKey::dummy_proof_and_keygen(agg_stark_config); + AggStarkProvingKey::dummy_proof_and_keygen(agg_stark_config)?; let dummy_root_proof = agg_stark_pk .root_verifier_pk - .generate_dummy_root_proof(dummy_internal_proof); + .generate_dummy_root_proof(dummy_internal_proof)?; let verifier = agg_stark_pk.root_verifier_pk.keygen_static_verifier( &reader.read_params(halo2_config.verifier_k), dummy_root_proof, @@ -448,23 +455,26 @@ impl AggProvingKey { wrapper, profiling: halo2_config.profiling, }; - Self { + Ok(Self { agg_stark_pk, halo2_pk, - } + }) } } pub fn leaf_keygen( fri_params: FriParameters, leaf_vm_config: NativeConfig, -) -> Arc> { +) -> Result>, AirInventoryError> { let leaf_engine = BabyBearPoseidon2Engine::new(fri_params); - let leaf_vm_pk = info_span!("keygen", group = "leaf") - .in_scope(|| VirtualMachine::new(leaf_engine, leaf_vm_config.clone()).keygen()); - Arc::new(VmProvingKey { + let leaf_vm_pk = info_span!("keygen", group = "leaf").in_scope(|| { + leaf_vm_config + .create_airs() + .map(|airs| airs.keygen(&leaf_engine)) + })?; + Ok(Arc::new(VmProvingKey { fri_params, vm_config: leaf_vm_config, vm_pk: leaf_vm_pk, - }) + })) } diff --git a/crates/sdk/src/keygen/perm.rs b/crates/sdk/src/keygen/perm.rs index 18d76b4958..f52d1ebd12 100644 --- a/crates/sdk/src/keygen/perm.rs +++ b/crates/sdk/src/keygen/perm.rs @@ -1,14 +1,15 @@ use std::cmp::Reverse; -use openvm_circuit::arch::{CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PUBLIC_VALUES_AIR_ID}; +#[cfg(feature = "evm-prove")] use openvm_continuations::verifier::common::types::SpecialAirIds; -pub struct AirIdPermutation { +/// Permutation of the AIR IDs to order them by forced trace heights. +pub(crate) struct AirIdPermutation { pub perm: Vec, } impl AirIdPermutation { - pub fn compute(heights: &[usize]) -> AirIdPermutation { + pub fn compute(heights: &[u32]) -> AirIdPermutation { let mut height_with_air_id: Vec<_> = heights.iter().copied().enumerate().collect(); height_with_air_id.sort_by_key(|(_, h)| Reverse(*h)); AirIdPermutation { @@ -18,7 +19,10 @@ impl AirIdPermutation { .collect(), } } + #[cfg(feature = "evm-prove")] pub fn get_special_air_ids(&self) -> SpecialAirIds { + use openvm_circuit::arch::{CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PUBLIC_VALUES_AIR_ID}; + let perm_len = self.perm.len(); let mut ret = SpecialAirIds { program_air_id: perm_len, diff --git a/crates/sdk/src/keygen/static_verifier.rs b/crates/sdk/src/keygen/static_verifier.rs index fd8e75a67d..31d5c1918d 100644 --- a/crates/sdk/src/keygen/static_verifier.rs +++ b/crates/sdk/src/keygen/static_verifier.rs @@ -1,7 +1,9 @@ +use openvm_circuit::arch::{SingleSegmentVmProver, VirtualMachineError}; use openvm_continuations::{ static_verifier::{StaticVerifierConfig, StaticVerifierPvHandler}, verifier::root::types::RootVmVerifierInput, }; +use openvm_native_circuit::NATIVE_MAX_TRACE_HEIGHTS; use openvm_native_compiler::prelude::*; use openvm_native_recursion::{ halo2::{verifier::Halo2VerifierProvingKey, Halo2Params, Halo2Prover}, @@ -10,11 +12,7 @@ use openvm_native_recursion::{ }; use openvm_stark_sdk::openvm_stark_backend::{p3_field::FieldAlgebra, proof::Proof}; -use crate::{ - keygen::RootVerifierProvingKey, - prover::{vm::SingleSegmentVmProver, RootVerifierLocalProver}, - RootSC, F, SC, -}; +use crate::{keygen::RootVerifierProvingKey, prover::RootVerifierLocalProver, RootSC, F, SC}; impl RootVerifierProvingKey { /// Keygen the static verifier for this root verifier. @@ -43,23 +41,21 @@ impl RootVerifierProvingKey { } } - pub fn generate_dummy_root_proof(&self, dummy_internal_proof: Proof) -> Proof { - let prover = RootVerifierLocalProver::new(self.clone()); + pub fn generate_dummy_root_proof( + &self, + dummy_internal_proof: Proof, + ) -> Result, VirtualMachineError> { + let mut prover = RootVerifierLocalProver::new(self.clone())?; // 2 * DIGEST_SIZE for exe_commit and leaf_commit - let num_public_values = prover - .root_verifier_pk - .vm_pk - .vm_config - .system - .num_public_values - - 2 * DIGEST_SIZE; + let num_public_values = prover.vm_config().as_ref().num_public_values - 2 * DIGEST_SIZE; SingleSegmentVmProver::prove( - &prover, + &mut prover, RootVmVerifierInput { proofs: vec![dummy_internal_proof], public_values: vec![F::ZERO; num_public_values], } .write(), + NATIVE_MAX_TRACE_HEIGHTS, ) } } diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index c2c874d3f1..fd868ce95d 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -5,6 +5,7 @@ use alloy_sol_types::sol; use commit::{commit_app_exe, AppExecutionCommit}; use config::{AggregationTreeConfig, AppConfig}; use eyre::Result; +use getset::{Getters, WithSetters}; use keygen::{AppProvingKey, AppVerifyingKey}; use openvm_build::{ build_guest_package, find_unique_executable, get_package, GuestOptions, TargetFilter, @@ -13,12 +14,13 @@ use openvm_circuit::{ arch::{ hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, instructions::exe::VmExe, - verify_segments, ContinuationVmProof, ExecutionError, InitFileGenerator, - VerifiedExecutionPayload, VmConfig, VmExecutor, CONNECTOR_AIR_ID, PROGRAM_AIR_ID, + verify_segments, ContinuationVmProof, Executor, InitFileGenerator, MeteredExecutor, + PreflightExecutor, SystemConfig, VerifiedExecutionPayload, VmBuilder, VmCircuitConfig, + VmExecutionConfig, VmExecutor, CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, PUBLIC_VALUES_AIR_ID, }, system::{ - memory::{tree::public_values::extract_public_values, CHUNK}, + memory::{merkle::public_values::extract_public_values, CHUNK}, program::trace::{compute_exe_commit, VmCommittedExe}, }, }; @@ -29,17 +31,17 @@ pub use openvm_continuations::static_verifier::{ use openvm_continuations::verifier::{ common::types::VmVerifierPvs, internal::types::{InternalVmVerifierPvs, VmStarkProof}, - root::{types::RootVmVerifierInput, RootVmVerifierConfig}, + root::RootVmVerifierConfig, }; // Re-exports: pub use openvm_continuations::{RootSC, C, F, SC}; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder}; #[cfg(feature = "evm-prove")] use openvm_native_recursion::halo2::utils::Halo2ParamsReader; use openvm_stark_backend::proof::Proof; use openvm_stark_sdk::{ config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, engine::StarkFriEngine, - openvm_stark_backend::Chip, p3_bn254_fr::Bn254Fr, }; use openvm_transpiler::{ @@ -103,36 +105,45 @@ pub struct VerifiedContinuationVmPayload { pub user_public_values: Vec, } -pub struct GenericSdk> { +// The SDK is only generic in the engine for the non-root SC. The root SC is fixed to +// BabyBearPoseidon2RootEngine right now. +#[derive(Getters, WithSetters)] +pub struct GenericSdk { + #[getset(get = "pub", set_with = "pub")] agg_tree_config: AggregationTreeConfig, + #[getset(get = "pub")] + native_builder: NativeBuilder, _phantom: PhantomData, } -impl> Default for GenericSdk { +pub type Sdk = GenericSdk; + +impl Default for GenericSdk +where + NativeBuilder: Default, +{ fn default() -> Self { Self { agg_tree_config: AggregationTreeConfig::default(), + native_builder: NativeBuilder::default(), _phantom: PhantomData, } } } -pub type Sdk = GenericSdk; - -impl> GenericSdk { +// The SDK is only functional for SC = BabyBearPoseidon2Config because that is what recursive +// aggregation supports. +impl GenericSdk +where + E: StarkFriEngine, + NativeBuilder: VmBuilder + Clone + Default, + >::Executor: + PreflightExecutor>::RecordArena>, +{ pub fn new() -> Self { Self::default() } - pub fn with_agg_tree_config(mut self, agg_tree_config: AggregationTreeConfig) -> Self { - self.agg_tree_config = agg_tree_config; - self - } - - pub fn agg_tree_config(&self) -> &AggregationTreeConfig { - &self.agg_tree_config - } - pub fn build>( &self, guest_opts: GuestOptions, @@ -168,22 +179,18 @@ impl> GenericSdk { VmExe::from_elf(elf, transpiler) } - pub fn execute>( - &self, - exe: VmExe, - vm_config: VC, - inputs: StdIn, - ) -> Result, ExecutionError> + /// Returns the user public values as field elements. + pub fn execute(&self, exe: VmExe, vm_config: VC, inputs: StdIn) -> Result> where - VC::Executor: Chip, - VC::Periphery: Chip, + VC: VmExecutionConfig + AsRef + Clone, + VC::Executor: Clone + Executor + MeteredExecutor, { - let vm = VmExecutor::new(vm_config); - let final_memory = vm.execute(exe, inputs)?; + let executor = VmExecutor::new(vm_config)?; + let instance = executor.instance(&exe)?; + let final_memory = instance.execute(inputs, None)?.memory; let public_values = extract_public_values( - &vm.config.system().memory_config.memory_dimensions(), - vm.config.system().num_public_values, - final_memory.as_ref().unwrap(), + executor.config.as_ref().num_public_values, + &final_memory.memory, ); Ok(public_values) } @@ -197,27 +204,30 @@ impl> GenericSdk { Ok(committed_exe) } - pub fn app_keygen>(&self, config: AppConfig) -> Result> + pub fn app_keygen(&self, config: AppConfig) -> Result> where - VC::Executor: Chip, - VC::Periphery: Chip, + VC: Clone + VmCircuitConfig + AsRef, { - let app_pk = AppProvingKey::keygen(config); + let app_pk = AppProvingKey::keygen(config)?; Ok(app_pk) } - pub fn generate_app_proof>( + pub fn generate_app_proof( &self, - app_pk: Arc>, + app_vm_builder: VB, + app_pk: Arc>, app_committed_exe: Arc, inputs: StdIn, ) -> Result> where - VC::Executor: Chip, - VC::Periphery: Chip, + VB: VmBuilder, + VB::VmConfig: VmExecutionConfig + VmCircuitConfig, + >::Executor: + Executor + MeteredExecutor + PreflightExecutor, { - let app_prover = AppProver::::new(app_pk.app_vm_pk.clone(), app_committed_exe); - let proof = app_prover.generate_app_proof(inputs); + let mut app_prover = + AppProver::::new(app_vm_builder, app_pk.app_vm_pk.clone(), app_committed_exe)?; + let proof = app_prover.generate_app_proof(inputs)?; Ok(proof) } @@ -268,12 +278,12 @@ impl> GenericSdk { reader: &impl Halo2ParamsReader, pv_handler: &impl StaticVerifierPvHandler, ) -> Result { - let agg_pk = AggProvingKey::keygen(config, reader, pv_handler); + let agg_pk = AggProvingKey::keygen(config, reader, pv_handler)?; Ok(agg_pk) } pub fn agg_stark_keygen(&self, config: AggStarkConfig) -> Result { - let agg_pk = AggStarkProvingKey::keygen(config); + let agg_pk = AggStarkProvingKey::keygen(config)?; Ok(agg_pk) } @@ -295,37 +305,29 @@ impl> GenericSdk { program_to_asm(kernel_asm) } - pub fn generate_root_verifier_input>( - &self, - app_pk: Arc>, - app_exe: Arc, - agg_stark_pk: AggStarkProvingKey, - inputs: StdIn, - ) -> Result> - where - VC::Executor: Chip, - VC::Periphery: Chip, - { - let stark_prover = - StarkProver::::new(app_pk, app_exe, agg_stark_pk, self.agg_tree_config); - let proof = stark_prover.generate_root_verifier_input(inputs); - Ok(proof) - } - - pub fn generate_e2e_stark_proof>( + pub fn generate_e2e_stark_proof( &self, - app_pk: Arc>, + app_vm_builder: VB, + app_pk: Arc>, app_exe: Arc, agg_stark_pk: AggStarkProvingKey, inputs: StdIn, ) -> Result> where - VC::Executor: Chip, - VC::Periphery: Chip, + VB: VmBuilder, + >::Executor: Executor + + MeteredExecutor + + PreflightExecutor>::RecordArena>, { - let stark_prover = - StarkProver::::new(app_pk, app_exe, agg_stark_pk, self.agg_tree_config); - let proof = stark_prover.generate_e2e_stark_proof(inputs); + let mut stark_prover = StarkProver::::new( + app_vm_builder, + self.native_builder.clone(), + app_pk, + app_exe, + agg_stark_pk, + self.agg_tree_config, + )?; + let proof = stark_prover.generate_e2e_stark_proof(inputs)?; Ok(proof) } @@ -430,21 +432,31 @@ impl> GenericSdk { } #[cfg(feature = "evm-prove")] - pub fn generate_evm_proof>( + pub fn generate_evm_proof( &self, reader: &impl Halo2ParamsReader, - app_pk: Arc>, + app_vm_builder: VB, + app_pk: Arc>, app_exe: Arc, agg_pk: AggProvingKey, inputs: StdIn, ) -> Result where - VC::Executor: Chip, - VC::Periphery: Chip, + VB: VmBuilder, + >::Executor: Executor + + MeteredExecutor + + PreflightExecutor>::RecordArena>, { - let e2e_prover = - EvmHalo2Prover::::new(reader, app_pk, app_exe, agg_pk, self.agg_tree_config); - let proof = e2e_prover.generate_proof_for_evm(inputs); + let mut e2e_prover = EvmHalo2Prover::::new( + reader, + app_vm_builder, + self.native_builder.clone(), + app_pk, + app_exe, + agg_pk, + self.agg_tree_config, + )?; + let proof = e2e_prover.generate_proof_for_evm(inputs)?; Ok(proof) } diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index aa8fc843cb..34b1ca1310 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -1,34 +1,38 @@ use std::sync::Arc; -use openvm_circuit::arch::ContinuationVmProof; +use openvm_circuit::arch::{ + ContinuationVmProof, PreflightExecutor, SingleSegmentVmProver, VirtualMachineError, VmBuilder, + VmExecutionConfig, VmLocalProver, +}; +#[cfg(feature = "evm-prove")] +use openvm_continuations::verifier::root::types::RootVmVerifierInput; use openvm_continuations::verifier::{ internal::types::{InternalVmVerifierInput, VmStarkProof}, leaf::types::LeafVmVerifierInput, - root::types::RootVmVerifierInput, }; -use openvm_native_circuit::NativeConfig; -use openvm_native_compiler::ir::DIGEST_SIZE; +use openvm_native_circuit::{NativeConfig, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::{engine::StarkFriEngine, openvm_stark_backend::proof::Proof}; -use tracing::info_span; +use tracing::{info_span, instrument}; use crate::{ - config::AggregationTreeConfig, - keygen::AggStarkProvingKey, - prover::{ - vm::{local::VmLocalProver, SingleSegmentVmProver}, - RootVerifierLocalProver, - }, - NonRootCommittedExe, RootSC, F, SC, + config::AggregationTreeConfig, keygen::AggStarkProvingKey, prover::vm::new_local_prover, + NonRootCommittedExe, F, SC, }; +#[cfg(feature = "evm-prove")] +use crate::{prover::RootVerifierLocalProver, RootSC}; -pub struct AggStarkProver> { - leaf_prover: VmLocalProver, +pub struct AggStarkProver +where + E: StarkFriEngine, + NativeBuilder: VmBuilder, +{ + leaf_prover: VmLocalProver, leaf_controller: LeafProvingController, - internal_prover: VmLocalProver, + internal_prover: VmLocalProver, + #[cfg(feature = "evm-prove")] root_prover: RootVerifierLocalProver, - pub num_children_internal: usize, pub max_internal_wrapper_layers: usize, } @@ -38,30 +42,43 @@ pub struct LeafProvingController { pub num_children: usize, } -impl> AggStarkProver { +impl AggStarkProver +where + E: StarkFriEngine, + NativeBuilder: VmBuilder + Clone, + >::Executor: + PreflightExecutor>::RecordArena>, +{ pub fn new( + native_builder: NativeBuilder, agg_stark_pk: AggStarkProvingKey, leaf_committed_exe: Arc, tree_config: AggregationTreeConfig, - ) -> Self { - let leaf_prover = - VmLocalProver::::new(agg_stark_pk.leaf_vm_pk, leaf_committed_exe); + ) -> Result { + let leaf_prover = new_local_prover( + native_builder.clone(), + &agg_stark_pk.leaf_vm_pk, + &leaf_committed_exe, + )?; let leaf_controller = LeafProvingController { num_children: tree_config.num_children_leaf, }; - let internal_prover = VmLocalProver::::new( - agg_stark_pk.internal_vm_pk, - agg_stark_pk.internal_committed_exe, - ); - let root_prover = RootVerifierLocalProver::new(agg_stark_pk.root_verifier_pk); - Self { + let internal_prover = new_local_prover( + native_builder, + &agg_stark_pk.internal_vm_pk, + &agg_stark_pk.internal_committed_exe, + )?; + #[cfg(feature = "evm-prove")] + let root_prover = RootVerifierLocalProver::new(agg_stark_pk.root_verifier_pk)?; + Ok(Self { leaf_prover, leaf_controller, internal_prover, + #[cfg(feature = "evm-prove")] root_prover, num_children_internal: tree_config.num_children_internal, max_internal_wrapper_layers: tree_config.max_internal_wrapper_layers, - } + }) } pub fn with_num_children_leaf(mut self, num_children_leaf: usize) -> Self { @@ -80,31 +97,41 @@ impl> AggStarkProver { } /// Generate the root proof for outer recursion. - pub fn generate_root_proof(&self, app_proofs: ContinuationVmProof) -> Proof { - let root_verifier_input = self.generate_root_verifier_input(app_proofs); + #[cfg(feature = "evm-prove")] + pub fn generate_root_proof( + &mut self, + app_proofs: ContinuationVmProof, + ) -> Result, VirtualMachineError> { + let root_verifier_input = self.generate_root_verifier_input(app_proofs)?; self.generate_root_proof_impl(root_verifier_input) } - pub fn generate_leaf_proofs(&self, app_proofs: &ContinuationVmProof) -> Vec> { + pub fn generate_leaf_proofs( + &mut self, + app_proofs: &ContinuationVmProof, + ) -> Result>, VirtualMachineError> { self.leaf_controller - .generate_proof(&self.leaf_prover, app_proofs) + .generate_proof(&mut self.leaf_prover, app_proofs) } + /// This is typically only used for the halo2 verifier. + #[cfg(feature = "evm-prove")] pub fn generate_root_verifier_input( - &self, + &mut self, app_proofs: ContinuationVmProof, - ) -> RootVmVerifierInput { - let leaf_proofs = self.generate_leaf_proofs(&app_proofs); + ) -> Result, VirtualMachineError> { + let leaf_proofs = self.generate_leaf_proofs(&app_proofs)?; let public_values = app_proofs.user_public_values.public_values; - let e2e_stark_proof = self.aggregate_leaf_proofs(leaf_proofs, public_values); - self.wrap_e2e_stark_proof(e2e_stark_proof) + let e2e_stark_proof = self.aggregate_leaf_proofs(leaf_proofs, public_values)?; + let wrapped_stark_proof = self.wrap_e2e_stark_proof(e2e_stark_proof)?; + Ok(wrapped_stark_proof) } pub fn aggregate_leaf_proofs( - &self, + &mut self, leaf_proofs: Vec>, public_values: Vec, - ) -> VmStarkProof { + ) -> Result, VirtualMachineError> { let mut internal_node_idx = -1; let mut internal_node_height = 0; let mut proofs = leaf_proofs; @@ -112,10 +139,7 @@ impl> AggStarkProver { // proof, in order to shrink the proof size while proofs.len() > 1 || internal_node_height == 0 { let internal_inputs = InternalVmVerifierInput::chunk_leaf_or_internal_proofs( - self.internal_prover - .committed_exe - .get_program_commit() - .into(), + (*self.internal_prover.exe_commitment()).into(), &proofs, self.num_children_internal, ); @@ -124,10 +148,10 @@ impl> AggStarkProver { group = format!("internal.{internal_node_height}") ) .in_scope(|| { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] { metrics::counter!("fri.log_blowup") - .absolute(self.internal_prover.fri_params().log_blowup as u64); + .absolute(self.internal_prover.vm.engine.fri_params().log_blowup as u64); metrics::counter!("num_children").absolute(self.num_children_internal as u64); } internal_inputs @@ -135,47 +159,97 @@ impl> AggStarkProver { .map(|input| { internal_node_idx += 1; info_span!("single_internal_agg", idx = internal_node_idx,).in_scope(|| { - SingleSegmentVmProver::prove(&self.internal_prover, input.write()) + SingleSegmentVmProver::prove( + &mut self.internal_prover, + input.write(), + NATIVE_MAX_TRACE_HEIGHTS, + ) }) }) - .collect() - }); + .collect::, _>>() + })?; internal_node_height += 1; } - VmStarkProof { - proof: proofs.pop().unwrap(), + let proof = proofs.pop().unwrap(); + Ok(VmStarkProof { + proof, user_public_values: public_values, - } + }) } /// Wrap the e2e stark proof until its heights meet the requirements of the root verifier. - pub fn wrap_e2e_stark_proof( - &self, + #[cfg(feature = "evm-prove")] + fn wrap_e2e_stark_proof( + &mut self, e2e_stark_proof: VmStarkProof, - ) -> RootVmVerifierInput { - let internal_commit = self - .internal_prover - .committed_exe - .get_program_commit() - .into(); - wrap_e2e_stark_proof( - &self.internal_prover, - &self.root_prover, - internal_commit, - self.max_internal_wrapper_layers, - e2e_stark_proof, - ) - } + ) -> Result, VirtualMachineError> { + let internal_commit = (*self.internal_prover.exe_commitment()).into(); + let internal_prover = &mut self.internal_prover; + let root_prover = &mut self.root_prover; + let max_internal_wrapper_layers = self.max_internal_wrapper_layers; + fn heights_le(a: &[u32], b: &[u32]) -> bool { + assert_eq!(a.len(), b.len()); + a.iter().zip(b.iter()).all(|(a, b)| a <= b) + } - fn generate_root_proof_impl(&self, root_input: RootVmVerifierInput) -> Proof { - info_span!("agg_layer", group = "root", idx = 0).in_scope(|| { - let input = root_input.write(); - #[cfg(feature = "bench-metrics")] - metrics::counter!("fri.log_blowup") - .absolute(self.root_prover.fri_params().log_blowup as u64); - SingleSegmentVmProver::prove(&self.root_prover, input) + let VmStarkProof { + mut proof, + user_public_values, + } = e2e_stark_proof; + let mut wrapper_layers = 0; + loop { + let input = RootVmVerifierInput { + proofs: vec![proof.clone()], + public_values: user_public_values.clone(), + }; + let actual_air_heights = root_prover.execute_for_air_heights(input)?; + // Root verifier can handle the internal proof. We can stop here. + if heights_le(&actual_air_heights, root_prover.fixed_air_heights()) { + break; + } + if wrapper_layers >= max_internal_wrapper_layers { + panic!("The heights of the root verifier still exceed the required heights after {} wrapper layers", max_internal_wrapper_layers); + } + wrapper_layers += 1; + let input = InternalVmVerifierInput { + self_program_commit: internal_commit, + proofs: vec![proof.clone()], + }; + proof = info_span!( + "wrapper_layer", + group = format!("internal_wrapper.{wrapper_layers}") + ) + .in_scope(|| { + #[cfg(feature = "metrics")] + { + metrics::counter!("fri.log_blowup") + .absolute(internal_prover.vm.engine.fri_params().log_blowup as u64); + } + SingleSegmentVmProver::prove( + internal_prover, + input.write(), + NATIVE_MAX_TRACE_HEIGHTS, + ) + })?; + } + Ok(RootVmVerifierInput { + proofs: vec![proof], + public_values: user_public_values, }) } + + #[cfg(feature = "evm-prove")] + #[instrument(name = "agg_layer", skip_all, fields(group = "root", idx = 0))] + fn generate_root_proof_impl( + &mut self, + root_input: RootVmVerifierInput, + ) -> Result, VirtualMachineError> { + let input = root_input.write(); + #[cfg(feature = "metrics")] + metrics::counter!("fri.log_blowup") + .absolute(self.root_prover.fri_params().log_blowup as u64); + SingleSegmentVmProver::prove(&mut self.root_prover, input, NATIVE_MAX_TRACE_HEIGHTS) + } } impl LeafProvingController { @@ -184,85 +258,39 @@ impl LeafProvingController { self } - pub fn generate_proof>( + #[instrument(name = "agg_layer", skip_all, fields(group = "leaf"))] + pub fn generate_proof( &self, - prover: &VmLocalProver, + prover: &mut VmLocalProver, app_proofs: &ContinuationVmProof, - ) -> Vec> { - info_span!("agg_layer", group = "leaf").in_scope(|| { - #[cfg(feature = "bench-metrics")] - { - metrics::counter!("fri.log_blowup").absolute(prover.fri_params().log_blowup as u64); - metrics::counter!("num_children").absolute(self.num_children as u64); - } - let leaf_inputs = - LeafVmVerifierInput::chunk_continuation_vm_proof(app_proofs, self.num_children); - tracing::info!("num_leaf_proofs={}", leaf_inputs.len()); - leaf_inputs - .into_iter() - .enumerate() - .map(|(leaf_node_idx, input)| { - info_span!("single_leaf_agg", idx = leaf_node_idx) - .in_scope(|| SingleSegmentVmProver::prove(prover, input.write_to_stream())) - }) - .collect::>() - }) - } -} - -/// Wrap the e2e stark proof until its heights meet the requirements of the root verifier. -pub fn wrap_e2e_stark_proof>( - internal_prover: &VmLocalProver, - root_prover: &RootVerifierLocalProver, - internal_commit: [F; DIGEST_SIZE], - max_internal_wrapper_layers: usize, - e2e_stark_proof: VmStarkProof, -) -> RootVmVerifierInput { - let VmStarkProof { - mut proof, - user_public_values, - } = e2e_stark_proof; - let mut wrapper_layers = 0; - loop { - let actual_air_heights = root_prover.execute_for_air_heights(RootVmVerifierInput { - proofs: vec![proof.clone()], - public_values: user_public_values.clone(), - }); - // Root verifier can handle the internal proof. We can stop here. - if heights_le( - &actual_air_heights, - &root_prover.root_verifier_pk.air_heights, - ) { - break; - } - if wrapper_layers >= max_internal_wrapper_layers { - panic!("The heights of the root verifier still exceed the required heights after {} wrapper layers", max_internal_wrapper_layers); + ) -> Result>, VirtualMachineError> + where + E: StarkFriEngine, + NativeBuilder: VmBuilder, + >::Executor: + PreflightExecutor>::RecordArena>, + { + #[cfg(feature = "metrics")] + { + metrics::counter!("fri.log_blowup") + .absolute(prover.vm.engine.fri_params().log_blowup as u64); + metrics::counter!("num_children").absolute(self.num_children as u64); } - wrapper_layers += 1; - let input = InternalVmVerifierInput { - self_program_commit: internal_commit, - proofs: vec![proof.clone()], - }; - proof = info_span!( - "wrapper_layer", - group = format!("internal_wrapper.{wrapper_layers}") - ) - .in_scope(|| { - #[cfg(feature = "bench-metrics")] - { - metrics::counter!("fri.log_blowup") - .absolute(internal_prover.fri_params().log_blowup as u64); - } - SingleSegmentVmProver::prove(internal_prover, input.write()) - }); - } - RootVmVerifierInput { - proofs: vec![proof], - public_values: user_public_values, + let leaf_inputs = + LeafVmVerifierInput::chunk_continuation_vm_proof(app_proofs, self.num_children); + tracing::info!("num_leaf_proofs={}", leaf_inputs.len()); + leaf_inputs + .into_iter() + .enumerate() + .map(|(leaf_node_idx, input)| { + info_span!("single_leaf_agg", idx = leaf_node_idx).in_scope(|| { + SingleSegmentVmProver::prove( + prover, + input.write_to_stream(), + NATIVE_MAX_TRACE_HEIGHTS, + ) + }) + }) + .collect() } } - -fn heights_le(a: &[usize], b: &[usize]) -> bool { - assert_eq!(a.len(), b.len()); - a.iter().zip(b.iter()).all(|(a, b)| a <= b) -} diff --git a/crates/sdk/src/prover/app.rs b/crates/sdk/src/prover/app.rs index 095351677e..9bd7b4821f 100644 --- a/crates/sdk/src/prover/app.rs +++ b/crates/sdk/src/prover/app.rs @@ -1,36 +1,60 @@ use std::sync::Arc; use getset::Getters; -use openvm_circuit::arch::{ContinuationVmProof, VmConfig}; -use openvm_stark_backend::{proof::Proof, Chip}; -use openvm_stark_sdk::engine::StarkFriEngine; +use openvm_circuit::{ + arch::{ + verify_segments, ContinuationVmProof, ContinuationVmProver, Executor, MeteredExecutor, + PreflightExecutor, SingleSegmentVmProver, VirtualMachineError, VmBuilder, + VmExecutionConfig, VmLocalProver, + }, + system::{memory::CHUNK, program::trace::VmCommittedExe}, +}; +use openvm_stark_backend::{ + config::{Com, Val}, + keygen::types::MultiStarkVerifyingKey, + p3_field::PrimeField32, + proof::Proof, +}; +use openvm_stark_sdk::engine::{StarkEngine, StarkFriEngine}; use tracing::info_span; -use super::vm::SingleSegmentVmProver; use crate::{ - prover::vm::{local::VmLocalProver, types::VmProvingKey, ContinuationVmProver}, - NonRootCommittedExe, StdIn, F, SC, + prover::vm::{new_local_prover, types::VmProvingKey}, + StdIn, }; #[derive(Getters)] -pub struct AppProver> { +pub struct AppProver +where + E: StarkEngine, + VB: VmBuilder, +{ pub program_name: Option, #[getset(get = "pub")] - app_prover: VmLocalProver, + app_prover: VmLocalProver, + #[getset(get = "pub")] + app_vm_vk: MultiStarkVerifyingKey, } -impl> AppProver { +impl AppProver +where + E: StarkFriEngine, + VB: VmBuilder, + Val: PrimeField32, + Com: AsRef<[Val; CHUNK]>, +{ pub fn new( - app_vm_pk: Arc>, - app_committed_exe: Arc, - ) -> Self - where - VC: VmConfig, - { - Self { + vm_builder: VB, + app_vm_pk: Arc>, + app_committed_exe: Arc>, + ) -> Result { + let app_prover = new_local_prover(vm_builder, &app_vm_pk, &app_committed_exe)?; + let app_vm_vk = app_vm_pk.vm_pk.get_vk(); + Ok(Self { program_name: None, - app_prover: VmLocalProver::::new(app_vm_pk, app_committed_exe), - } + app_prover, + app_vm_vk, + }) } pub fn set_program_name(&mut self, program_name: impl AsRef) -> &mut Self { self.program_name = Some(program_name.as_ref().to_string()); @@ -42,17 +66,20 @@ impl> AppProver { } /// Generates proof for every continuation segment - pub fn generate_app_proof(&self, input: StdIn) -> ContinuationVmProof + pub fn generate_app_proof( + &mut self, + input: StdIn>, + ) -> Result, VirtualMachineError> where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + >>::Executor: Executor> + + MeteredExecutor> + + PreflightExecutor, VB::RecordArena>, { assert!( - self.vm_config().system().continuation_enabled, + self.vm_config().as_ref().continuation_enabled, "Use generate_app_proof_without_continuations instead." ); - info_span!( + let proofs = info_span!( "app proof", group = self .program_name @@ -60,21 +87,32 @@ impl> AppProver { .unwrap_or(&"app_proof".to_string()) ) .in_scope(|| { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] metrics::counter!("fri.log_blowup") - .absolute(self.app_prover.pk.fri_params.log_blowup as u64); - ContinuationVmProver::prove(&self.app_prover, input) - }) + .absolute(self.app_prover.vm.engine.fri_params().log_blowup as u64); + ContinuationVmProver::prove(&mut self.app_prover, input) + })?; + // We skip verification of the user public values proof here because it is directly computed + // from the merkle tree above + verify_segments( + &self.app_prover.vm.engine, + &self.app_vm_vk, + &proofs.per_segment, + )?; + Ok(proofs) } - pub fn generate_app_proof_without_continuations(&self, input: StdIn) -> Proof + pub fn generate_app_proof_without_continuations( + &mut self, + input: StdIn>, + trace_heights: &[u32], + ) -> Result, VirtualMachineError> where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + >>::Executor: + PreflightExecutor, VB::RecordArena>, { assert!( - !self.vm_config().system().continuation_enabled, + !self.vm_config().as_ref().continuation_enabled, "Use generate_app_proof instead." ); info_span!( @@ -85,15 +123,15 @@ impl> AppProver { .unwrap_or(&"app_proof".to_string()) ) .in_scope(|| { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] metrics::counter!("fri.log_blowup") - .absolute(self.app_prover.pk.fri_params.log_blowup as u64); - SingleSegmentVmProver::prove(&self.app_prover, input) + .absolute(self.app_prover.vm.engine.fri_params().log_blowup as u64); + SingleSegmentVmProver::prove(&mut self.app_prover, input, trace_heights) }) } /// App VM config - pub fn vm_config(&self) -> &VC { - self.app_prover.vm_config() + pub fn vm_config(&self) -> &VB::VmConfig { + self.app_prover.vm.config() } } diff --git a/crates/sdk/src/prover/mod.rs b/crates/sdk/src/prover/mod.rs index 67ccfe1eb8..589b75fafa 100644 --- a/crates/sdk/src/prover/mod.rs +++ b/crates/sdk/src/prover/mod.rs @@ -2,6 +2,7 @@ mod agg; mod app; #[cfg(feature = "evm-prove")] mod halo2; +#[cfg(feature = "evm-prove")] mod root; mod stark; pub mod vm; @@ -12,6 +13,7 @@ pub use app::*; pub use evm::*; #[cfg(feature = "evm-prove")] pub use halo2::*; +#[cfg(feature = "evm-prove")] pub use root::*; pub use stark::*; @@ -19,9 +21,13 @@ pub use stark::*; mod evm { use std::sync::Arc; - use openvm_circuit::arch::VmConfig; + use openvm_circuit::arch::{ + Executor, MeteredExecutor, PreflightExecutor, VirtualMachineError, VmBuilder, + VmExecutionConfig, + }; + use openvm_native_circuit::NativeConfig; use openvm_native_recursion::halo2::utils::Halo2ParamsReader; - use openvm_stark_sdk::{engine::StarkFriEngine, openvm_stark_backend::Chip}; + use openvm_stark_sdk::engine::StarkFriEngine; use super::{Halo2Prover, StarkProver}; use crate::{ @@ -32,32 +38,52 @@ mod evm { NonRootCommittedExe, F, SC, }; - pub struct EvmHalo2Prover> { - pub stark_prover: StarkProver, + pub struct EvmHalo2Prover + where + E: StarkFriEngine, + VB: VmBuilder, + NativeBuilder: VmBuilder, + { + pub stark_prover: StarkProver, pub halo2_prover: Halo2Prover, } - impl> EvmHalo2Prover { + impl EvmHalo2Prover + where + E: StarkFriEngine, + VB: VmBuilder, + >::Executor: Executor + + MeteredExecutor + + PreflightExecutor>::RecordArena>, + NativeBuilder: VmBuilder + Clone, + >::Executor: + PreflightExecutor>::RecordArena>, + { pub fn new( reader: &impl Halo2ParamsReader, - app_pk: Arc>, + app_vm_builder: VB, + native_builder: NativeBuilder, + app_pk: Arc>, app_committed_exe: Arc, agg_pk: AggProvingKey, agg_tree_config: AggregationTreeConfig, - ) -> Self - where - VC: VmConfig, - { + ) -> Result { let AggProvingKey { agg_stark_pk, halo2_pk, } = agg_pk; - let stark_prover = - StarkProver::new(app_pk, app_committed_exe, agg_stark_pk, agg_tree_config); - Self { + let stark_prover = StarkProver::new( + app_vm_builder, + native_builder, + app_pk, + app_committed_exe, + agg_stark_pk, + agg_tree_config, + )?; + Ok(Self { stark_prover, halo2_prover: Halo2Prover::new(reader, halo2_pk), - } + }) } pub fn set_program_name(&mut self, program_name: impl AsRef) -> &mut Self { @@ -65,14 +91,15 @@ mod evm { self } - pub fn generate_proof_for_evm(&self, input: StdIn) -> EvmProof - where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, - { - let root_proof = self.stark_prover.generate_proof_for_outer_recursion(input); - self.halo2_prover.prove_for_evm(&root_proof) + pub fn generate_proof_for_evm( + &mut self, + input: StdIn, + ) -> Result { + let root_proof = self + .stark_prover + .generate_proof_for_outer_recursion(input)?; + let evm_proof = self.halo2_prover.prove_for_evm(&root_proof); + Ok(evm_proof) } } } diff --git a/crates/sdk/src/prover/root.rs b/crates/sdk/src/prover/root.rs index 6e69aa0f13..df3aaceb6f 100644 --- a/crates/sdk/src/prover/root.rs +++ b/crates/sdk/src/prover/root.rs @@ -1,89 +1,185 @@ -use async_trait::async_trait; -use openvm_circuit::arch::{SingleSegmentVmExecutor, Streams}; +use getset::Getters; +use itertools::zip_eq; +use openvm_circuit::arch::{ + GenerationError, PreflightExecutionOutput, SingleSegmentVmProver, Streams, VirtualMachine, + VirtualMachineError, VmLocalProver, +}; use openvm_continuations::verifier::root::types::RootVmVerifierInput; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::{ config::{baby_bear_poseidon2_root::BabyBearPoseidon2RootEngine, FriParameters}, - engine::{StarkEngine, StarkFriEngine}, + engine::StarkEngine, openvm_stark_backend::proof::Proof, }; use crate::{ - keygen::RootVerifierProvingKey, - prover::vm::{AsyncSingleSegmentVmProver, SingleSegmentVmProver}, + keygen::{perm::AirIdPermutation, RootVerifierProvingKey}, + prover::vm::new_local_prover, RootSC, F, SC, }; /// Local prover for a root verifier. +#[derive(Getters)] pub struct RootVerifierLocalProver { - pub root_verifier_pk: RootVerifierProvingKey, - executor_for_heights: SingleSegmentVmExecutor, + /// The proving key in `inner` should always have ordering of AIRs in the sorted order by fixed + /// trace heights outside of the `prove` function. + // This is CPU-only for now because it uses RootSC + inner: VmLocalProver, + /// The constant trace heights, ordered by AIR ID (the original ordering from VmConfig). + #[getset(get = "pub")] + fixed_air_heights: Vec, + air_id_perm: AirIdPermutation, + air_id_inv_perm: AirIdPermutation, } impl RootVerifierLocalProver { - pub fn new(root_verifier_pk: RootVerifierProvingKey) -> Self { - let executor_for_heights = - SingleSegmentVmExecutor::::new(root_verifier_pk.vm_pk.vm_config.clone()); - Self { - root_verifier_pk, - executor_for_heights, + pub fn new(root_verifier_pk: RootVerifierProvingKey) -> Result { + let inner = new_local_prover( + NativeCpuBuilder, + &root_verifier_pk.vm_pk, + &root_verifier_pk.root_committed_exe, + )?; + let fixed_air_heights = root_verifier_pk.air_heights; + let air_id_perm = AirIdPermutation::compute(&fixed_air_heights); + let mut inverse_perm = vec![0usize; air_id_perm.perm.len()]; + for (i, &perm_i) in air_id_perm.perm.iter().enumerate() { + inverse_perm[perm_i] = i; } - } - pub fn execute_for_air_heights(&self, input: RootVmVerifierInput) -> Vec { - let result = self - .executor_for_heights - .execute_and_compute_heights( - self.root_verifier_pk.root_committed_exe.exe.clone(), - input.write(), - ) - .unwrap(); - result.air_heights + let air_id_inv_perm = AirIdPermutation { perm: inverse_perm }; + + Ok(Self { + inner, + fixed_air_heights, + air_id_perm, + air_id_inv_perm, + }) } pub fn vm_config(&self) -> &NativeConfig { - &self.root_verifier_pk.vm_pk.vm_config + self.inner.vm.config() } #[allow(dead_code)] pub(crate) fn fri_params(&self) -> &FriParameters { - &self.root_verifier_pk.vm_pk.fri_params + &self.inner.vm.engine.fri_params + } + + pub fn execute_for_air_heights( + &mut self, + input: RootVmVerifierInput, + ) -> Result, VirtualMachineError> { + let exe = self.inner.exe().clone(); + // See `SingleSegmentVmProver::prove` for explanation + let vm = &mut self.inner.vm; + Self::permute_pk(vm, &self.air_id_inv_perm); + assert!(!vm.config().as_ref().continuation_enabled); + let input = input.write(); + let state = vm.create_initial_state(&exe, input); + vm.transport_init_memory_to_device(&state.memory); + let PreflightExecutionOutput { + system_records, + record_arenas, + .. + } = vm.execute_preflight(&exe, state, None, NATIVE_MAX_TRACE_HEIGHTS)?; + // Note[jpw]: we could in theory extract trace heights from just preflight execution, but + // that requires special logic in the chips so we will just generate the traces for now + let ctx = vm.generate_proving_ctx(system_records, record_arenas)?; + let air_heights = ctx + .per_air + .iter() + .map(|(_, air_ctx)| air_ctx.main_trace_height() as u32) + .collect(); + Self::permute_pk(vm, &self.air_id_perm); + Ok(air_heights) + } + + // ATTENTION: this must exactly match the permutation done in + // `AggStarkProvingKey::dummy_proof_and_keygen` except on DeviceMultiStarkProvingKey. + fn permute_pk( + vm: &mut VirtualMachine, + perm: &AirIdPermutation, + ) { + perm.permute(&mut vm.pk_mut().per_air); + for thc in &mut vm.pk_mut().trace_height_constraints { + perm.permute(&mut thc.coefficients); + } } } impl SingleSegmentVmProver for RootVerifierLocalProver { - fn prove(&self, input: impl Into>) -> Proof { + // @dev: If this implementation is generalized to prover backends not using MatrixRecordArena, + // then it must be ensured that: + // - the Native extension chips can ensure that, if the record arenas have + // `force_matrix_dimensions()` set, then the record arena capacity heights must equal the + // trace matrix heights. + // - any chips that do not use record arenas (currently system memory chips) have a way to force + // trace heights as well. We currently use the fact that all non-system periphery chips have + // fixed height (in particular, there is no Poseidon2PeripheryChip). + fn prove( + &mut self, + input: impl Into>, + _: &[u32], + ) -> Result, VirtualMachineError> { + assert!(!self.vm_config().as_ref().continuation_enabled); + // The following is unrolled from SingleSegmentVmProver for VmLocalProver and + // VirtualMachine::prove to add special logic around ensuring trace heights are fixed and + // then reordering the trace matrices so the heights are sorted. let input = input.into(); - let mut vm = SingleSegmentVmExecutor::new(self.vm_config().clone()); - vm.set_override_trace_heights(self.root_verifier_pk.vm_heights.clone()); - let mut proof_input = vm - .execute_and_generate(self.root_verifier_pk.root_committed_exe.clone(), input) - .unwrap(); - assert_eq!( - proof_input.per_air.len(), - self.root_verifier_pk.air_heights.len(), - "All AIRs of root verifier should present" - ); - proof_input.per_air.iter().for_each(|(air_id, input)| { - assert_eq!( - input.main_trace_height(), - self.root_verifier_pk.air_heights[*air_id], - "Trace height doesn't match" - ); - }); - // Reorder the AIRs by heights. - let air_id_perm = self.root_verifier_pk.air_id_permutation(); - air_id_perm.permute(&mut proof_input.per_air); - for i in 0..proof_input.per_air.len() { - // Overwrite the AIR ID. - proof_input.per_air[i].0 = i; + let exe = self.inner.exe().clone(); + let vm = &mut self.inner.vm; + // The root_verifier_pk has the AIRs ordered by the fixed AIR height sorted ordering, but + // execute_preflight and generate_proving_ctx still expect the original AIR ID ordering from + // VmConfig, so we apply the inverse permutation here, and then undo it after tracegen. This + // could maybe be replaced by only changing `executor_idx_to_air_idx`, but applying the + // permutation is conceptually simpler to track. + Self::permute_pk(vm, &self.air_id_inv_perm); + assert!(!vm.config().as_ref().continuation_enabled); + let state = vm.create_initial_state(&exe, input); + vm.transport_init_memory_to_device(&state.memory); + + let trace_heights = &self.fixed_air_heights; + let PreflightExecutionOutput { + system_records, + mut record_arenas, + .. + } = vm.execute_preflight(&exe, state, None, trace_heights)?; + // record_arenas are created with capacity specified by trace_heights. we must ensure + // `generate_proving_ctx` does not resize the trace matrices to make them smaller: + for ra in &mut record_arenas { + ra.force_matrix_dimensions(); } - let e = BabyBearPoseidon2RootEngine::new(*self.fri_params()); - e.prove(&self.root_verifier_pk.vm_pk.vm_pk, proof_input) - } -} + vm.override_system_trace_heights(trace_heights); -#[async_trait] -impl AsyncSingleSegmentVmProver for RootVerifierLocalProver { - async fn prove(&self, input: impl Into> + Send + Sync) -> Proof { - SingleSegmentVmProver::prove(self, input) + let mut ctx = vm.generate_proving_ctx(system_records, record_arenas)?; + // Sanity check: ensure all generated trace matrices actually match the fixed heights. + for (air_idx, (fixed_height, (idx, air_ctx))) in + zip_eq(trace_heights, &ctx.per_air).enumerate() + { + let fixed_height = *fixed_height as usize; + if air_idx != *idx { + return Err(GenerationError::ForceTraceHeightIncorrect { + air_idx, + actual: 0, + expected: fixed_height, + } + .into()); + } + if fixed_height != air_ctx.main_trace_height() { + return Err(GenerationError::ForceTraceHeightIncorrect { + air_idx, + actual: air_ctx.main_trace_height(), + expected: fixed_height, + } + .into()); + } + } + // Reorder the AIRs by heights. + self.air_id_perm.permute(&mut ctx.per_air); + for (i, (air_idx, _)) in ctx.per_air.iter_mut().enumerate() { + *air_idx = i; + } + // We also undo the permutation on pk because `prove` needs pk and ctx ordering to match. + Self::permute_pk(vm, &self.air_id_perm); + let proof = vm.engine.prove(vm.pk(), ctx); + Ok(proof) } } diff --git a/crates/sdk/src/prover/stark.rs b/crates/sdk/src/prover/stark.rs index fdec583f0f..90dbf87e40 100644 --- a/crates/sdk/src/prover/stark.rs +++ b/crates/sdk/src/prover/stark.rs @@ -1,84 +1,101 @@ use std::sync::Arc; -use openvm_circuit::arch::VmConfig; -use openvm_continuations::verifier::{ - internal::types::VmStarkProof, root::types::RootVmVerifierInput, +use openvm_circuit::arch::{ + Executor, MeteredExecutor, PreflightExecutor, VirtualMachineError, VmBuilder, VmExecutionConfig, }; -use openvm_stark_backend::{proof::Proof, Chip}; +use openvm_continuations::verifier::internal::types::VmStarkProof; +#[cfg(feature = "evm-prove")] +use openvm_continuations::{verifier::root::types::RootVmVerifierInput, RootSC}; +use openvm_native_circuit::NativeConfig; +#[cfg(feature = "evm-prove")] +use openvm_stark_backend::proof::Proof; use openvm_stark_sdk::engine::StarkFriEngine; use crate::{ config::AggregationTreeConfig, keygen::{AggStarkProvingKey, AppProvingKey}, prover::{agg::AggStarkProver, app::AppProver}, - NonRootCommittedExe, RootSC, StdIn, F, SC, + NonRootCommittedExe, StdIn, F, SC, }; -pub struct StarkProver> { - pub app_prover: AppProver, - pub agg_prover: AggStarkProver, +pub struct StarkProver +where + E: StarkFriEngine, + VB: VmBuilder, + NativeBuilder: VmBuilder, +{ + pub app_prover: AppProver, + pub agg_prover: AggStarkProver, } -impl> StarkProver { +impl StarkProver +where + E: StarkFriEngine, + VB: VmBuilder, + >::Executor: + Executor + MeteredExecutor + PreflightExecutor>::RecordArena>, + NativeBuilder: VmBuilder + Clone, + >::Executor: + PreflightExecutor>::RecordArena>, +{ pub fn new( - app_pk: Arc>, + app_vm_builder: VB, + native_builder: NativeBuilder, + app_pk: Arc>, app_committed_exe: Arc, agg_stark_pk: AggStarkProvingKey, agg_tree_config: AggregationTreeConfig, - ) -> Self - where - VC: VmConfig, - { + ) -> Result { assert_eq!( app_pk.leaf_fri_params, agg_stark_pk.leaf_vm_pk.fri_params, "App VM is incompatible with Agg VM because of leaf FRI parameters" ); assert_eq!( - app_pk.app_vm_pk.vm_config.system().num_public_values, + app_pk.app_vm_pk.vm_config.as_ref().num_public_values, agg_stark_pk.num_user_public_values(), "App VM is incompatible with Agg VM because of the number of public values" ); - Self { - app_prover: AppProver::new(app_pk.app_vm_pk.clone(), app_committed_exe), + Ok(Self { + app_prover: AppProver::new( + app_vm_builder, + app_pk.app_vm_pk.clone(), + app_committed_exe, + )?, agg_prover: AggStarkProver::new( + native_builder, agg_stark_pk, app_pk.leaf_committed_exe.clone(), agg_tree_config, - ), - } + )?, + }) } pub fn set_program_name(&mut self, program_name: impl AsRef) -> &mut Self { self.app_prover.set_program_name(program_name); self } - pub fn generate_proof_for_outer_recursion(&self, input: StdIn) -> Proof - where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, - { - let app_proof = self.app_prover.generate_app_proof(input); + #[cfg(feature = "evm-prove")] + pub fn generate_proof_for_outer_recursion( + &mut self, + input: StdIn, + ) -> Result, VirtualMachineError> { + let app_proof = self.app_prover.generate_app_proof(input)?; self.agg_prover.generate_root_proof(app_proof) } - - pub fn generate_root_verifier_input(&self, input: StdIn) -> RootVmVerifierInput - where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, - { - let app_proof = self.app_prover.generate_app_proof(input); + #[cfg(feature = "evm-prove")] + pub fn generate_root_verifier_input( + &mut self, + input: StdIn, + ) -> Result, VirtualMachineError> { + let app_proof = self.app_prover.generate_app_proof(input)?; self.agg_prover.generate_root_verifier_input(app_proof) } - pub fn generate_e2e_stark_proof(&self, input: StdIn) -> VmStarkProof - where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, - { - let app_proof = self.app_prover.generate_app_proof(input); - let leaf_proofs = self.agg_prover.generate_leaf_proofs(&app_proof); + pub fn generate_e2e_stark_proof( + &mut self, + input: StdIn, + ) -> Result, VirtualMachineError> { + let app_proof = self.app_prover.generate_app_proof(input)?; + let leaf_proofs = self.agg_prover.generate_leaf_proofs(&app_proof)?; self.agg_prover .aggregate_leaf_proofs(leaf_proofs, app_proof.user_public_values.public_values) } diff --git a/crates/sdk/src/prover/vm/local.rs b/crates/sdk/src/prover/vm/local.rs deleted file mode 100644 index b56c6a1ad3..0000000000 --- a/crates/sdk/src/prover/vm/local.rs +++ /dev/null @@ -1,196 +0,0 @@ -use std::{marker::PhantomData, mem, sync::Arc}; - -use async_trait::async_trait; -use openvm_circuit::{ - arch::{ - hasher::poseidon2::vm_poseidon2_hasher, GenerationError, SingleSegmentVmExecutor, Streams, - VirtualMachine, VmComplexTraceHeights, VmConfig, - }, - system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe}, -}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::PrimeField32, - proof::Proof, - Chip, -}; -use openvm_stark_sdk::{config::FriParameters, engine::StarkFriEngine}; -use tracing::info_span; - -use crate::prover::vm::{ - types::VmProvingKey, AsyncContinuationVmProver, AsyncSingleSegmentVmProver, - ContinuationVmProof, ContinuationVmProver, SingleSegmentVmProver, -}; - -pub struct VmLocalProver> { - pub pk: Arc>, - pub committed_exe: Arc>, - overridden_heights: Option, - _marker: PhantomData, -} - -impl> VmLocalProver { - pub fn new(pk: Arc>, committed_exe: Arc>) -> Self { - Self { - pk, - committed_exe, - overridden_heights: None, - _marker: PhantomData, - } - } - - pub fn new_with_overridden_trace_heights( - pk: Arc>, - committed_exe: Arc>, - overridden_heights: Option, - ) -> Self { - Self { - pk, - committed_exe, - overridden_heights, - _marker: PhantomData, - } - } - - pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) { - self.overridden_heights = Some(overridden_heights); - } - - pub fn vm_config(&self) -> &VC { - &self.pk.vm_config - } - #[allow(dead_code)] - pub(crate) fn fri_params(&self) -> &FriParameters { - &self.pk.fri_params - } -} - -const MAX_SEGMENTATION_RETRIES: usize = 4; - -impl>, E: StarkFriEngine> ContinuationVmProver - for VmLocalProver -where - Val: PrimeField32, - VC::Executor: Chip, - VC::Periphery: Chip, -{ - fn prove(&self, input: impl Into>>) -> ContinuationVmProof { - assert!(self.pk.vm_config.system().continuation_enabled); - let e = E::new(self.pk.fri_params); - let trace_height_constraints = self.pk.vm_pk.trace_height_constraints.clone(); - let mut vm = VirtualMachine::new_with_overridden_trace_heights( - e, - self.pk.vm_config.clone(), - self.overridden_heights.clone(), - ); - vm.set_trace_height_constraints(trace_height_constraints.clone()); - let mut final_memory = None; - let VmCommittedExe { - exe, - committed_program, - } = self.committed_exe.as_ref(); - let input = input.into(); - - // This loop should typically iterate exactly once. Only in exceptional cases will the - // segmentation produce an invalid segment and we will have to retry. - let mut retries = 0; - let per_segment = loop { - match vm.executor.execute_and_then( - exe.clone(), - input.clone(), - |seg_idx, mut seg| { - final_memory = mem::take(&mut seg.final_memory); - let proof_input = info_span!("trace_gen", segment = seg_idx) - .in_scope(|| seg.generate_proof_input(Some(committed_program.clone())))?; - info_span!("prove_segment", segment = seg_idx) - .in_scope(|| Ok(vm.engine.prove(&self.pk.vm_pk, proof_input))) - }, - GenerationError::Execution, - ) { - Ok(per_segment) => break per_segment, - Err(GenerationError::Execution(err)) => panic!("execution error: {err}"), - Err(GenerationError::TraceHeightsLimitExceeded) => { - if retries >= MAX_SEGMENTATION_RETRIES { - panic!( - "trace heights limit exceeded after {MAX_SEGMENTATION_RETRIES} retries" - ); - } - retries += 1; - tracing::info!( - "trace heights limit exceeded; retrying execution (attempt {retries})" - ); - let sys_config = vm.executor.config.system_mut(); - let new_seg_strat = sys_config.segmentation_strategy.stricter_strategy(); - sys_config.set_segmentation_strategy(new_seg_strat); - // continue - } - }; - }; - - let user_public_values = UserPublicValuesProof::compute( - self.pk.vm_config.system().memory_config.memory_dimensions(), - self.pk.vm_config.system().num_public_values, - &vm_poseidon2_hasher(), - final_memory.as_ref().unwrap(), - ); - ContinuationVmProof { - per_segment, - user_public_values, - } - } -} - -#[async_trait] -impl>, E: StarkFriEngine> - AsyncContinuationVmProver for VmLocalProver -where - VmLocalProver: Send + Sync, - Val: PrimeField32, - VC::Executor: Chip, - VC::Periphery: Chip, -{ - async fn prove( - &self, - input: impl Into>> + Send + Sync, - ) -> ContinuationVmProof { - ContinuationVmProver::prove(self, input) - } -} - -impl>, E: StarkFriEngine> SingleSegmentVmProver - for VmLocalProver -where - Val: PrimeField32, - VC::Executor: Chip, - VC::Periphery: Chip, -{ - fn prove(&self, input: impl Into>>) -> Proof { - assert!(!self.pk.vm_config.system().continuation_enabled); - let e = E::new(self.pk.fri_params); - // note: use SingleSegmentVmExecutor so there's not a "segment" label in metrics - let executor = { - let mut executor = SingleSegmentVmExecutor::new(self.pk.vm_config.clone()); - executor.set_trace_height_constraints(self.pk.vm_pk.trace_height_constraints.clone()); - executor - }; - let proof_input = executor - .execute_and_generate(self.committed_exe.clone(), input) - .unwrap(); - let vm = VirtualMachine::new(e, executor.config); - vm.prove_single(&self.pk.vm_pk, proof_input) - } -} - -#[async_trait] -impl>, E: StarkFriEngine> - AsyncSingleSegmentVmProver for VmLocalProver -where - VmLocalProver: Send + Sync, - Val: PrimeField32, - VC::Executor: Chip, - VC::Periphery: Chip, -{ - async fn prove(&self, input: impl Into>> + Send + Sync) -> Proof { - SingleSegmentVmProver::prove(self, input) - } -} diff --git a/crates/sdk/src/prover/vm/mod.rs b/crates/sdk/src/prover/vm/mod.rs index bc79d7b30c..0d6bce0511 100644 --- a/crates/sdk/src/prover/vm/mod.rs +++ b/crates/sdk/src/prover/vm/mod.rs @@ -1,34 +1,31 @@ -use async_trait::async_trait; -use openvm_circuit::arch::{ContinuationVmProof, Streams}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - proof::Proof, +use openvm_circuit::{ + arch::{VirtualMachine, VirtualMachineError, VmBuilder, VmLocalProver}, + system::program::trace::VmCommittedExe, }; +use openvm_stark_backend::prover::hal::DeviceDataTransporter; +use openvm_stark_sdk::engine::StarkFriEngine; -pub mod local; -pub mod types; - -/// Prover for a specific exe in a specific continuation VM using a specific Stark config. -pub trait ContinuationVmProver { - fn prove(&self, input: impl Into>>) -> ContinuationVmProof; -} +use crate::prover::vm::types::VmProvingKey; -/// Async prover for a specific exe in a specific continuation VM using a specific Stark config. -#[async_trait] -pub trait AsyncContinuationVmProver { - async fn prove( - &self, - input: impl Into>> + Send + Sync, - ) -> ContinuationVmProof; -} - -/// Prover for a specific exe in a specific single-segment VM using a specific Stark config. -pub trait SingleSegmentVmProver { - fn prove(&self, input: impl Into>>) -> Proof; -} +pub mod types; -/// Async prover for a specific exe in a specific single-segment VM using a specific Stark config. -#[async_trait] -pub trait AsyncSingleSegmentVmProver { - async fn prove(&self, input: impl Into>> + Send + Sync) -> Proof; +pub fn new_local_prover( + vm_builder: VB, + vm_pk: &VmProvingKey, + committed_exe: &VmCommittedExe, +) -> Result, VirtualMachineError> +where + E: StarkFriEngine, + VB: VmBuilder, +{ + let engine = E::new(vm_pk.fri_params); + let d_pk = engine.device().transport_pk_to_device(&vm_pk.vm_pk); + let vm = VirtualMachine::new(engine, vm_builder, vm_pk.vm_config.clone(), d_pk)?; + let cached_program_trace = vm.transport_committed_exe_to_device(committed_exe); + // TODO[jpw]: remove this clone + Ok(VmLocalProver::new( + vm, + committed_exe.exe.clone(), + cached_program_trace, + )) } diff --git a/crates/sdk/src/stdin.rs b/crates/sdk/src/stdin.rs index 9101e8d4de..db5bfbb52e 100644 --- a/crates/sdk/src/stdin.rs +++ b/crates/sdk/src/stdin.rs @@ -4,18 +4,16 @@ use std::{ }; use openvm_circuit::arch::Streams; -use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_backend::p3_field::Field; use serde::{Deserialize, Serialize}; -use crate::F; - #[derive(Clone, Default, Serialize, Deserialize)] -pub struct StdIn { +pub struct StdIn { pub buffer: VecDeque>, pub kv_store: HashMap, Vec>, } -impl StdIn { +impl StdIn { pub fn from_bytes(data: &[u8]) -> Self { let mut ret = Self::default(); ret.write_bytes(data); @@ -45,8 +43,8 @@ impl StdIn { } } -impl From for Streams { - fn from(mut std_in: StdIn) -> Self { +impl From> for Streams { + fn from(mut std_in: StdIn) -> Self { let mut data = Vec::>::new(); while let Some(input) = std_in.read() { data.push(input); @@ -57,9 +55,9 @@ impl From for Streams { } } -impl From>> for StdIn { +impl From>> for StdIn { fn from(inputs: Vec>) -> Self { - let mut ret = StdIn::default(); + let mut ret = StdIn::::default(); for input in inputs { ret.write_field(&input); } diff --git a/crates/sdk/src/types.rs b/crates/sdk/src/types.rs index d83140a5ae..ba8dc9ed8e 100644 --- a/crates/sdk/src/types.rs +++ b/crates/sdk/src/types.rs @@ -10,7 +10,7 @@ use { crate::commit::CommitBytes, itertools::Itertools, openvm_native_recursion::halo2::{wrapper::EvmVerifierByteCode, Fr, RawEvmProof}, - std::iter::{once, repeat}, + std::iter::{once, repeat_n}, thiserror::Error, }; @@ -195,7 +195,7 @@ impl TryFrom for RawEvmProof { let user_public_values = user_public_values .into_iter() - .flat_map(|byte| once(byte).chain(repeat(0).take(31))) + .flat_map(|byte| once(byte).chain(repeat_n(0, 31))) .collect::>(); let mut ret = Vec::new(); diff --git a/crates/sdk/tests/integration_test.rs b/crates/sdk/tests/integration_test.rs index 9248fc5445..a3a77d4759 100644 --- a/crates/sdk/tests/integration_test.rs +++ b/crates/sdk/tests/integration_test.rs @@ -3,36 +3,32 @@ use std::{borrow::Borrow, path::PathBuf, sync::Arc}; use eyre::Result; use openvm_build::GuestOptions; use openvm_circuit::{ - arch::{ - hasher::poseidon2::vm_poseidon2_hasher, ContinuationVmProof, ExecutionError, - GenerationError, SingleSegmentVmExecutor, SystemConfig, VmConfig, VmExecutor, - }, - system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe}, + self, + arch::{ContinuationVmProof, ExecutionError, VirtualMachineError}, + system::program::trace::VmCommittedExe, + utils::test_system_config_with_continuations, }; use openvm_continuations::verifier::{ common::types::VmVerifierPvs, leaf::types::{LeafVmVerifierInput, UserPublicValuesRootProof}, }; -use openvm_native_circuit::{Native, NativeConfig}; +use openvm_native_circuit::{execute_program_with_config, NativeConfig, NativeCpuBuilder}; use openvm_native_compiler::{conversion::CompilerOptions, prelude::*}; -use openvm_native_recursion::types::InnerConfig; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; use openvm_sdk::{ codec::{Decode, Encode}, - config::{AggStarkConfig, AppConfig, SdkSystemConfig, SdkVmConfig}, + config::{AggStarkConfig, AppConfig, SdkSystemConfig, SdkVmConfig, SdkVmCpuBuilder}, keygen::AppProvingKey, Sdk, StdIn, }; -use openvm_stark_backend::{keygen::types::LinearConstraint, p3_matrix::Matrix}; use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, setup_tracing, FriParameters, }, - engine::{StarkEngine, StarkFriEngine}, - openvm_stark_backend::{p3_field::FieldAlgebra, Chip}, + openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear, }; use openvm_transpiler::transpiler::Transpiler; @@ -65,7 +61,6 @@ use { }; type SC = BabyBearPoseidon2Config; -type C = InnerConfig; type F = BabyBear; const NUM_PUB_VALUES: usize = 16; @@ -91,47 +86,37 @@ fn verify_evm_halo2_proof_with_fallback( Ok(gas_cost) } -fn run_leaf_verifier>( - leaf_vm: &SingleSegmentVmExecutor, +fn run_leaf_verifier( + leaf_vm_config: &NativeConfig, leaf_committed_exe: Arc>, verifier_input: LeafVmVerifierInput, -) -> Result, ExecutionError> -where - VC::Executor: Chip, - VC::Periphery: Chip, -{ - let exe_result = leaf_vm.execute_and_compute_heights( - leaf_committed_exe.exe.clone(), +) -> Result, VirtualMachineError> { + assert!(leaf_vm_config.system.has_public_values_chip()); + let (output, _vm) = execute_program_with_config::( + leaf_committed_exe.exe.program.clone(), verifier_input.write_to_stream(), + NativeCpuBuilder, + leaf_vm_config.clone(), )?; - let runtime_pvs: Vec<_> = exe_result - .public_values - .iter() - .map(|v| v.unwrap()) - .collect(); - Ok(runtime_pvs) + Ok(output.system_records.public_values) } fn app_committed_exe_for_test(app_log_blowup: usize) -> Arc> { - let program = { - let n = 200; - let mut builder = Builder::::default(); - let a: Felt = builder.eval(F::ZERO); - let b: Felt = builder.eval(F::ONE); - let c: Felt = builder.uninit(); - builder.range(0, n).for_each(|_, builder| { - builder.assign(&c, a + b); - builder.assign(&a, b); - builder.assign(&b, c); - }); - builder.halt(); - builder.compile_isa() - }; - Sdk::new() - .commit_app_exe( - FriParameters::new_for_testing(app_log_blowup), - program.into(), + let sdk = Sdk::new(); + let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); + pkg_dir.push("guest/fib"); + let vm_config = app_vm_config_for_test(); + let elf = sdk + .build( + Default::default(), + &vm_config, + pkg_dir, + &Default::default(), + None, ) + .unwrap(); + let exe = sdk.transpile(elf, vm_config.transpiler()).unwrap(); + sdk.commit_app_exe(FriParameters::new_for_testing(app_log_blowup), exe) .unwrap() } @@ -162,16 +147,22 @@ fn agg_stark_config_for_test() -> AggStarkConfig { } } -fn small_test_app_config(app_log_blowup: usize) -> AppConfig { +fn app_vm_config_for_test() -> SdkVmConfig { + let config = test_system_config_with_continuations() + .with_max_segment_len(200) + .with_public_values(NUM_PUB_VALUES); + SdkVmConfig::builder() + .system(SdkSystemConfig { config }) + .rv32i(Default::default()) + .rv32m(Default::default()) + .io(Default::default()) + .build() +} + +fn small_test_app_config(app_log_blowup: usize) -> AppConfig { AppConfig { app_fri_params: FriParameters::new_for_testing(app_log_blowup).into(), - app_vm_config: NativeConfig::new( - SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - Native, - ), + app_vm_config: app_vm_config_for_test(), leaf_fri_params: FriParameters::new_for_testing(LEAF_LOG_BLOWUP).into(), compiler_options: CompilerOptions { enable_cycle_tracker: true, @@ -181,40 +172,38 @@ fn small_test_app_config(app_log_blowup: usize) -> AppConfig { } #[test] -fn test_public_values_and_leaf_verification() { - let app_log_blowup = 3; +fn test_public_values_and_leaf_verification() -> eyre::Result<()> { + setup_tracing(); + let app_log_blowup = 1; let app_config = small_test_app_config(app_log_blowup); - let app_pk = AppProvingKey::keygen(app_config); + let app_pk = Arc::new(AppProvingKey::keygen(app_config)?); let app_committed_exe = app_committed_exe_for_test(app_log_blowup); + let pc_start = app_committed_exe.exe.pc_start; let agg_stark_config = agg_stark_config_for_test(); let leaf_vm_config = agg_stark_config.leaf_vm_config(); - let leaf_vm = SingleSegmentVmExecutor::new(leaf_vm_config); let leaf_committed_exe = app_pk.leaf_committed_exe.clone(); - let app_engine = BabyBearPoseidon2Engine::new(app_pk.app_vm_pk.fri_params); - let app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); - let app_vm_result = app_vm - .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]) - .unwrap(); - assert!(app_vm_result.per_segment.len() > 2); + let sdk = Sdk::new(); + let mut app_proof = sdk.generate_app_proof( + SdkVmCpuBuilder, + app_pk, + app_committed_exe.clone(), + StdIn::default(), + )?; - let mut app_vm_seg_proofs: Vec<_> = app_vm_result - .per_segment - .into_iter() - .map(|proof_input| app_engine.prove(&app_pk.app_vm_pk.vm_pk, proof_input)) - .collect(); - let app_last_proof = app_vm_seg_proofs.pop().unwrap(); + assert!(app_proof.per_segment.len() > 2); + let app_last_proof = app_proof.per_segment.pop().unwrap(); let expected_app_commit: [F; DIGEST_SIZE] = app_committed_exe.get_program_commit().into(); // Verify all segments except the last one. let (first_seg_final_pc, first_seg_final_mem_root) = { let runtime_pvs = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { - proofs: app_vm_seg_proofs.clone(), + proofs: app_proof.per_segment.clone(), public_values_root_proof: None, }, ) @@ -224,25 +213,23 @@ fn test_public_values_and_leaf_verification() { assert_eq!(leaf_vm_pvs.app_commit, expected_app_commit); assert_eq!(leaf_vm_pvs.connector.is_terminate, F::ZERO); - assert_eq!(leaf_vm_pvs.connector.initial_pc, F::ZERO); + assert_eq!( + leaf_vm_pvs.connector.initial_pc, + F::from_canonical_u32(pc_start) + ); ( leaf_vm_pvs.connector.final_pc, leaf_vm_pvs.memory.final_root, ) }; - let pv_proof = UserPublicValuesProof::compute( - app_vm.config.system.memory_config.memory_dimensions(), - NUM_PUB_VALUES, - &vm_poseidon2_hasher(), - app_vm_result.final_memory.as_ref().unwrap(), - ); + let pv_proof = app_proof.user_public_values; let pv_root_proof = UserPublicValuesRootProof::extract(&pv_proof); // Verify the last segment with the correct public values root proof. { let runtime_pvs = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { proofs: vec![app_last_proof.clone()], @@ -268,7 +255,7 @@ fn test_public_values_and_leaf_verification() { let mut wrong_pv_root_proof = pv_root_proof.clone(); wrong_pv_root_proof.public_values_commit[0] += F::ONE; let execution_result = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { proofs: vec![app_last_proof.clone()], @@ -276,7 +263,10 @@ fn test_public_values_and_leaf_verification() { }, ); assert!( - matches!(execution_result, Err(ExecutionError::Fail { .. })), + matches!( + execution_result, + Err(VirtualMachineError::Execution(ExecutionError::Fail { .. })) + ), "Expected failure: the public value root proof has a wrong pv commit: {:?}", execution_result ); @@ -287,7 +277,7 @@ fn test_public_values_and_leaf_verification() { let mut wrong_pv_root_proof = pv_root_proof.clone(); wrong_pv_root_proof.sibling_hashes[0][0] += F::ONE; let execution_result = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { proofs: vec![app_last_proof.clone()], @@ -295,15 +285,20 @@ fn test_public_values_and_leaf_verification() { }, ); assert!( - matches!(execution_result, Err(ExecutionError::Fail { .. })), + matches!( + execution_result, + Err(VirtualMachineError::Execution(ExecutionError::Fail { .. })) + ), "Expected failure: the public value root proof has a wrong path proof: {:?}", execution_result ); } + Ok(()) } #[cfg(feature = "evm-verify")] #[test] +#[ignore = "slow"] fn test_static_verifier_custom_pv_handler() { // Define custom public values handler and implement StaticVerifierPvHandler trait on it pub struct CustomPvHandler { @@ -386,6 +381,7 @@ fn test_static_verifier_custom_pv_handler() { let evm_proof = sdk .generate_evm_proof( ¶ms_reader, + SdkVmCpuBuilder, Arc::new(app_pk), app_committed_exe, agg_pk, @@ -403,33 +399,8 @@ fn test_static_verifier_custom_pv_handler() { #[cfg(feature = "evm-verify")] #[test] fn test_e2e_proof_generation_and_verification_with_pvs() { - let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); - pkg_dir.push("guest/fib"); - - let vm_config = SdkVmConfig::builder() - .system(SdkSystemConfig { - config: SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - }) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .native(Default::default()) - .build(); - + let vm_config = app_vm_config_for_test(); let sdk = Sdk::new(); - let elf = sdk - .build( - Default::default(), - &vm_config, - pkg_dir, - &Default::default(), - None, - ) - .unwrap(); - let exe = sdk.transpile(elf, vm_config.transpiler()).unwrap(); let app_log_blowup = 1; let app_fri_params = FriParameters::new_for_testing(app_log_blowup); @@ -438,10 +409,7 @@ fn test_e2e_proof_generation_and_verification_with_pvs() { AppConfig::new_with_leaf_fri_params(app_fri_params, vm_config, leaf_fri_params); app_config.compiler_options.enable_cycle_tracker = true; - let app_committed_exe = sdk - .commit_app_exe(app_fri_params, exe) - .expect("failed to commit exe"); - + let app_committed_exe = app_committed_exe_for_test(app_log_blowup); let app_pk = sdk.app_keygen(app_config).unwrap(); let params_reader = CacheHalo2ParamsReader::new_with_default_params_dir(); @@ -460,6 +428,7 @@ fn test_e2e_proof_generation_and_verification_with_pvs() { let evm_proof = sdk .generate_evm_proof( ¶ms_reader, + SdkVmCpuBuilder, Arc::new(app_pk), app_committed_exe, agg_pk, @@ -475,25 +444,11 @@ fn test_e2e_proof_generation_and_verification_with_pvs() { #[test] fn test_sdk_guest_build_and_transpile() { let sdk = Sdk::new(); - let guest_opts = GuestOptions::default() - // .with_features(vec!["zkvm"]) - // .with_options(vec!["--release"]); - ; + let guest_opts = GuestOptions::default(); let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); pkg_dir.push("guest/fib"); - let vm_config = SdkVmConfig::builder() - .system(SdkSystemConfig { - config: SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - }) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .native(Default::default()) - .build(); + let vm_config = app_vm_config_for_test(); let one = sdk .build( @@ -526,35 +481,18 @@ fn test_sdk_guest_build_and_transpile() { fn test_inner_proof_codec_roundtrip() -> eyre::Result<()> { // generate a proof let sdk = Sdk::new(); - let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); - pkg_dir.push("guest/fib"); - - let vm_config = SdkVmConfig::builder() - .system(SdkSystemConfig { - config: SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - }) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .native(Default::default()) - .build(); - let elf = sdk.build( - Default::default(), - &vm_config, - pkg_dir, - &Default::default(), - None, - )?; + let vm_config = app_vm_config_for_test(); assert!(vm_config.system.config.continuation_enabled); - let exe = sdk.transpile(elf, vm_config.transpiler())?; let fri_params = FriParameters::standard_fast(); let app_config = AppConfig::new(fri_params, vm_config); - let committed_exe = sdk.commit_app_exe(fri_params, exe)?; + let committed_exe = app_committed_exe_for_test(fri_params.log_blowup); let app_pk = Arc::new(sdk.app_keygen(app_config)?); - let app_proof = sdk.generate_app_proof(app_pk.clone(), committed_exe, StdIn::default())?; + let app_proof = sdk.generate_app_proof( + SdkVmCpuBuilder, + app_pk.clone(), + committed_exe, + StdIn::default(), + )?; let mut app_proof_bytes = Vec::new(); app_proof.encode(&mut app_proof_bytes)?; let decoded_app_proof = ContinuationVmProof::decode(&mut &app_proof_bytes[..])?; @@ -567,59 +505,3 @@ fn test_inner_proof_codec_roundtrip() -> eyre::Result<()> { sdk.verify_app_proof(&app_pk.get_app_vk(), &decoded_app_proof)?; Ok(()) } - -#[test] -fn test_segmentation_retry() { - setup_tracing(); - let app_log_blowup = 3; - let app_config = small_test_app_config(app_log_blowup); - let app_pk = AppProvingKey::keygen(app_config); - let app_committed_exe = app_committed_exe_for_test(app_log_blowup); - - let app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); - let app_vm_result = app_vm - .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]) - .unwrap(); - assert!(app_vm_result.per_segment.len() > 2); - - let total_height: usize = app_vm_result.per_segment[0] - .per_air - .iter() - .map(|(_, input)| { - let main = input.raw.common_main.as_ref(); - main.map(|mat| mat.height()).unwrap_or(0) - }) - .sum(); - - // Re-run with a threshold that will be violated. - let mut app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); - let num_airs = app_pk.app_vm_pk.vm_pk.per_air.len(); - app_vm.set_trace_height_constraints(vec![LinearConstraint { - coefficients: vec![1; num_airs], - threshold: total_height as u32 - 1, - }]); - let app_vm_result = - app_vm.execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]); - assert!(matches!( - app_vm_result, - Err(GenerationError::TraceHeightsLimitExceeded) - )); - - // Try lowering segmentation threshold. - let config = VmConfig::::system_mut(&mut app_vm.config); - config.set_segmentation_strategy(config.segmentation_strategy.stricter_strategy()); - let app_vm_result = app_vm - .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]) - .unwrap(); - - // New max height should indeed by smaller. - let new_total_height: usize = app_vm_result.per_segment[0] - .per_air - .iter() - .map(|(_, input)| { - let main = input.raw.common_main.as_ref(); - main.map(|mat| mat.height()).unwrap_or(0) - }) - .sum(); - assert!(new_total_height < total_height); -} diff --git a/crates/toolchain/build/src/lib.rs b/crates/toolchain/build/src/lib.rs index 749a29d346..e15941faa0 100644 --- a/crates/toolchain/build/src/lib.rs +++ b/crates/toolchain/build/src/lib.rs @@ -21,7 +21,13 @@ mod config; /// The rustc compiler [target](https://doc.rust-lang.org/rustc/targets/index.html). pub const RUSTC_TARGET: &str = "riscv32im-risc0-zkvm-elf"; -const RUSTUP_TOOLCHAIN_NAME: &str = "nightly-2025-02-14"; +/// The default Rust toolchain name to use if OPENVM_RUST_TOOLCHAIN is not set +pub const DEFAULT_RUSTUP_TOOLCHAIN_NAME: &str = "nightly-2025-02-14"; + +/// Get the Rust toolchain name from environment variable or default +pub fn get_rustup_toolchain_name() -> String { + env::var("OPENVM_RUST_TOOLCHAIN").unwrap_or_else(|_| DEFAULT_RUSTUP_TOOLCHAIN_NAME.to_string()) +} const BUILD_LOCKED_ENV: &str = "OPENVM_BUILD_LOCKED"; const SKIP_BUILD_ENV: &str = "OPENVM_SKIP_BUILD"; const GUEST_LOGFILE_ENV: &str = "OPENVM_GUEST_LOGFILE"; @@ -240,7 +246,7 @@ fn sanitized_cmd(tool: &str) -> Command { /// Creates a std::process::Command to execute the given cargo /// command in an environment suitable for targeting the zkvm guest. pub fn cargo_command(subcmd: &str, rust_flags: &[&str]) -> Command { - let toolchain = format!("+{RUSTUP_TOOLCHAIN_NAME}"); + let toolchain = format!("+{}", get_rustup_toolchain_name()); let rustc = sanitized_cmd("rustup") .args([&toolchain, "which", "rustc"]) @@ -382,7 +388,7 @@ pub fn build_generic(guest_opts: &GuestOptions) -> Result> // Check if the required toolchain and rust-src component are installed, and if not, install // them. This requires that `rustup` is installed. - if let Err(code) = ensure_toolchain_installed(RUSTUP_TOOLCHAIN_NAME, &["rust-src"]) { + if let Err(code) = ensure_toolchain_installed(&get_rustup_toolchain_name(), &["rust-src"]) { eprintln!("rustup toolchain commands failed. Please ensure rustup is installed (https://www.rust-lang.org/tools/install)"); return Err(Some(code)); } diff --git a/crates/toolchain/instructions/src/exe.rs b/crates/toolchain/instructions/src/exe.rs index fb84ec7da5..9db5f242ac 100644 --- a/crates/toolchain/instructions/src/exe.rs +++ b/crates/toolchain/instructions/src/exe.rs @@ -5,8 +5,9 @@ use serde::{Deserialize, Serialize}; use crate::program::Program; -/// Memory image is a map from (address space, address) to word. -pub type MemoryImage = BTreeMap<(u32, u32), F>; +// TODO[jpw]: delete this +/// Memory image is a map from (address space, address * size_of) to u8. +pub type SparseMemoryImage = BTreeMap<(u32, u32), u8>; /// Stores the starting address, end address, and name of a set of function. pub type FnBounds = BTreeMap; @@ -22,7 +23,7 @@ pub struct VmExe { /// Start address of pc. pub pc_start: u32, /// Initial memory image. - pub init_memory: MemoryImage, + pub init_memory: SparseMemoryImage, /// Starting + ending bounds for each function. pub fn_bounds: FnBounds, } @@ -40,7 +41,7 @@ impl VmExe { self.pc_start = pc_start; self } - pub fn with_init_memory(mut self, init_memory: MemoryImage) -> Self { + pub fn with_init_memory(mut self, init_memory: SparseMemoryImage) -> Self { self.init_memory = init_memory; self } diff --git a/crates/toolchain/instructions/src/lib.rs b/crates/toolchain/instructions/src/lib.rs index c251e77d0d..76e7200cbb 100644 --- a/crates/toolchain/instructions/src/lib.rs +++ b/crates/toolchain/instructions/src/lib.rs @@ -18,6 +18,8 @@ pub mod utils; pub use phantom::*; +pub const NATIVE_AS: u32 = 4; + pub trait LocalOpcode { const CLASS_OFFSET: usize; /// Convert from the discriminant of the enum to the typed enum variant. @@ -25,8 +27,11 @@ pub trait LocalOpcode { fn from_usize(value: usize) -> Self; fn local_usize(&self) -> usize; + fn global_opcode_usize(&self) -> usize { + self.local_usize() + Self::CLASS_OFFSET + } fn global_opcode(&self) -> VmOpcode { - VmOpcode::from_usize(self.local_usize() + Self::CLASS_OFFSET) + VmOpcode::from_usize(self.global_opcode_usize()) } } @@ -36,17 +41,19 @@ pub struct VmOpcode(usize); impl VmOpcode { /// Returns the corresponding `local_opcode_idx` - pub fn local_opcode_idx(&self, offset: usize) -> usize { + #[inline(always)] + pub const fn local_opcode_idx(&self, offset: usize) -> usize { self.as_usize() - offset } /// Returns the opcode as a usize - pub fn as_usize(&self) -> usize { + #[inline(always)] + pub const fn as_usize(&self) -> usize { self.0 } /// Create a new [VmOpcode] from a usize - pub fn from_usize(value: usize) -> Self { + pub const fn from_usize(value: usize) -> Self { Self(value) } diff --git a/crates/toolchain/instructions/src/program.rs b/crates/toolchain/instructions/src/program.rs index 010b70514d..73c901be8a 100644 --- a/crates/toolchain/instructions/src/program.rs +++ b/crates/toolchain/instructions/src/program.rs @@ -1,4 +1,8 @@ -use std::{fmt, fmt::Display}; +use std::{ + fmt::{self, Display}, + ops::Deref, + sync::Arc, +}; use itertools::Itertools; use openvm_stark_backend::p3_field::Field; @@ -24,37 +28,35 @@ pub struct Program { deserialize_with = "deserialize_instructions_and_debug_infos" )] pub instructions_and_debug_infos: Vec, Option)>>, - pub step: u32, pub pc_base: u32, } +#[derive(Clone, Debug, Default)] +pub struct ProgramDebugInfo { + inner: Arc>>, + pc_base: u32, +} + impl Program { - pub fn new_empty(step: u32, pc_base: u32) -> Self { + pub fn new_empty(pc_base: u32) -> Self { Self { instructions_and_debug_infos: vec![], - step, pc_base, } } - pub fn new_without_debug_infos( - instructions: &[Instruction], - step: u32, - pc_base: u32, - ) -> Self { + pub fn new_without_debug_infos(instructions: &[Instruction], pc_base: u32) -> Self { Self { instructions_and_debug_infos: instructions .iter() .map(|instruction| Some((instruction.clone(), None))) .collect(), - step, pc_base, } } pub fn new_without_debug_infos_with_option( instructions: &[Option>], - step: u32, pc_base: u32, ) -> Self { Self { @@ -62,7 +64,6 @@ impl Program { .iter() .map(|instruction| instruction.clone().map(|instruction| (instruction, None))) .collect(), - step, pc_base, } } @@ -79,7 +80,6 @@ impl Program { .zip_eq(debug_infos.iter()) .map(|(instruction, debug_info)| Some((instruction.clone(), debug_info.clone()))) .collect(), - step: DEFAULT_PC_STEP, pc_base: 0, } } @@ -96,7 +96,7 @@ impl Program { } pub fn from_instructions(instructions: &[Instruction]) -> Self { - Self::new_without_debug_infos(instructions, DEFAULT_PC_STEP, 0) + Self::new_without_debug_infos(instructions, 0) } pub fn len(&self) -> usize { @@ -120,14 +120,6 @@ impl Program { self.defined_instructions().len() } - pub fn debug_infos(&self) -> Vec> { - self.instructions_and_debug_infos - .iter() - .flatten() - .map(|(_, debug_info)| debug_info.clone()) - .collect() - } - pub fn enumerate_by_pc(&self) -> Vec<(u32, Instruction, Option)> { self.instructions_and_debug_infos .iter() @@ -135,7 +127,7 @@ impl Program { .flat_map(|(index, option)| { option.clone().map(|(instruction, debug_info)| { ( - self.pc_base + (self.step * (index as u32)), + self.pc_base + (DEFAULT_PC_STEP * (index as u32)), instruction, debug_info, ) @@ -172,6 +164,21 @@ impl Program { .extend(other.instructions_and_debug_infos); } } + +impl Program { + pub fn debug_infos(&self) -> ProgramDebugInfo { + let debug_infos = self + .instructions_and_debug_infos + .iter() + .map(|opt| opt.as_ref().and_then(|(_, debug_info)| debug_info.clone())) + .collect(); + ProgramDebugInfo { + inner: Arc::new(debug_infos), + pc_base: self.pc_base, + } + } +} + impl Display for Program { fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { for instruction in self.defined_instructions().iter() { @@ -195,6 +202,24 @@ impl Display for Program { } } +impl ProgramDebugInfo { + /// ## Panics + /// If `pc` is out of bounds. + pub fn get(&self, pc: u32) -> &Option { + let pc_base = self.pc_base; + let pc_idx = ((pc - pc_base) / DEFAULT_PC_STEP) as usize; + &self.inner[pc_idx] + } +} + +impl Deref for ProgramDebugInfo { + type Target = [Option]; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + pub fn display_program_with_pc(program: &Program) { for (pc, instruction) in program.defined_instructions().iter().enumerate() { let Instruction { @@ -257,7 +282,7 @@ mod tests { #[test] fn test_program_serde() { - let mut program = Program::::new_empty(4, 0); + let mut program = Program::::new_empty(0); program.instructions_and_debug_infos.push(Some(( Instruction::from_isize(VmOpcode::from_usize(113), 1, 2, 3, 4, 5), None, diff --git a/crates/toolchain/instructions/src/riscv.rs b/crates/toolchain/instructions/src/riscv.rs index b2998c4539..720b323d52 100644 --- a/crates/toolchain/instructions/src/riscv.rs +++ b/crates/toolchain/instructions/src/riscv.rs @@ -5,3 +5,5 @@ pub const RV32_CELL_BITS: usize = 8; pub const RV32_IMM_AS: u32 = 0; pub const RV32_REGISTER_AS: u32 = 1; pub const RV32_MEMORY_AS: u32 = 2; + +pub const RV32_NUM_REGISTERS: usize = 32; diff --git a/crates/toolchain/platform/src/alloc.rs b/crates/toolchain/platform/src/alloc.rs new file mode 100644 index 0000000000..0af25a3671 --- /dev/null +++ b/crates/toolchain/platform/src/alloc.rs @@ -0,0 +1,62 @@ +extern crate alloc; + +use alloc::alloc::{alloc, dealloc, handle_alloc_error, Layout}; +use core::ptr::NonNull; + +/// Bytes allocated according to the given Layout +pub struct AlignedBuf { + pub ptr: *mut u8, + pub layout: Layout, +} + +impl AlignedBuf { + /// Allocate a new buffer whose start address is aligned to `align` bytes. + /// *NOTE* if `len` is zero then a creates new `NonNull` that is dangling and 16-byte aligned. + pub fn uninit(len: usize, align: usize) -> Self { + let layout = Layout::from_size_align(len, align).unwrap(); + if layout.size() == 0 { + return Self { + ptr: NonNull::::dangling().as_ptr() as *mut u8, + layout, + }; + } + // SAFETY: `len` is nonzero + let ptr = unsafe { alloc(layout) }; + if ptr.is_null() { + handle_alloc_error(layout); + } + AlignedBuf { ptr, layout } + } + + /// Allocate a new buffer whose start address is aligned to `align` bytes + /// and copy the given data into it. + /// + /// # Safety + /// - `bytes` must not be null + /// - `len` should not be zero + /// + /// See [alloc]. In particular `data` should not be empty. + pub unsafe fn new(bytes: *const u8, len: usize, align: usize) -> Self { + let buf = Self::uninit(len, align); + // SAFETY: + // - src and dst are not null + // - src and dst are allocated for size + // - no alignment requirements on u8 + // - non-overlapping since ptr is newly allocated + unsafe { + core::ptr::copy_nonoverlapping(bytes, buf.ptr, len); + } + + buf + } +} + +impl Drop for AlignedBuf { + fn drop(&mut self) { + if self.layout.size() != 0 { + unsafe { + dealloc(self.ptr, self.layout); + } + } + } +} diff --git a/crates/toolchain/platform/src/lib.rs b/crates/toolchain/platform/src/lib.rs index 1ace328a66..2a0beedef1 100644 --- a/crates/toolchain/platform/src/lib.rs +++ b/crates/toolchain/platform/src/lib.rs @@ -4,12 +4,15 @@ #![deny(rustdoc::broken_intra_doc_links)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))] +#[cfg(target_os = "zkvm")] pub use openvm_custom_insn::{custom_insn_i, custom_insn_r}; +#[cfg(target_os = "zkvm")] +pub mod alloc; #[cfg(all(feature = "rust-runtime", target_os = "zkvm"))] pub mod heap; #[cfg(all(feature = "export-libm", target_os = "zkvm"))] mod libm_extern; + pub mod memory; pub mod print; #[cfg(feature = "rust-runtime")] @@ -19,9 +22,6 @@ pub mod rust_rt; /// 4 bytes (i.e. 32 bits) as the zkVM is an implementation of the rv32im ISA. pub const WORD_SIZE: usize = core::mem::size_of::(); -/// Size of a zkVM memory page. -pub const PAGE_SIZE: usize = 1024; - /// Standard IO file descriptors for use with sys_read and sys_write. pub mod fileno { pub const STDIN: u32 = 0; diff --git a/crates/toolchain/tests/Cargo.toml b/crates/toolchain/tests/Cargo.toml index 9f3e3caa82..c2349b893f 100644 --- a/crates/toolchain/tests/Cargo.toml +++ b/crates/toolchain/tests/Cargo.toml @@ -8,11 +8,16 @@ homepage.workspace = true repository.workspace = true [dependencies] +openvm-build.workspace = true +openvm-circuit.workspace = true +openvm-transpiler.workspace = true +eyre.workspace = true +tempfile.workspace = true + +[dev-dependencies] openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-transpiler.workspace = true -openvm-build.workspace = true openvm-algebra-transpiler.workspace = true openvm-bigint-circuit.workspace = true openvm-rv32im-circuit.workspace = true @@ -21,10 +26,8 @@ openvm-algebra-circuit.workspace = true openvm-ecc-circuit = { workspace = true } openvm-instructions = { workspace = true } openvm-platform = { workspace = true } - -eyre.workspace = true test-case.workspace = true -tempfile.workspace = true +rand = { workspace = true } serde = { workspace = true, features = ["alloc"] } derive_more = { workspace = true, features = ["from"] } @@ -36,4 +39,4 @@ default = ["parallel"] parallel = ["openvm-circuit/parallel"] [package.metadata.cargo-shear] -ignored = ["derive_more", "openvm-stark-backend"] +ignored = ["derive_more", "openvm-stark-backend", "rand"] diff --git a/crates/toolchain/tests/src/utils.rs b/crates/toolchain/tests/src/utils.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/toolchain/tests/tests/riscv_test_vectors.rs b/crates/toolchain/tests/tests/riscv_test_vectors.rs index 9516b0cd7b..9b0c2524e4 100644 --- a/crates/toolchain/tests/tests/riscv_test_vectors.rs +++ b/crates/toolchain/tests/tests/riscv_test_vectors.rs @@ -5,7 +5,7 @@ use openvm_circuit::{ arch::{instructions::exe::VmExe, VmExecutor}, utils::air_test, }; -use openvm_rv32im_circuit::Rv32ImConfig; +use openvm_rv32im_circuit::{Rv32ImConfig, Rv32ImCpuBuilder}; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -39,9 +39,10 @@ fn test_rv32im_riscv_vector_runtime() -> Result<()> { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - let executor = VmExecutor::::new(config.clone()); - let res = executor.execute(exe, vec![])?; - Ok(res) + let executor = VmExecutor::new(config.clone())?; + let interpreter = executor.instance(&exe)?; + let _state = interpreter.execute(vec![], None)?; + Ok(()) }); match result { @@ -80,7 +81,7 @@ fn test_rv32im_riscv_vector_prove() -> Result<()> { )?; let result = std::panic::catch_unwind(|| { - air_test(config.clone(), exe); + air_test(Rv32ImCpuBuilder, config.clone(), exe); }); match result { diff --git a/crates/toolchain/tests/tests/transpiler_tests.rs b/crates/toolchain/tests/tests/transpiler_tests.rs index bf07eccc42..727ed0032f 100644 --- a/crates/toolchain/tests/tests/transpiler_tests.rs +++ b/crates/toolchain/tests/tests/transpiler_tests.rs @@ -5,28 +5,22 @@ use std::{ use eyre::Result; use num_bigint::BigUint; -use openvm_algebra_circuit::{ - Fp2Extension, Fp2ExtensionExecutor, Fp2ExtensionPeriphery, ModularExtension, - ModularExtensionExecutor, ModularExtensionPeriphery, -}; +use openvm_algebra_circuit::*; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; -use openvm_bigint_circuit::{Int256, Int256Executor, Int256Periphery}; +use openvm_bigint_circuit::*; use openvm_circuit::{ arch::{InitFileGenerator, SystemConfig, VmExecutor}, derive::VmConfig, + system::SystemExecutor, utils::air_test, }; -use openvm_ecc_circuit::{SECP256K1_MODULUS, SECP256K1_ORDER}; +use openvm_ecc_circuit::SECP256K1_CONFIG; use openvm_instructions::exe::VmExe; use openvm_platform::memory::MEM_SIZE; -use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32ImConfig, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, - Rv32M, Rv32MExecutor, Rv32MPeriphery, -}; +use openvm_rv32im_circuit::*; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_stark_backend::p3_field::PrimeField32; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_transpiler::{elf::Elf, transpiler::Transpiler, FromElf}; use serde::{Deserialize, Serialize}; @@ -80,14 +74,15 @@ fn test_rv32im_runtime(elf_path: &str) -> Result<()> { .with_extension(Rv32IoTranspilerExtension), )?; let config = Rv32ImConfig::default(); - let executor = VmExecutor::::new(config); - executor.execute(exe, vec![])?; + let executor = VmExecutor::new(config)?; + let interpreter = executor.instance(&exe)?; + interpreter.execute(vec![], None)?; Ok(()) } #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] pub struct Rv32ModularFp2Int256Config { - #[system] + #[config(executor = "SystemExecutor")] pub system: SystemConfig, #[extension] pub base: Rv32I, @@ -130,8 +125,14 @@ impl InitFileGenerator for Rv32ModularFp2Int256Config { #[test_case("tests/data/rv32im-intrin-from-as")] fn test_intrinsic_runtime(elf_path: &str) -> Result<()> { let config = Rv32ModularFp2Int256Config::new( - vec![SECP256K1_MODULUS.clone(), SECP256K1_ORDER.clone()], - vec![("Secp256k1Coord".to_string(), SECP256K1_MODULUS.clone())], + vec![ + SECP256K1_CONFIG.modulus.clone(), + SECP256K1_CONFIG.scalar.clone(), + ], + vec![( + SECP256K1_CONFIG.struct_name.clone(), + SECP256K1_CONFIG.modulus.clone(), + )], ); let elf = get_elf(elf_path)?; let openvm_exe = VmExe::from_elf( @@ -143,8 +144,9 @@ fn test_intrinsic_runtime(elf_path: &str) -> Result<()> { .with_extension(ModularTranspilerExtension) .with_extension(Fp2TranspilerExtension), )?; - let executor = VmExecutor::::new(config); - executor.execute(openvm_exe, vec![])?; + let executor = VmExecutor::new(config)?; + let interpreter = executor.instance(&openvm_exe)?; + interpreter.execute(vec![], None)?; Ok(()) } @@ -160,6 +162,6 @@ fn test_terminate_prove() -> Result<()> { .with_extension(Rv32IoTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ImCpuBuilder, config, openvm_exe); Ok(()) } diff --git a/crates/toolchain/transpiler/src/lib.rs b/crates/toolchain/transpiler/src/lib.rs index 367b028393..ee85e9b153 100644 --- a/crates/toolchain/transpiler/src/lib.rs +++ b/crates/toolchain/transpiler/src/lib.rs @@ -1,10 +1,7 @@ //! A transpiler from custom RISC-V ELFs to OpenVM executable binaries. use elf::Elf; -use openvm_instructions::{ - exe::VmExe, - program::{Program, DEFAULT_PC_STEP}, -}; +use openvm_instructions::{exe::VmExe, program::Program}; pub use openvm_platform; use openvm_stark_backend::p3_field::PrimeField32; use transpiler::{Transpiler, TranspilerError}; @@ -29,11 +26,7 @@ impl FromElf for VmExe { type ElfContext = Transpiler; fn from_elf(elf: Elf, transpiler: Self::ElfContext) -> Result { let instructions = transpiler.transpile(&elf.instructions)?; - let program = Program::new_without_debug_infos_with_option( - &instructions, - DEFAULT_PC_STEP, - elf.pc_base, - ); + let program = Program::new_without_debug_infos_with_option(&instructions, elf.pc_base); let init_memory = elf_memory_image_to_openvm_memory_image(elf.memory_image); Ok(VmExe { diff --git a/crates/toolchain/transpiler/src/util.rs b/crates/toolchain/transpiler/src/util.rs index d9135de153..c5711653ff 100644 --- a/crates/toolchain/transpiler/src/util.rs +++ b/crates/toolchain/transpiler/src/util.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use openvm_instructions::{ - exe::MemoryImage, + exe::SparseMemoryImage, instruction::Instruction, riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS}, utils::isize_to_field, @@ -165,17 +165,14 @@ pub fn nop() -> Instruction { } } -/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as, address) -> word) -pub fn elf_memory_image_to_openvm_memory_image( +/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as=2, address) -> byte) +pub fn elf_memory_image_to_openvm_memory_image( memory_image: BTreeMap, -) -> MemoryImage { - let mut result = MemoryImage::new(); +) -> SparseMemoryImage { + let mut result = SparseMemoryImage::new(); for (addr, word) in memory_image { for (i, byte) in word.to_le_bytes().into_iter().enumerate() { - result.insert( - (RV32_MEMORY_AS, addr + i as u32), - F::from_canonical_u8(byte), - ); + result.insert((RV32_MEMORY_AS, addr + i as u32), byte); } } result diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 80e6794b48..c0db02dbeb 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -27,7 +27,6 @@ backtrace.workspace = true rand.workspace = true serde.workspace = true serde-big-array.workspace = true -cfg-if.workspace = true metrics = { workspace = true, optional = true } thiserror.workspace = true rustc-hash.workspace = true @@ -35,24 +34,41 @@ eyre.workspace = true derivative.workspace = true static_assertions.workspace = true getset.workspace = true +dashmap.workspace = true + +[target.'cfg(any(unix, windows))'.dependencies] +memmap2.workspace = true [dev-dependencies] test-log.workspace = true openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-stark-sdk.workspace = true -openvm-native-circuit.workspace = true +openvm-native-circuit = { workspace = true, features = ["test-utils"] } openvm-native-compiler.workspace = true openvm-rv32im-transpiler.workspace = true [features] default = ["parallel", "jemalloc"] -parallel = ["openvm-stark-backend/parallel"] -test-utils = ["dep:openvm-stark-sdk"] -bench-metrics = ["dep:metrics", "openvm-stark-backend/bench-metrics"] -function-span = ["bench-metrics"] +parallel = [ + "openvm-stark-backend/parallel", + "dashmap/rayon", + "openvm-stark-sdk?/parallel", +] +metrics = [ + "dep:metrics", + "openvm-stark-backend/metrics", + "openvm-stark-sdk?/metrics", +] +# turns on more invasive profiling for fine-grained guest metrics +perf-metrics = ["metrics"] +# use basic memory instead of mmap: +basic-memory = [] +# turns on stark-backend debugger in all proofs +stark-debug = [] +test-utils = ["openvm-stark-sdk"] # performance features: mimalloc = ["openvm-stark-backend/mimalloc"] jemalloc = ["openvm-stark-backend/jemalloc"] jemalloc-prof = ["openvm-stark-backend/jemalloc-prof"] -nightly-features = ["openvm-stark-sdk/nightly-features"] +nightly-features = ["openvm-stark-sdk?/nightly-features"] diff --git a/crates/vm/derive/Cargo.toml b/crates/vm/derive/Cargo.toml index bd3c7cb693..d2d11dcc78 100644 --- a/crates/vm/derive/Cargo.toml +++ b/crates/vm/derive/Cargo.toml @@ -12,4 +12,5 @@ proc-macro = true [dependencies] syn = { version = "2.0", features = ["parsing"] } quote = "1.0" +proc-macro2 = "1.0" itertools = { workspace = true } diff --git a/crates/vm/derive/src/lib.rs b/crates/vm/derive/src/lib.rs index 37dca6e4ed..ec103fc1c9 100644 --- a/crates/vm/derive/src/lib.rs +++ b/crates/vm/derive/src/lib.rs @@ -4,15 +4,33 @@ extern crate proc_macro; use itertools::{multiunzip, Itertools}; use proc_macro::{Span, TokenStream}; use quote::{quote, ToTokens}; -use syn::{punctuated::Punctuated, Data, Fields, GenericParam, Ident, Meta, Token}; +use syn::{ + parse_quote, punctuated::Punctuated, spanned::Spanned, Data, DataStruct, Field, Fields, + GenericParam, Ident, Meta, Token, +}; -#[proc_macro_derive(InstructionExecutor)] -pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { +#[proc_macro_derive(PreflightExecutor)] +pub fn preflight_executor_derive(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; let generics = &ast.generics; - let (impl_generics, ty_generics, _) = generics.split_for_impl(); + let (_, ty_generics, _) = generics.split_for_impl(); + + let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site()); + let mut new_generics = generics.clone(); + new_generics.params.push(syn::parse_quote! { RA }); + let field_ty_generic = generics + .params + .first() + .and_then(|param| match param { + GenericParam::Type(type_param) => Some(&type_param.ident), + _ => None, + }) + .unwrap_or_else(|| { + new_generics.params.push(syn::parse_quote! { F }); + &default_ty_generic + }); match &ast.data { Data::Struct(inner) => { @@ -27,21 +45,20 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { _ => panic!("Only unnamed fields are supported"), }; // Use full path ::openvm_circuit... so it can be used either within or outside the vm - // crate. Assume F is already generic of the field. - let mut new_generics = generics.clone(); + // crate. let where_clause = new_generics.make_where_clause(); where_clause.predicates.push( - syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InstructionExecutor }, + syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> }, ); + let (impl_generics, _, where_clause) = new_generics.split_for_impl(); quote! { - impl #impl_generics ::openvm_circuit::arch::InstructionExecutor for #name #ty_generics #where_clause { + impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause { fn execute( &mut self, - memory: &mut ::openvm_circuit::system::memory::MemoryController, - instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, - from_state: ::openvm_circuit::arch::ExecutionState, - ) -> ::openvm_circuit::arch::Result<::openvm_circuit::arch::ExecutionState> { - self.0.execute(memory, instruction, from_state) + state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>, + ) -> Result<(), ::openvm_circuit::arch::ExecutionError> { + self.0.execute(state, instruction) } fn get_opcode_name(&self, opcode: usize) -> String { @@ -64,37 +81,35 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { (variant_name, field) }) .collect::>(); - let first_ty_generic = ast - .generics - .params - .first() - .and_then(|param| match param { - GenericParam::Type(type_param) => Some(&type_param.ident), - _ => None, - }) - .expect("First generic must be type for Field"); // Use full path ::openvm_circuit... so it can be used either within or outside the vm // crate. Assume F is already generic of the field. - let (execute_arms, get_opcode_name_arms): (Vec<_>, Vec<_>) = + let (execute_arms, get_opcode_name_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| { let field_ty = &field.ty; let execute_arm = quote! { - #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic>>::execute(x, memory, instruction, from_state) + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::execute(x, state, instruction) }; let get_opcode_name_arm = quote! { - #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic>>::get_opcode_name(x, opcode) + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::get_opcode_name(x, opcode) }; - - (execute_arm, get_opcode_name_arm) + let where_predicate = syn::parse_quote! { + #field_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> + }; + (execute_arm, get_opcode_name_arm, where_predicate) })); + let where_clause = new_generics.make_where_clause(); + for predicate in where_predicates { + where_clause.predicates.push(predicate); + } + // Don't use these ty_generics because it might have extra "F" + let (impl_generics, _, where_clause) = new_generics.split_for_impl(); quote! { - impl #impl_generics ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic> for #name #ty_generics { + impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause { fn execute( &mut self, - memory: &mut ::openvm_circuit::system::memory::MemoryController<#first_ty_generic>, - instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>, - from_state: ::openvm_circuit::arch::ExecutionState, - ) -> ::openvm_circuit::arch::Result<::openvm_circuit::arch::ExecutionState> { + state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>, + ) -> Result<(), ::openvm_circuit::arch::ExecutionError> { match self { #(#execute_arms,)* } @@ -113,6 +128,262 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { } } +#[proc_macro_derive(Executor)] +pub fn executor_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + where_clause + .predicates + .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::Executor }); + quote! { + impl #impl_generics ::openvm_circuit::arch::Executor for #name #ty_generics #where_clause { + #[inline(always)] + fn pre_compute_size(&self) -> usize { + self.0.pre_compute_size() + } + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::ExecuteFunc, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::E1ExecutionCtx, { + self.0.pre_compute(pc, inst, data) + } + } + } + .into() + } + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!(fields.next().is_none(), "Only one field is supported"); + (variant_name, field) + }) + .collect::>(); + let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site()); + let mut new_generics = generics.clone(); + let first_ty_generic = ast + .generics + .params + .first() + .and_then(|param| match param { + GenericParam::Type(type_param) => Some(&type_param.ident), + _ => None, + }) + .unwrap_or_else(|| { + new_generics.params.push(syn::parse_quote! { F }); + &default_ty_generic + }); + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let (pre_compute_size_arms, pre_compute_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + let pre_compute_size_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::pre_compute_size(x) + }; + let pre_compute_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::Executor<#first_ty_generic>>::pre_compute(x, pc, instruction, data) + }; + let where_predicate = syn::parse_quote! { + #field_ty: ::openvm_circuit::arch::Executor<#first_ty_generic> + }; + (pre_compute_size_arm, pre_compute_arm, where_predicate) + })); + let where_clause = new_generics.make_where_clause(); + for predicate in where_predicates { + where_clause.predicates.push(predicate); + } + // Don't use these ty_generics because it might have extra "F" + let (impl_generics, _, where_clause) = new_generics.split_for_impl(); + + quote! { + impl #impl_generics ::openvm_circuit::arch::Executor<#first_ty_generic> for #name #ty_generics #where_clause { + #[inline(always)] + fn pre_compute_size(&self) -> usize { + match self { + #(#pre_compute_size_arms,)* + } + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::ExecuteFunc, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::E1ExecutionCtx, { + match self { + #(#pre_compute_arms,)* + } + } + } + } + .into() + } + Data::Union(_) => unimplemented!("Unions are not supported"), + } +} + +#[proc_macro_derive(MeteredExecutor)] +pub fn metered_executor_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + where_clause + .predicates + .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::MeteredExecutor }); + quote! { + impl #impl_generics ::openvm_circuit::arch::MeteredExecutor for #name #ty_generics #where_clause { + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + self.0.metered_pre_compute_size() + } + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::ExecuteFunc, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::E2ExecutionCtx, { + self.0.metered_pre_compute(chip_idx, pc, inst, data) + } + } + } + .into() + } + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!(fields.next().is_none(), "Only one field is supported"); + (variant_name, field) + }) + .collect::>(); + let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site()); + let mut new_generics = generics.clone(); + let first_ty_generic = ast + .generics + .params + .first() + .and_then(|param| match param { + GenericParam::Type(type_param) => Some(&type_param.ident), + _ => None, + }) + .unwrap_or_else(|| { + new_generics.params.push(syn::parse_quote! { F }); + &default_ty_generic + }); + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let (pre_compute_size_arms, metered_pre_compute_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + let pre_compute_size_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic>>::metered_pre_compute_size(x) + }; + let metered_pre_compute_arm = quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic>>::metered_pre_compute(x, chip_idx, pc, instruction, data) + }; + let where_predicate = syn::parse_quote! { + #field_ty: ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic> + }; + (pre_compute_size_arm, metered_pre_compute_arm, where_predicate) + })); + let where_clause = new_generics.make_where_clause(); + for predicate in where_predicates { + where_clause.predicates.push(predicate); + } + // Don't use these ty_generics because it might have extra "F" + let (impl_generics, _, where_clause) = new_generics.split_for_impl(); + + quote! { + impl #impl_generics ::openvm_circuit::arch::MeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause { + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + match self { + #(#pre_compute_size_arms,)* + } + } + + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> Result<::openvm_circuit::arch::ExecuteFunc, ::openvm_circuit::arch::StaticProgramError> + where + Ctx: ::openvm_circuit::arch::execution_mode::E2ExecutionCtx, { + match self { + #(#metered_pre_compute_arms,)* + } + } + } + } + .into() + } + Data::Union(_) => unimplemented!("Unions are not supported"), + } +} + /// Derives `AnyEnum` trait on an enum type. /// By default an enum arm will just return `self` as `&dyn Any`. /// @@ -189,18 +460,23 @@ pub fn any_enum_derive(input: TokenStream) -> TokenStream { } } -// VmConfig derive macro -#[derive(Debug)] -enum Source { - System(Ident), - Config(Ident), -} - -#[proc_macro_derive(VmConfig, attributes(system, config, extension))] +#[proc_macro_derive(VmConfig, attributes(config, extension))] pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = syn::parse_macro_input!(input as syn::DeriveInput); let name = &ast.ident; + match &ast.data { + syn::Data::Struct(inner) => match generate_config_traits_impl(name, inner) { + Ok(tokens) => tokens, + Err(err) => err.to_compile_error().into(), + }, + _ => syn::Error::new(name.span(), "Only structs are supported") + .to_compile_error() + .into(), + } +} + +fn generate_config_traits_impl(name: &Ident, inner: &DataStruct) -> syn::Result { let gen_name_with_uppercase_idents = |ident: &Ident| { let mut name = ident.to_string().chars().collect::>(); assert!(name[0].is_lowercase(), "Field name must not be capitalized"); @@ -210,180 +486,217 @@ pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::T (res_lower, res_upper) }; - match &ast.data { - syn::Data::Struct(inner) => { - let fields = match &inner.fields { - Fields::Named(named) => named.named.iter().collect(), - Fields::Unnamed(_) => { - return syn::Error::new(name.span(), "Only named fields are supported") - .to_compile_error() - .into(); - } - Fields::Unit => vec![], - }; + let fields = match &inner.fields { + Fields::Named(named) => named.named.iter().collect(), + Fields::Unnamed(_) => { + return Err(syn::Error::new( + name.span(), + "Only named fields are supported", + )) + } + Fields::Unit => vec![], + }; - let source = fields - .iter() - .filter_map(|f| { - if f.attrs.iter().any(|attr| attr.path().is_ident("system")) { - Some(Source::System(f.ident.clone().unwrap())) - } else if f.attrs.iter().any(|attr| attr.path().is_ident("config")) { - Some(Source::Config(f.ident.clone().unwrap())) - } else { - None - } - }) - .exactly_one() - .expect("Exactly one field must have #[system] or #[config] attribute"); - let (source_name, source_name_upper) = match &source { - Source::System(ident) | Source::Config(ident) => { - gen_name_with_uppercase_idents(ident) - } - }; + let source_field = fields + .iter() + .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("config"))) + .exactly_one() + .clone() + .expect("Exactly one field must have the #[config] attribute"); + let (source_name, source_name_upper) = + gen_name_with_uppercase_idents(source_field.ident.as_ref().unwrap()); - let extensions = fields - .iter() - .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("extension"))) - .cloned() - .collect::>(); + let extensions = fields + .iter() + .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("extension"))) + .cloned() + .collect::>(); - let mut executor_enum_fields = Vec::new(); - let mut periphery_enum_fields = Vec::new(); - let mut create_chip_complex = Vec::new(); - for &e in extensions.iter() { - let (field_name, field_name_upper) = - gen_name_with_uppercase_idents(&e.ident.clone().unwrap()); - // TRACKING ISSUE: - // We cannot just use >::Executor because of this: - let mut executor_name = Ident::new( - &format!("{}Executor", e.ty.to_token_stream()), - Span::call_site().into(), - ); - let mut periphery_name = Ident::new( - &format!("{}Periphery", e.ty.to_token_stream()), - Span::call_site().into(), - ); - if let Some(attr) = e - .attrs - .iter() - .find(|attr| attr.path().is_ident("extension")) - { - match attr.meta { - Meta::Path(_) => {} - Meta::NameValue(_) => { - return syn::Error::new( - name.span(), - "Only `#[extension]` or `#[extension(...)] formats are supported", - ) - .to_compile_error() - .into() - } - _ => { - let nested = attr - .parse_args_with(Punctuated::::parse_terminated) - .unwrap(); - for meta in nested { - match meta { - Meta::NameValue(nv) => { - if nv.path.is_ident("executor") { - executor_name = Ident::new( - &nv.value.to_token_stream().to_string(), - Span::call_site().into(), - ); - Ok(()) - } else if nv.path.is_ident("periphery") { - periphery_name = Ident::new( - &nv.value.to_token_stream().to_string(), - Span::call_site().into(), - ); - Ok(()) - } else { - Err("only executor and periphery keys are supported") - } - } - _ => Err("only name = value format is supported"), - } - .expect("wrong attributes format"); - } - } - } - }; - executor_enum_fields.push(quote! { - #[any_enum] - #field_name_upper(#executor_name), - }); - periphery_enum_fields.push(quote! { - #[any_enum] - #field_name_upper(#periphery_name), - }); - create_chip_complex.push(quote! { - let complex: ::openvm_circuit::arch::VmChipComplex = complex.extend(&self.#field_name)?; - }); - } + let mut executor_enum_fields = Vec::new(); + let mut create_executors = Vec::new(); + let mut create_airs = Vec::new(); + let mut execution_where_predicates: Vec = Vec::new(); + let mut circuit_where_predicates: Vec = Vec::new(); - let (source_executor_type, source_periphery_type) = match &source { - Source::System(_) => ( - quote! { ::openvm_circuit::arch::SystemExecutor }, - quote! { ::openvm_circuit::arch::SystemPeriphery }, - ), - Source::Config(field_ident) => { - let field_type = fields - .iter() - .find(|f| f.ident.as_ref() == Some(field_ident)) - .map(|f| &f.ty) - .expect("Field not found"); + let source_field_ty = source_field.ty.clone(); - let executor_type = format!("{}Executor", quote!(#field_type)); - let periphery_type = format!("{}Periphery", quote!(#field_type)); + for e in extensions.iter() { + let (ext_field_name, ext_name_upper) = + gen_name_with_uppercase_idents(e.ident.as_ref().expect("field must be named")); + let executor_type = parse_executor_type(e, false)?; + executor_enum_fields.push(quote! { + #[any_enum] + #ext_name_upper(#executor_type), + }); + create_executors.push(quote! { + let inventory: ::openvm_circuit::arch::ExecutorInventory = inventory.extend::(&self.#ext_field_name)?; + }); + let extension_ty = e.ty.clone(); + execution_where_predicates.push(parse_quote! { + #extension_ty: ::openvm_circuit::arch::VmExecutionExtension + }); + create_airs.push(quote! { + inventory.start_new_extension(); + ::openvm_circuit::arch::VmCircuitExtension::extend_circuit(&self.#ext_field_name, &mut inventory)?; + }); + circuit_where_predicates.push(parse_quote! { + #extension_ty: ::openvm_circuit::arch::VmCircuitExtension + }); + } - let executor_ident = Ident::new(&executor_type, field_ident.span()); - let periphery_ident = Ident::new(&periphery_type, field_ident.span()); + // The config type always needs due to SystemExecutor + let source_executor_type = parse_executor_type(source_field, true)?; + execution_where_predicates.push(parse_quote! { + #source_field_ty: ::openvm_circuit::arch::VmExecutionConfig + }); + circuit_where_predicates.push(parse_quote! { + #source_field_ty: ::openvm_circuit::arch::VmCircuitConfig + }); + let execution_where_clause = quote! { where #(#execution_where_predicates),* }; + let circuit_where_clause = quote! { where #(#circuit_where_predicates),* }; - (quote! { #executor_ident }, quote! { #periphery_ident }) - } - }; + let executor_type = Ident::new(&format!("{}Executor", name), name.span()); - let executor_type = Ident::new(&format!("{}Executor", name), name.span()); - let periphery_type = Ident::new(&format!("{}Periphery", name), name.span()); + let token_stream = TokenStream::from(quote! { + #[derive( + Clone, + ::derive_more::derive::From, + ::openvm_circuit::derive::AnyEnum, + ::openvm_circuit::derive::Executor, + ::openvm_circuit::derive::MeteredExecutor, + ::openvm_circuit::derive::PreflightExecutor, + )] + pub enum #executor_type { + #[any_enum] + #source_name_upper(#source_executor_type), + #(#executor_enum_fields)* + } - TokenStream::from(quote! { - #[derive(::openvm_circuit::circuit_derive::ChipUsageGetter, ::openvm_circuit::circuit_derive::Chip, ::openvm_circuit::derive::InstructionExecutor, ::derive_more::derive::From, ::openvm_circuit::derive::AnyEnum)] - pub enum #executor_type { - #[any_enum] - #source_name_upper(#source_executor_type), - #(#executor_enum_fields)* - } + impl ::openvm_circuit::arch::VmExecutionConfig for #name #execution_where_clause { + type Executor = #executor_type; - #[derive(::openvm_circuit::circuit_derive::ChipUsageGetter, ::openvm_circuit::circuit_derive::Chip, ::derive_more::derive::From, ::openvm_circuit::derive::AnyEnum)] - pub enum #periphery_type { - #[any_enum] - #source_name_upper(#source_periphery_type), - #(#periphery_enum_fields)* - } + fn create_executors( + &self, + ) -> Result<::openvm_circuit::arch::ExecutorInventory, ::openvm_circuit::arch::ExecutorInventoryError> { + let inventory = self.#source_name.create_executors()?.transmute::(); + #(#create_executors)* + Ok(inventory) + } + } - impl ::openvm_circuit::arch::VmConfig for #name { - type Executor = #executor_type; - type Periphery = #periphery_type; + impl ::openvm_circuit::arch::VmCircuitConfig for #name #circuit_where_clause { + fn create_airs( + &self, + ) -> Result<::openvm_circuit::arch::AirInventory, ::openvm_circuit::arch::AirInventoryError> { + let mut inventory = self.#source_name.create_airs()?; + #(#create_airs)* + Ok(inventory) + } + } - fn system(&self) -> &::openvm_circuit::arch::SystemConfig { - ::openvm_circuit::arch::VmConfig::::system(&self.#source_name) - } - fn system_mut(&mut self) -> &mut ::openvm_circuit::arch::SystemConfig { - ::openvm_circuit::arch::VmConfig::::system_mut(&mut self.#source_name) - } + impl AsRef for #name { + fn as_ref(&self) -> &SystemConfig { + self.#source_name.as_ref() + } + } - fn create_chip_complex( - &self, - ) -> Result<::openvm_circuit::arch::VmChipComplex, ::openvm_circuit::arch::VmInventoryError> { - let complex = self.#source_name.create_chip_complex()?; - #(#create_chip_complex)* - Ok(complex) + impl AsMut for #name { + fn as_mut(&mut self) -> &mut SystemConfig { + self.#source_name.as_mut() + } + } + }); + Ok(token_stream) +} + +// Parse the executor name as either +// `{type_name}Executor` or whatever the attribute `executor = ` specifies +// Also determines whether the executor type needs generic parameters +fn parse_executor_type( + f: &Field, + default_needs_generics: bool, +) -> syn::Result { + // TRACKING ISSUE: + // We cannot just use >::Executor because of this: + let mut executor_type = None; + // Do not unwrap the Result until needed + let executor_name = syn::parse_str::(&format!("{}Executor", f.ty.to_token_stream())); + + if let Some(attr) = f + .attrs + .iter() + .find(|attr| attr.path().is_ident("extension") || attr.path().is_ident("config")) + { + match attr.meta { + Meta::Path(_) => {} + Meta::NameValue(_) => { + return Err(syn::Error::new( + f.ty.span(), + "Only `#[config]`, `#[extension]`, `#[config(...)]` or `#[extension(...)]` formats are supported", + )) + } + _ => { + let nested = attr + .parse_args_with(Punctuated::::parse_terminated)?; + for meta in nested { + match meta { + Meta::NameValue(nv) => { + if nv.path.is_ident("executor") { + executor_type = match nv.value { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), .. + }) => { + let executor_type: syn::Type = syn::parse_str(&lit_str.value())?; + Some(quote! { #executor_type }) + }, + syn::Expr::Path(path) => { + // Handle identifier paths like `executor = MyExecutor` + Some(path.to_token_stream()) + }, + _ => { + return Err(syn::Error::new( + nv.value.span(), + "executor value must be a string literal or identifier" + )); + } + }; + } else if nv.path.is_ident("generics") { + // Parse boolean value for generics + let value_str = nv.value.to_token_stream().to_string(); + let needs_generics = match value_str.as_str() { + "true" => true, + "false" => false, + _ => return Err(syn::Error::new( + nv.value.span(), + "generics attribute must be either true or false" + )) + }; + let executor_name = executor_name.clone()?; + executor_type = Some(if needs_generics { + quote! { #executor_name } + } else { + quote! { #executor_name } + }); + } else { + return Err(syn::Error::new(nv.span(), "only executor and generics keys are supported")); + } + } + _ => { + return Err(syn::Error::new(meta.span(), "only name = value format is supported")); + } } } - }) + } } - _ => syn::Error::new(name.span(), "Only structs are supported") - .to_compile_error() - .into(), + } + if let Some(executor_type) = executor_type { + Ok(executor_type) + } else { + let executor_name = executor_name?; + Ok(if default_needs_generics { + quote! { #executor_name } + } else { + quote! { #executor_name } + }) } } diff --git a/crates/vm/src/arch/config.rs b/crates/vm/src/arch/config.rs index d82b5f7cf0..f0de3f80d1 100644 --- a/crates/vm/src/arch/config.rs +++ b/crates/vm/src/arch/config.rs @@ -1,17 +1,32 @@ -use std::{fs::File, io::Write, path::Path, sync::Arc}; +use std::{fs::File, io::Write, path::Path}; use derive_new::new; -use openvm_circuit::system::memory::MemoryTraceHeights; +use getset::{Setters, WithSetters}; +use openvm_instructions::{ + riscv::{RV32_IMM_AS, RV32_MEMORY_AS, RV32_REGISTER_AS}, + NATIVE_AS, +}; use openvm_poseidon2_air::Poseidon2Config; -use openvm_stark_backend::{p3_field::PrimeField32, ChipUsageGetter}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::Field, +}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use super::{ - segment::DefaultSegmentationStrategy, AnyEnum, InstructionExecutor, SegmentationStrategy, - SystemComplex, SystemExecutor, SystemPeriphery, VmChipComplex, VmInventoryError, - PUBLIC_VALUES_AIR_ID, +use super::{AnyEnum, VmChipComplex, PUBLIC_VALUES_AIR_ID}; +use crate::{ + arch::{ + execution_mode::metered::segment_ctx::SegmentationLimits, AirInventory, AirInventoryError, + Arena, ChipInventoryError, ExecutorInventory, ExecutorInventoryError, + }, + system::{ + memory::{ + merkle::public_values::PUBLIC_VALUES_AS, num_memory_airs, CHUNK, POINTER_MAX_BITS, + }, + SystemChipComplex, + }, }; -use crate::system::memory::BOUNDARY_AIR_OFFSET; // sbox is decomposed to have this max degree for Poseidon2. We set to 3 so quotient_degree = 2 // allows log_blowup = 1 @@ -19,28 +34,86 @@ const DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE: usize = 3; pub const DEFAULT_MAX_NUM_PUBLIC_VALUES: usize = 32; /// Width of Poseidon2 VM uses. pub const POSEIDON2_WIDTH: usize = 16; +/// Offset for address space indices. This is used to distinguish between different memory spaces. +pub const ADDR_SPACE_OFFSET: u32 = 1; /// Returns a Poseidon2 config for the VM. -pub fn vm_poseidon2_config() -> Poseidon2Config { +pub fn vm_poseidon2_config() -> Poseidon2Config { Poseidon2Config::default() } -pub trait VmConfig: - Clone + Serialize + DeserializeOwned + InitFileGenerator +/// A VM configuration is the minimum serializable format to be able to create the execution +/// environment and circuit for a zkVM supporting a fixed set of instructions. +/// +/// For users who only need to create an execution environment, use the sub-trait +/// [VmExecutionConfig] to avoid the `SC` generic. +/// +/// This trait does not contain the [VmProverBuilder] trait, because a single VM configuration may +/// implement multiple [VmProverBuilder]s for different prover backends. +pub trait VmConfig: + Clone + + Serialize + + DeserializeOwned + + InitFileGenerator + + VmExecutionConfig> + + VmCircuitConfig + + AsRef + + AsMut +where + SC: StarkGenericConfig, { - type Executor: InstructionExecutor + AnyEnum + ChipUsageGetter; - type Periphery: AnyEnum + ChipUsageGetter; +} - /// Must contain system config - fn system(&self) -> &SystemConfig; - fn system_mut(&mut self) -> &mut SystemConfig; +pub trait VmExecutionConfig { + type Executor: AnyEnum + Send + Sync; + + fn create_executors(&self) + -> Result, ExecutorInventoryError>; +} +pub trait VmCircuitConfig { + fn create_airs(&self) -> Result, AirInventoryError>; +} + +/// This trait is intended to be implemented on a new type wrapper of the VmConfig struct to get +/// around Rust orphan rules. +pub trait VmBuilder: Sized { + type VmConfig: VmConfig; + type RecordArena: Arena; + type SystemChipInventory: SystemChipComplex; + + /// Create a [VmChipComplex] from the full [AirInventory], which should be the output of + /// [VmCircuitConfig::create_airs]. + #[allow(clippy::type_complexity)] fn create_chip_complex( &self, - ) -> Result, VmInventoryError>; + config: &Self::VmConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + >; +} + +impl VmConfig for VC +where + SC: StarkGenericConfig, + VC: Clone + + Serialize + + DeserializeOwned + + InitFileGenerator + + VmExecutionConfig> + + VmCircuitConfig + + AsRef + + AsMut, +{ } pub const OPENVM_DEFAULT_INIT_FILE_BASENAME: &str = "openvm_init"; pub const OPENVM_DEFAULT_INIT_FILE_NAME: &str = "openvm_init.rs"; +/// The minimum block size is 4, but RISC-V `lb` only requires alignment of 1 and `lh` only requires +/// alignment of 2 because the instructions are implemented by doing an access of block size 4. +const DEFAULT_U8_BLOCK_SIZE: usize = 4; +const DEFAULT_NATIVE_BLOCK_SIZE: usize = 1; /// Trait for generating a init.rs file that contains a call to moduli_init!, /// complex_init!, sw_init! with the supported moduli and curves. @@ -68,37 +141,93 @@ pub trait InitFileGenerator { } } -#[derive(Debug, Serialize, Deserialize, Clone, new, Copy)] +/// Each address space in guest memory may be configured with a different type `T` to represent a +/// memory cell in the address space. On host, the address space will be mapped to linear host +/// memory in bytes. The type `T` must be plain old data (POD) and be safely transmutable from a +/// fixed size array of bytes. Moreover, each type `T` must be convertible to a field element `F`. +/// +/// We currently implement this trait on the enum [MemoryCellType], which includes all cell types +/// that we expect to be used in the VM context. +pub trait AddressSpaceHostLayout { + /// Size in bytes of the memory cell type. + fn size(&self) -> usize; + + /// # Safety + /// - This function must only be called when `value` is guaranteed to be of size `self.size()`. + /// - Alignment of `value` must be a multiple of the alignment of `F`. + /// - The field type `F` must be plain old data. + unsafe fn to_field(&self, value: &[u8]) -> F; +} + +#[derive(Debug, Serialize, Deserialize, Clone, new)] pub struct MemoryConfig { - /// The maximum height of the address space. This means the trie has `as_height` layers for - /// searching the address space. The allowed address spaces are those in the range `[as_offset, - /// as_offset + 2^as_height)` where `as_offset` is currently fixed to `1` to not allow address - /// space `0` in memory. - pub as_height: usize, - /// The offset of the address space. Should be fixed to equal `1`. - pub as_offset: u32, + /// The maximum height of the address space. This means the trie has `addr_space_height` layers + /// for searching the address space. The allowed address spaces are those in the range `[1, + /// 1 + 2^addr_space_height)` where it starts from 1 to not allow address space 0 in memory. + pub addr_space_height: usize, + /// It is expected that the size of the list is `(1 << addr_space_height) + 1` and the first + /// element is 0, which means no address space. + pub addr_spaces: Vec, pub pointer_max_bits: usize, - /// All timestamps must be in the range `[0, 2^clk_max_bits)`. Maximum allowed: 29. - pub clk_max_bits: usize, + /// All timestamps must be in the range `[0, 2^timestamp_max_bits)`. Maximum allowed: 29. + pub timestamp_max_bits: usize, /// Limb size used by the range checker pub decomp: usize, /// Maximum N AccessAdapter AIR to support. pub max_access_adapter_n: usize, - /// An expected upper bound on the number of memory accesses. - pub access_capacity: usize, } impl Default for MemoryConfig { fn default() -> Self { - Self::new(3, 1, 29, 29, 17, 32, 1 << 24) + let mut addr_spaces = + Self::empty_address_space_configs((1 << 3) + ADDR_SPACE_OFFSET as usize); + const MAX_CELLS: usize = 1 << 29; + addr_spaces[RV32_REGISTER_AS as usize].num_cells = 32 * size_of::(); + addr_spaces[RV32_MEMORY_AS as usize].num_cells = MAX_CELLS; + addr_spaces[PUBLIC_VALUES_AS as usize].num_cells = DEFAULT_MAX_NUM_PUBLIC_VALUES; + addr_spaces[NATIVE_AS as usize].num_cells = MAX_CELLS; + Self::new(3, addr_spaces, POINTER_MAX_BITS, 29, 17, 32) + } +} + +impl MemoryConfig { + pub fn empty_address_space_configs(num_addr_spaces: usize) -> Vec { + // All except address spaces 0..4 default to native 32-bit field. + // By default only address spaces 1..=4 have non-empty cell counts. + let mut addr_spaces = vec![ + AddressSpaceHostConfig::new( + 0, + DEFAULT_NATIVE_BLOCK_SIZE, + MemoryCellType::native32() + ); + num_addr_spaces + ]; + addr_spaces[RV32_IMM_AS as usize] = AddressSpaceHostConfig::new(0, 1, MemoryCellType::Null); + addr_spaces[RV32_REGISTER_AS as usize] = + AddressSpaceHostConfig::new(0, DEFAULT_U8_BLOCK_SIZE, MemoryCellType::U8); + addr_spaces[RV32_MEMORY_AS as usize] = + AddressSpaceHostConfig::new(0, DEFAULT_U8_BLOCK_SIZE, MemoryCellType::U8); + addr_spaces[PUBLIC_VALUES_AS as usize] = + AddressSpaceHostConfig::new(0, DEFAULT_U8_BLOCK_SIZE, MemoryCellType::U8); + + addr_spaces + } + + /// Config for aggregation usage with only native address space. + pub fn aggregation() -> Self { + let mut addr_spaces = + Self::empty_address_space_configs((1 << 3) + ADDR_SPACE_OFFSET as usize); + addr_spaces[NATIVE_AS as usize].num_cells = 1 << 29; + Self::new(3, addr_spaces, POINTER_MAX_BITS, 29, 17, 8) } } /// System-level configuration for the virtual machine. Contains all configuration parameters that /// are managed by the architecture, including configuration for continuations support. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Setters, WithSetters)] pub struct SystemConfig { /// The maximum constraint degree any chip is allowed to use. + #[getset(set_with = "pub")] pub max_constraint_degree: usize, /// True if the VM is in continuation mode. In this mode, an execution could be segmented and /// each segment is proved by a proof. Each proof commits the before and after state of the @@ -119,47 +248,41 @@ pub struct SystemConfig { /// Whether to collect detailed profiling metrics. /// **Warning**: this slows down the runtime. pub profiling: bool, - /// Segmentation strategy + /// Segmentation limits /// This field is skipped in serde as it's only used in execution and /// not needed after any serialize/deserialize. - #[serde(skip, default = "get_default_segmentation_strategy")] - pub segmentation_strategy: Arc, -} - -pub fn get_default_segmentation_strategy() -> Arc { - Arc::new(DefaultSegmentationStrategy::default()) -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SystemTraceHeights { - pub memory: MemoryTraceHeights, - // All other chips have constant heights. + #[serde(skip, default = "SegmentationLimits::default")] + #[getset(set = "pub")] + pub segmentation_limits: SegmentationLimits, } impl SystemConfig { pub fn new( max_constraint_degree: usize, - memory_config: MemoryConfig, + mut memory_config: MemoryConfig, num_public_values: usize, ) -> Self { - let segmentation_strategy = get_default_segmentation_strategy(); assert!( - memory_config.clk_max_bits <= 29, + memory_config.timestamp_max_bits <= 29, "Timestamp max bits must be <= 29 for LessThan to work in 31-bit field" ); + memory_config.addr_spaces[PUBLIC_VALUES_AS as usize].num_cells = num_public_values; Self { max_constraint_degree, continuation_enabled: false, memory_config, num_public_values, - segmentation_strategy, profiling: false, + segmentation_limits: SegmentationLimits::default(), } } - pub fn with_max_constraint_degree(mut self, max_constraint_degree: usize) -> Self { - self.max_constraint_degree = max_constraint_degree; - self + pub fn default_from_memory(memory_config: MemoryConfig) -> Self { + Self::new( + DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE, + memory_config, + DEFAULT_MAX_NUM_PUBLIC_VALUES, + ) } pub fn with_continuations(mut self) -> Self { @@ -174,20 +297,15 @@ impl SystemConfig { pub fn with_public_values(mut self, num_public_values: usize) -> Self { self.num_public_values = num_public_values; + self.memory_config.addr_spaces[PUBLIC_VALUES_AS as usize].num_cells = num_public_values; self } pub fn with_max_segment_len(mut self, max_segment_len: usize) -> Self { - self.segmentation_strategy = Arc::new( - DefaultSegmentationStrategy::new_with_max_segment_len(max_segment_len), - ); + self.segmentation_limits.max_trace_height = max_segment_len as u32; self } - pub fn set_segmentation_strategy(&mut self, strategy: Arc) { - self.segmentation_strategy = strategy; - } - pub fn with_profiling(mut self) -> Self { self.profiling = true; self @@ -204,55 +322,123 @@ impl SystemConfig { /// Returns the AIR ID of the memory boundary AIR. Panic if the boundary AIR is not enabled. pub fn memory_boundary_air_id(&self) -> usize { - let mut ret = PUBLIC_VALUES_AIR_ID; - if self.has_public_values_chip() { - ret += 1; + PUBLIC_VALUES_AIR_ID + usize::from(self.has_public_values_chip()) + } + + /// AIR ID for the first memory access adapter AIR. + pub fn access_adapter_air_id_offset(&self) -> usize { + let boundary_idx = self.memory_boundary_air_id(); + // boundary, (if persistent memory) merkle AIRs + boundary_idx + 1 + usize::from(self.continuation_enabled) + } + + /// This is O(1) and returns the length of + /// [`SystemAirInventory::into_airs`](crate::system::SystemAirInventory::into_airs). + pub fn num_airs(&self) -> usize { + self.memory_boundary_air_id() + + num_memory_airs( + self.continuation_enabled, + self.memory_config.max_access_adapter_n, + ) + } + + pub fn initial_block_size(&self) -> usize { + match self.continuation_enabled { + true => CHUNK, + false => 1, } - ret += BOUNDARY_AIR_OFFSET; - ret } } impl Default for SystemConfig { fn default() -> Self { - Self::new( - DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE, - Default::default(), - DEFAULT_MAX_NUM_PUBLIC_VALUES, - ) + Self::default_from_memory(MemoryConfig::default()) } } -impl SystemTraceHeights { - /// Round all trace heights to the next power of two. This will round trace heights of 0 to 1. - pub fn round_to_next_power_of_two(&mut self) { - self.memory.round_to_next_power_of_two(); +impl AsRef for SystemConfig { + fn as_ref(&self) -> &SystemConfig { + self } +} - /// Round all trace heights to the next power of two, except 0 stays 0. - pub fn round_to_next_power_of_two_or_zero(&mut self) { - self.memory.round_to_next_power_of_two_or_zero(); +impl AsMut for SystemConfig { + fn as_mut(&mut self) -> &mut SystemConfig { + self } } -impl VmConfig for SystemConfig { - type Executor = SystemExecutor; - type Periphery = SystemPeriphery; +// Default implementation uses no init file +impl InitFileGenerator for SystemConfig {} - fn system(&self) -> &SystemConfig { - self - } - fn system_mut(&mut self) -> &mut SystemConfig { - self +#[derive(Debug, Serialize, Deserialize, Clone, Copy, new)] +pub struct AddressSpaceHostConfig { + /// The number of memory cells in each address space, where a memory cell refers to a single + /// addressable unit of memory as defined by the ISA. + pub num_cells: usize, + /// Minimum block size for memory accesses supported. This is a property of the address space + /// that is determined by the ISA. + /// + /// **Note**: Block size is in terms of memory cells. + pub min_block_size: usize, + pub layout: MemoryCellType, +} + +impl AddressSpaceHostConfig { + /// The total size in bytes of the address space in a linear memory layout. + pub fn size(&self) -> usize { + self.num_cells * self.layout.size() } +} - fn create_chip_complex( - &self, - ) -> Result, VmInventoryError> { - let complex = SystemComplex::new(self.clone()); - Ok(complex) +pub(crate) const MAX_CELL_BYTE_SIZE: usize = 8; + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)] +pub enum MemoryCellType { + Null, + U8, + U16, + /// Represented in little-endian format. + U32, + /// `size` is the size in bytes of the native field type. This should not exceed 8. + Native { + size: u8, + }, +} + +impl MemoryCellType { + pub fn native32() -> Self { + Self::Native { + size: size_of::() as u8, + } } } -// Default implementation uses no init file -impl InitFileGenerator for SystemConfig {} +impl AddressSpaceHostLayout for MemoryCellType { + fn size(&self) -> usize { + match self { + Self::Null => 1, // to avoid divide by zero + Self::U8 => size_of::(), + Self::U16 => size_of::(), + Self::U32 => size_of::(), + Self::Native { size } => *size as usize, + } + } + + /// # Safety + /// - This function must only be called when `value` is guaranteed to be of size `self.size()`. + /// - Alignment of `value` must be a multiple of the alignment of `F`. + /// - The field type `F` must be plain old data. + /// + /// # Panics + /// If the value is of integer type and overflows the field. + unsafe fn to_field(&self, value: &[u8]) -> F { + match self { + Self::Null => unreachable!(), + Self::U8 => F::from_canonical_u8(*value.get_unchecked(0)), + Self::U16 => F::from_canonical_u16(core::ptr::read(value.as_ptr() as *const u16)), + Self::U32 => F::from_canonical_u32(core::ptr::read(value.as_ptr() as *const u32)), + Self::Native { .. } => core::ptr::read(value.as_ptr() as *const F), + } + } +} diff --git a/crates/vm/src/arch/execution.rs b/crates/vm/src/arch/execution.rs index 4edc88d355..8c6eba0675 100644 --- a/crates/vm/src/arch/execution.rs +++ b/crates/vm/src/arch/execution.rs @@ -1,5 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; - +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, PhantomDiscriminant, VmOpcode, @@ -8,32 +7,33 @@ use openvm_stark_backend::{ interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, p3_field::FieldAlgebra, }; +use rand::rngs::StdRng; use serde::{Deserialize, Serialize}; use thiserror::Error; -use super::Streams; -use crate::system::{memory::MemoryController, program::ProgramBus}; - -pub type Result = std::result::Result; +use super::{execution_mode::E1ExecutionCtx, Streams, VmExecState}; +#[cfg(feature = "metrics")] +use crate::metrics::VmMetrics; +use crate::{ + arch::{execution_mode::E2ExecutionCtx, ExecutorInventoryError, MatrixRecordArena}, + system::{ + memory::online::{GuestMemory, TracingMemory}, + program::ProgramBus, + }, +}; #[derive(Error, Debug)] pub enum ExecutionError { - #[error("execution failed at pc {pc}")] - Fail { pc: u32 }, - #[error("pc {pc} not found for program of length {program_len}, with pc_base {pc_base} and step = {step}")] - PcNotFound { - pc: u32, - step: u32, - pc_base: u32, - program_len: usize, - }, - #[error("pc {pc} out of bounds for program of length {program_len}, with pc_base {pc_base} and step = {step}")] + #[error("execution failed at pc {pc}, err: {msg}")] + Fail { pc: u32, msg: &'static str }, + #[error("pc {pc} out of bounds for program of length {program_len}, with pc_base {pc_base}")] PcOutOfBounds { pc: u32, - step: u32, pc_base: u32, program_len: usize, }, + #[error("unreachable instruction at pc {0}")] + Unreachable(u32), #[error("at pc {pc}, opcode {opcode} was not enabled")] DisabledOperation { pc: u32, opcode: VmOpcode }, #[error("at pc = {pc}")] @@ -66,51 +66,109 @@ pub enum ExecutionError { DidNotTerminate, #[error("program exit code {0}")] FailedWithExitCode(u32), + #[error("trace buffer out of bounds: requested {requested} but capacity is {capacity}")] + TraceBufferOutOfBounds { requested: usize, capacity: usize }, + #[error("inventory error: {0}")] + Inventory(#[from] ExecutorInventoryError), + #[error("static program error: {0}")] + Static(#[from] StaticProgramError), +} + +/// Errors in the program that can be statically analyzed before runtime. +#[derive(Error, Debug)] +pub enum StaticProgramError { + #[error("invalid instruction at pc {0}")] + InvalidInstruction(u32), + #[error("Too many executors")] + TooManyExecutors, + #[error("at pc {pc}, opcode {opcode} was not enabled")] + DisabledOperation { pc: u32, opcode: VmOpcode }, + #[error("Executor not found for opcode {opcode}")] + ExecutorNotFound { opcode: VmOpcode }, +} + +/// Function pointer for interpreter execution with function signature `(pre_compute, exec_state)`. +/// The `pre_compute: &[u8]` is a pre-computed buffer of data corresponding to a single instruction. +/// The contents of `pre_compute` are determined from the program code as specified by the +/// [Executor] and [MeteredExecutor] traits. +pub type ExecuteFunc = unsafe fn(&[u8], &mut VmExecState); + +/// Trait for pure execution via a host interpreter. The trait methods provide the methods to +/// pre-process the program code into function pointers which operate on `pre_compute` instruction +/// data. +// @dev: In the codebase this is sometimes referred to as (E1). +pub trait Executor { + fn pre_compute_size(&self) -> usize; + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx; } -pub trait InstructionExecutor { +/// Trait for metered execution via a host interpreter. The trait methods provide the methods to +/// pre-process the program code into function pointers which operate on `pre_compute` instruction +/// data which contains auxiliary data (e.g., corresponding AIR ID) for metering purposes. +// @dev: In the codebase this is sometimes referred to as (E2). +pub trait MeteredExecutor { + fn metered_pre_compute_size(&self) -> usize; + + fn metered_pre_compute( + &self, + air_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx; +} + +// TODO[jpw]: Avoid Clone by making executors stateless? +/// Trait for preflight execution via a host interpreter. The trait methods allow execution of +/// instructions via enum dispatch within an interpreter. This execution is specialized to record +/// "records" of execution which will be ingested later for trace matrix generation. The records are +/// stored in a record arena, which is provided in the [VmStateMut] argument. +// @dev: In the codebase this is sometimes referred to as (E3). +pub trait PreflightExecutor>: Clone { /// Runtime execution of the instruction, if the instruction is owned by the /// current instance. May internally store records of this call for later trace generation. fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result>; + ) -> Result<(), ExecutionError>; /// For display purposes. From absolute opcode as `usize`, return the string name of the opcode /// if it is a supported opcode by the present executor. fn get_opcode_name(&self, opcode: usize) -> String; } -impl> InstructionExecutor for RefCell { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - prev_state: ExecutionState, - ) -> Result> { - self.borrow_mut().execute(memory, instruction, prev_state) - } - - fn get_opcode_name(&self, opcode: usize) -> String { - self.borrow().get_opcode_name(opcode) - } +/// Global VM state accessible during instruction execution. +/// The state is generic in guest memory `MEM` and additional record arena `RA`. +/// The host state is execution context specific. +#[derive(derive_new::new)] +pub struct VmStateMut<'a, F, MEM, RA> { + pub pc: &'a mut u32, + pub memory: &'a mut MEM, + pub streams: &'a mut Streams, + pub rng: &'a mut StdRng, + pub ctx: &'a mut RA, + #[cfg(feature = "metrics")] + pub metrics: &'a mut VmMetrics, } -impl> InstructionExecutor for Rc> { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - prev_state: ExecutionState, - ) -> Result> { - self.borrow_mut().execute(memory, instruction, prev_state) - } - - fn get_opcode_name(&self, opcode: usize) -> String { - self.borrow().get_opcode_name(opcode) - } +/// Wrapper type for metered pre-computed data, which is always an AIR index together with the +/// pre-computed data for pure execution. +#[derive(Clone, AlignedBytesBorrow)] +#[repr(C)] +pub struct E2PreCompute { + pub chip_idx: u32, + pub data: DATA, } #[repr(C)] @@ -322,14 +380,16 @@ impl From<(u32, Option)> for PcIncOrSet { /// /// Phantom sub-instructions are only allowed to use operands /// `a,b` and `c_upper = c.as_canonical_u32() >> 16`. -pub trait PhantomSubExecutor: Send { +#[allow(clippy::too_many_arguments)] +pub trait PhantomSubExecutor: Send + Sync { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + rng: &mut StdRng, discriminant: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, c_upper: u16, ) -> eyre::Result<()>; } diff --git a/crates/vm/src/arch/execution_mode/e1.rs b/crates/vm/src/arch/execution_mode/e1.rs new file mode 100644 index 0000000000..920c4a81fb --- /dev/null +++ b/crates/vm/src/arch/execution_mode/e1.rs @@ -0,0 +1,35 @@ +use crate::{ + arch::{execution_mode::E1ExecutionCtx, VmExecState}, + system::memory::online::GuestMemory, +}; + +pub struct E1Ctx { + instret_end: u64, +} + +impl E1Ctx { + pub fn new(instret_end: Option) -> Self { + E1Ctx { + instret_end: if let Some(end) = instret_end { + end + } else { + u64::MAX + }, + } + } +} + +impl Default for E1Ctx { + fn default() -> Self { + Self::new(None) + } +} + +impl E1ExecutionCtx for E1Ctx { + #[inline(always)] + fn on_memory_operation(&mut self, _address_space: u32, _ptr: u32, _size: u32) {} + #[inline(always)] + fn should_suspend(vm_state: &mut VmExecState) -> bool { + vm_state.instret >= vm_state.ctx.instret_end + } +} diff --git a/crates/vm/src/arch/execution_mode/metered/ctx.rs b/crates/vm/src/arch/execution_mode/metered/ctx.rs new file mode 100644 index 0000000000..77b5a5a7d4 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/ctx.rs @@ -0,0 +1,270 @@ +use std::num::NonZero; + +use getset::WithSetters; +use openvm_instructions::riscv::{ + RV32_IMM_AS, RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS, +}; + +use super::{ + memory_ctx::MemoryCtx, + segment_ctx::{Segment, SegmentationCtx}, +}; +use crate::{ + arch::{ + execution_mode::{ + metered::segment_ctx::SegmentationLimits, E1ExecutionCtx, E2ExecutionCtx, + }, + VmExecState, + }, + system::memory::{dimensions::MemoryDimensions, online::GuestMemory}, +}; + +pub const DEFAULT_PAGE_BITS: usize = 6; +pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000; + +#[derive(Clone, Debug, WithSetters)] +pub struct MeteredCtx { + pub trace_heights: Vec, + pub is_trace_height_constant: Vec, + + pub memory_ctx: MemoryCtx, + pub segmentation_ctx: SegmentationCtx, + pub continuations_enabled: bool, + instret_last_segment_check: u64, + #[getset(set_with = "pub")] + segment_check_insns: u64, +} + +impl MeteredCtx { + // Note[jpw]: this is indeed too many arguments, prefer to use `build_metered_ctx` in + // `VmExecutor` or `VirtualMachine`. + #[allow(clippy::too_many_arguments)] + pub fn new( + constant_trace_heights: Vec>, + has_public_values_chip: bool, + continuations_enabled: bool, + as_byte_alignment_bits: Vec, + memory_dimensions: MemoryDimensions, + air_names: Vec, + widths: Vec, + interactions: Vec, + segmentation_limits: SegmentationLimits, + ) -> Self { + let (trace_heights, is_trace_height_constant): (Vec, Vec) = + constant_trace_heights + .iter() + .map(|&constant_height| { + if let Some(height) = constant_height { + (height as u32, true) + } else { + (0, false) + } + }) + .unzip(); + + let memory_ctx = MemoryCtx::new( + has_public_values_chip, + continuations_enabled, + as_byte_alignment_bits, + memory_dimensions, + ); + + // Assert that the indices are correct + debug_assert!( + air_names[memory_ctx.boundary_idx].contains("Boundary"), + "air_name={}", + air_names[memory_ctx.boundary_idx] + ); + if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index { + debug_assert!( + air_names[merkle_tree_index].contains("Merkle"), + "air_name={}", + air_names[merkle_tree_index] + ); + } + debug_assert!( + air_names[memory_ctx.adapter_offset].contains("AccessAdapterAir<2>"), + "air_name={}", + air_names[memory_ctx.adapter_offset] + ); + + let segmentation_ctx = + SegmentationCtx::new(air_names, widths, interactions, segmentation_limits); + + let mut ctx = Self { + trace_heights, + is_trace_height_constant, + memory_ctx, + segmentation_ctx, + continuations_enabled, + segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS, + instret_last_segment_check: 0, + }; + if !continuations_enabled { + // force single segment + ctx.segment_check_insns = u64::MAX; + } + + // Add merkle height contributions for all registers + ctx.add_register_merkle_heights(); + + ctx + } + + #[inline(always)] + fn add_register_merkle_heights(&mut self) { + if self.continuations_enabled { + self.memory_ctx.update_boundary_merkle_heights( + RV32_REGISTER_AS, + 0, + (RV32_NUM_REGISTERS * RV32_REGISTER_NUM_LIMBS) as u32, + ); + } + } + + pub fn with_max_trace_height(mut self, max_trace_height: u32) -> Self { + self.segmentation_ctx.set_max_trace_height(max_trace_height); + let max_check_freq = (max_trace_height / 2) as u64; + if max_check_freq < self.segment_check_insns { + self.segment_check_insns = max_check_freq; + } + self + } + + pub fn with_max_cells(mut self, max_cells: usize) -> Self { + self.segmentation_ctx.set_max_cells(max_cells); + self + } + + pub fn with_max_interactions(mut self, max_interactions: usize) -> Self { + self.segmentation_ctx.set_max_interactions(max_interactions); + self + } + + pub fn segments(&self) -> &[Segment] { + &self.segmentation_ctx.segments + } + + pub fn into_segments(self) -> Vec { + self.segmentation_ctx.segments + } + + fn reset_segment(&mut self) { + self.memory_ctx.clear(); + for (i, &is_constant) in self.is_trace_height_constant.iter().enumerate() { + if !is_constant { + self.trace_heights[i] = 0; + } + } + + // Add merkle height contributions for all registers + self.add_register_merkle_heights(); + } + + #[inline(always)] + pub fn check_and_segment(&mut self, instret: u64) { + let threshold = self + .instret_last_segment_check + .wrapping_add(self.segment_check_insns); + debug_assert!( + threshold >= self.instret_last_segment_check, + "overflow in segment check threshold calculation" + ); + if instret < threshold { + return; + } + + self.memory_ctx + .lazy_update_boundary_heights(&mut self.trace_heights); + let did_segment = self.segmentation_ctx.check_and_segment( + instret, + &self.trace_heights, + &self.is_trace_height_constant, + ); + + self.instret_last_segment_check = instret; + if did_segment { + self.reset_segment(); + } + } + + #[allow(dead_code)] + pub fn print_heights(&self) { + println!("{:>10} {:<30}", "Height", "Air Name"); + println!("{}", "-".repeat(42)); + for (i, height) in self.trace_heights.iter().enumerate() { + let air_name = self + .segmentation_ctx + .air_names + .get(i) + .map(|s| s.as_str()) + .unwrap_or("Unknown"); + println!("{:>10} {:<30}", height, air_name); + } + } +} + +impl E1ExecutionCtx for MeteredCtx { + #[inline(always)] + fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) { + debug_assert!( + address_space != RV32_IMM_AS, + "address space must not be immediate" + ); + debug_assert!(size > 0, "size must be greater than 0, got {}", size); + debug_assert!( + size.is_power_of_two(), + "size must be a power of 2, got {}", + size + ); + + // Handle access adapter updates + // SAFETY: size passed is always a non-zero power of 2 + let size_bits = unsafe { NonZero::new_unchecked(size).ilog2() }; + self.memory_ctx + .update_adapter_heights(&mut self.trace_heights, address_space, size_bits); + + // Handle merkle tree updates + if address_space != RV32_REGISTER_AS { + self.memory_ctx + .update_boundary_merkle_heights(address_space, ptr, size); + } + } + + #[inline(always)] + fn should_suspend(vm_state: &mut VmExecState) -> bool { + // E2 always runs until termination. Here we use the function as a hook called every + // instruction. + vm_state.ctx.check_and_segment(vm_state.instret); + false + } + + #[inline(always)] + fn on_terminate(vm_state: &mut VmExecState) { + vm_state + .ctx + .memory_ctx + .lazy_update_boundary_heights(&mut vm_state.ctx.trace_heights); + vm_state + .ctx + .segmentation_ctx + .segment(vm_state.instret, &vm_state.ctx.trace_heights); + } +} + +impl E2ExecutionCtx for MeteredCtx { + #[inline(always)] + fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) { + debug_assert!( + chip_idx < self.trace_heights.len(), + "chip_idx out of bounds" + ); + // SAFETY: chip_idx is created in executor_idx_to_air_idx and is always within bounds + unsafe { + *self.trace_heights.get_unchecked_mut(chip_idx) = self + .trace_heights + .get_unchecked(chip_idx) + .wrapping_add(height_delta); + } + } +} diff --git a/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs new file mode 100644 index 0000000000..4facf544cd --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs @@ -0,0 +1,336 @@ +use crate::{ + arch::PUBLIC_VALUES_AIR_ID, + system::memory::{dimensions::MemoryDimensions, CHUNK}, +}; + +#[derive(Clone, Debug)] +pub struct BitSet { + words: Box<[u64]>, +} + +impl BitSet { + pub fn new(num_bits: usize) -> Self { + Self { + words: vec![0; num_bits.div_ceil(u64::BITS as usize)].into_boxed_slice(), + } + } + + #[inline(always)] + pub fn insert(&mut self, index: usize) -> bool { + let word_index = index >> 6; + let bit_index = index & 63; + let mask = 1u64 << bit_index; + + debug_assert!(word_index < self.words.len(), "BitSet index out of bounds"); + + // SAFETY: word_index is derived from a memory address that is bounds-checked + // during memory access. The bitset is sized to accommodate all valid + // memory addresses, so word_index is always within bounds. + let word = unsafe { self.words.get_unchecked_mut(word_index) }; + let was_set = (*word & mask) != 0; + *word |= mask; + !was_set + } + + /// Set all bits within [start, end) to 1, return the number of flipped bits. + /// Assumes start < end and end <= self.words.len() * 64. + #[inline(always)] + pub fn insert_range(&mut self, start: usize, end: usize) -> usize { + debug_assert!(start < end); + debug_assert!(end <= self.words.len() * 64, "BitSet range out of bounds"); + + let mut ret = 0; + let start_word_index = start >> 6; + let end_word_index = (end - 1) >> 6; + let start_bit = (start & 63) as u32; + + if start_word_index == end_word_index { + let end_bit = ((end - 1) & 63) as u32 + 1; + let mask_bits = end_bit - start_bit; + let mask = (u64::MAX >> (64 - mask_bits)) << start_bit; + // SAFETY: Caller ensures start < end and end <= self.words.len() * 64, + // so start_word_index < self.words.len() + let word = unsafe { self.words.get_unchecked_mut(start_word_index) }; + ret += mask_bits - (*word & mask).count_ones(); + *word |= mask; + } else { + let end_bit = (end & 63) as u32; + let mask_bits = 64 - start_bit; + let mask = u64::MAX << start_bit; + // SAFETY: Caller ensures start < end and end <= self.words.len() * 64, + // so start_word_index < self.words.len() + let start_word = unsafe { self.words.get_unchecked_mut(start_word_index) }; + ret += mask_bits - (*start_word & mask).count_ones(); + *start_word |= mask; + + let mask_bits = end_bit; + let mask = if end_bit == 0 { + 0 + } else { + u64::MAX >> (64 - end_bit) + }; + // SAFETY: Caller ensures end <= self.words.len() * 64, so + // end_word_index < self.words.len() + let end_word = unsafe { self.words.get_unchecked_mut(end_word_index) }; + ret += mask_bits - (*end_word & mask).count_ones(); + *end_word |= mask; + } + + if start_word_index + 1 < end_word_index { + for i in (start_word_index + 1)..end_word_index { + // SAFETY: Caller ensures proper start and end, so i is within bounds + // of self.words.len() + let word = unsafe { self.words.get_unchecked_mut(i) }; + ret += word.count_zeros(); + *word = u64::MAX; + } + } + ret as usize + } + + #[inline(always)] + pub fn clear(&mut self) { + // SAFETY: words is valid for self.words.len() elements + unsafe { + std::ptr::write_bytes(self.words.as_mut_ptr(), 0, self.words.len()); + } + } +} + +#[derive(Clone, Debug)] +pub struct MemoryCtx { + pub page_indices: BitSet, + memory_dimensions: MemoryDimensions, + as_byte_alignment_bits: Vec, + pub boundary_idx: usize, + pub merkle_tree_index: Option, + pub adapter_offset: usize, + chunk: u32, + chunk_bits: u32, + page_access_count: usize, + // Note: 32 is the maximum access adapter size. + addr_space_access_count: Vec, +} + +impl MemoryCtx { + pub fn new( + has_public_values_chip: bool, + continuations_enabled: bool, + as_byte_alignment_bits: Vec, + memory_dimensions: MemoryDimensions, + ) -> Self { + let boundary_idx = if has_public_values_chip { + PUBLIC_VALUES_AIR_ID + 1 + } else { + PUBLIC_VALUES_AIR_ID + }; + + let merkle_tree_index = if continuations_enabled { + Some(boundary_idx + 1) + } else { + None + }; + + let adapter_offset = if continuations_enabled { + boundary_idx + 2 + } else { + boundary_idx + 1 + }; + + let chunk = if continuations_enabled { + // Persistent memory uses CHUNK-sized blocks + CHUNK as u32 + } else { + // Volatile memory uses single units + 1 + }; + + let chunk_bits = chunk.ilog2(); + let merkle_height = memory_dimensions.overall_height(); + + Self { + // Address height already considers `chunk_bits`. + page_indices: BitSet::new(1 << (merkle_height.saturating_sub(PAGE_BITS))), + as_byte_alignment_bits, + boundary_idx, + merkle_tree_index, + adapter_offset, + chunk, + chunk_bits, + memory_dimensions, + page_access_count: 0, + addr_space_access_count: vec![0; (1 << memory_dimensions.addr_space_height) + 1], + } + } + + #[inline(always)] + pub fn clear(&mut self) { + self.page_indices.clear(); + } + + /// For each memory access, record the minimal necessary data to update heights of + /// memory-related chips. The actual height updates happen during segment checks. The + /// implementation is in `lazy_update_boundary_heights`. + #[inline(always)] + pub(crate) fn update_boundary_merkle_heights( + &mut self, + address_space: u32, + ptr: u32, + size: u32, + ) { + debug_assert!((address_space as usize) < self.addr_space_access_count.len()); + + let num_blocks = (size + self.chunk - 1) >> self.chunk_bits; + let start_chunk_id = ptr >> self.chunk_bits; + let start_block_id = if self.chunk == 1 { + start_chunk_id + } else { + self.memory_dimensions + .label_to_index((address_space, start_chunk_id)) as u32 + }; + // Because `self.chunk == 1 << self.chunk_bits` + let end_block_id = start_block_id + num_blocks; + let start_page_id = start_block_id >> PAGE_BITS; + let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1; + + for page_id in start_page_id..end_page_id { + if self.page_indices.insert(page_id as usize) { + self.page_access_count += 1; + // SAFETY: address_space passed is usually a hardcoded constant or derived from an + // Instruction where it is bounds checked before passing + unsafe { + *self + .addr_space_access_count + .get_unchecked_mut(address_space as usize) += 1; + } + } + } + } + + #[inline(always)] + pub fn update_adapter_heights( + &mut self, + trace_heights: &mut [u32], + address_space: u32, + size_bits: u32, + ) { + self.update_adapter_heights_batch(trace_heights, address_space, size_bits, 1); + } + + #[inline(always)] + pub fn update_adapter_heights_batch( + &self, + trace_heights: &mut [u32], + address_space: u32, + size_bits: u32, + num: u32, + ) { + debug_assert!((address_space as usize) < self.as_byte_alignment_bits.len()); + + // SAFETY: address_space passed is usually a hardcoded constant or derived from an + // Instruction where it is bounds checked before passing + let align_bits = unsafe { + *self + .as_byte_alignment_bits + .get_unchecked(address_space as usize) + }; + debug_assert!( + align_bits as u32 <= size_bits, + "align_bits ({}) must be <= size_bits ({})", + align_bits, + size_bits + ); + + for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() { + let adapter_idx = self.adapter_offset + adapter_bits as usize - 1; + debug_assert!(adapter_idx < trace_heights.len()); + // SAFETY: trace_heights is initialized taking access adapters into account + unsafe { + *trace_heights.get_unchecked_mut(adapter_idx) += + num << (size_bits - adapter_bits + 1); + } + } + } + + // TODO(ayush): check if batching is actually even faster + /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip. + #[inline(always)] + pub(crate) fn lazy_update_boundary_heights(&mut self, trace_heights: &mut [u32]) { + debug_assert!(self.boundary_idx < trace_heights.len()); + + // On page fault, assume we add all leaves in a page + let leaves = (self.page_access_count << PAGE_BITS) as u32; + // SAFETY: boundary_idx is a compile time constant within bounds + unsafe { + *trace_heights.get_unchecked_mut(self.boundary_idx) += leaves; + } + + if let Some(merkle_tree_idx) = self.merkle_tree_index { + debug_assert!(merkle_tree_idx < trace_heights.len()); + debug_assert!(trace_heights.len() >= 2); + + let poseidon2_idx = trace_heights.len() - 2; + // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds + unsafe { + *trace_heights.get_unchecked_mut(poseidon2_idx) += leaves * 2; + } + + let merkle_height = self.memory_dimensions.overall_height(); + let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32; + // SAFETY: merkle_tree_idx is guaranteed to be in bounds + unsafe { + *trace_heights.get_unchecked_mut(poseidon2_idx) += nodes * 2; + *trace_heights.get_unchecked_mut(merkle_tree_idx) += nodes * 2; + } + } + self.page_access_count = 0; + + for address_space in 0..self.addr_space_access_count.len() { + // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds + let x = unsafe { *self.addr_space_access_count.get_unchecked(address_space) }; + if x > 0 { + // After finalize, we'll need to read it in chunk-sized units for the merkle chip + self.update_adapter_heights_batch( + trace_heights, + address_space as u32, + self.chunk_bits, + (x << PAGE_BITS) as u32, + ); + // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds + unsafe { + *self + .addr_space_access_count + .get_unchecked_mut(address_space) = 0; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_bitset_insert_range() { + // 513 bits + let mut bit_set = BitSet::new(8 * 64 + 1); + let num_flips = bit_set.insert_range(2, 29); + assert_eq!(num_flips, 27); + let num_flips = bit_set.insert_range(1, 31); + assert_eq!(num_flips, 3); + + let num_flips = bit_set.insert_range(32, 65); + assert_eq!(num_flips, 33); + let num_flips = bit_set.insert_range(0, 66); + assert_eq!(num_flips, 3); + let num_flips = bit_set.insert_range(0, 66); + assert_eq!(num_flips, 0); + + let num_flips = bit_set.insert_range(256, 320); + assert_eq!(num_flips, 64); + let num_flips = bit_set.insert_range(256, 377); + assert_eq!(num_flips, 57); + let num_flips = bit_set.insert_range(100, 513); + assert_eq!(num_flips, 413 - 121); + } +} diff --git a/crates/vm/src/arch/execution_mode/metered/mod.rs b/crates/vm/src/arch/execution_mode/metered/mod.rs new file mode 100644 index 0000000000..9bd0799194 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/mod.rs @@ -0,0 +1,6 @@ +pub mod ctx; +pub mod memory_ctx; +pub mod segment_ctx; + +pub use ctx::MeteredCtx; +pub use segment_ctx::Segment; diff --git a/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs b/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs new file mode 100644 index 0000000000..fbbe3700cf --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs @@ -0,0 +1,225 @@ +use getset::WithSetters; +use openvm_stark_backend::p3_field::PrimeField32; +use p3_baby_bear::BabyBear; +use serde::{Deserialize, Serialize}; + +const DEFAULT_MAX_TRACE_HEIGHT: u32 = (1 << 23) - 10000; +const DEFAULT_MAX_CELLS: usize = 2_000_000_000; // 2B +const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize; + +#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)] +pub struct Segment { + pub instret_start: u64, + pub num_insns: u64, + pub trace_heights: Vec, +} + +#[derive(Clone, Copy, Debug, WithSetters)] +pub struct SegmentationLimits { + #[getset(set_with = "pub")] + pub max_trace_height: u32, + #[getset(set_with = "pub")] + pub max_cells: usize, + #[getset(set_with = "pub")] + pub max_interactions: usize, +} + +impl Default for SegmentationLimits { + fn default() -> Self { + Self { + max_trace_height: DEFAULT_MAX_TRACE_HEIGHT, + max_cells: DEFAULT_MAX_CELLS, + max_interactions: DEFAULT_MAX_INTERACTIONS, + } + } +} + +#[derive(Clone, Debug)] +pub struct SegmentationCtx { + pub segments: Vec, + pub(crate) air_names: Vec, + widths: Vec, + interactions: Vec, + pub(crate) segmentation_limits: SegmentationLimits, +} + +impl SegmentationCtx { + pub fn new( + air_names: Vec, + widths: Vec, + interactions: Vec, + segmentation_limits: SegmentationLimits, + ) -> Self { + assert_eq!(air_names.len(), widths.len()); + assert_eq!(air_names.len(), interactions.len()); + + Self { + segments: Vec::new(), + air_names, + widths, + interactions, + segmentation_limits, + } + } + + pub fn new_with_default_segmentation_limits( + air_names: Vec, + widths: Vec, + interactions: Vec, + ) -> Self { + assert_eq!(air_names.len(), widths.len()); + assert_eq!(air_names.len(), interactions.len()); + + Self { + segments: Vec::new(), + air_names, + widths, + interactions, + segmentation_limits: SegmentationLimits::default(), + } + } + + pub fn set_max_trace_height(&mut self, max_trace_height: u32) { + self.segmentation_limits.max_trace_height = max_trace_height; + } + + pub fn set_max_cells(&mut self, max_cells: usize) { + self.segmentation_limits.max_cells = max_cells; + } + + pub fn set_max_interactions(&mut self, max_interactions: usize) { + self.segmentation_limits.max_interactions = max_interactions; + } + + /// Calculate the total cells used based on trace heights and widths + #[inline(always)] + fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize { + debug_assert_eq!(trace_heights.len(), self.widths.len()); + + // SAFETY: Length equality is asserted during initialization + let widths_slice = unsafe { self.widths.get_unchecked(..trace_heights.len()) }; + + trace_heights + .iter() + .zip(widths_slice) + .map(|(&height, &width)| height as usize * width) + .sum() + } + + /// Calculate the total interactions based on trace heights and interaction counts + #[inline(always)] + fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize { + debug_assert_eq!(trace_heights.len(), self.interactions.len()); + + // SAFETY: Length equality is asserted during initialization + let interactions_slice = unsafe { self.interactions.get_unchecked(..trace_heights.len()) }; + + trace_heights + .iter() + .zip(interactions_slice) + // We add 1 for the zero messages from the padding rows + .map(|(&height, &interactions)| (height + 1) as usize * interactions) + .sum() + } + + #[inline(always)] + fn should_segment( + &self, + instret: u64, + trace_heights: &[u32], + is_trace_height_constant: &[bool], + ) -> bool { + debug_assert_eq!(trace_heights.len(), is_trace_height_constant.len()); + debug_assert_eq!(trace_heights.len(), self.air_names.len()); + + let instret_start = self + .segments + .last() + .map_or(0, |s| s.instret_start + s.num_insns); + let num_insns = instret - instret_start; + + // Segment should contain at least one cycle + if num_insns == 0 { + return false; + } + + for (i, (height, is_constant)) in trace_heights + .iter() + .zip(is_trace_height_constant.iter()) + .enumerate() + { + // Only segment if the height is not constant and exceeds the maximum height + if !is_constant && *height > self.segmentation_limits.max_trace_height { + let air_name = &self.air_names[i]; + tracing::info!( + "Segment {:2} | instret {:9} | chip {} ({}) height ({:8}) > max ({:8})", + self.segments.len(), + instret, + i, + air_name, + height, + self.segmentation_limits.max_trace_height + ); + return true; + } + } + + let total_cells = self.calculate_total_cells(trace_heights); + if total_cells > self.segmentation_limits.max_cells { + tracing::info!( + "Segment {:2} | instret {:9} | total cells ({:10}) > max ({:10})", + self.segments.len(), + instret, + total_cells, + self.segmentation_limits.max_cells + ); + return true; + } + + let total_interactions = self.calculate_total_interactions(trace_heights); + if total_interactions > self.segmentation_limits.max_interactions { + tracing::info!( + "Segment {:2} | instret {:9} | total interactions ({:11}) > max ({:11})", + self.segments.len(), + instret, + total_interactions, + self.segmentation_limits.max_interactions + ); + return true; + } + + false + } + + #[inline(always)] + pub fn check_and_segment( + &mut self, + instret: u64, + trace_heights: &[u32], + is_trace_height_constant: &[bool], + ) -> bool { + let ret = self.should_segment(instret, trace_heights, is_trace_height_constant); + if ret { + self.segment(instret, trace_heights); + } + ret + } + + /// Try segment if there is at least one cycle + #[inline(always)] + pub fn segment(&mut self, instret: u64, trace_heights: &[u32]) { + let instret_start = self + .segments + .last() + .map_or(0, |s| s.instret_start + s.num_insns); + let num_insns = instret - instret_start; + + debug_assert!(num_insns > 0, "Segment should contain at least one cycle"); + + self.segments.push(Segment { + instret_start, + num_insns, + trace_heights: trace_heights.to_vec(), + }); + } +} diff --git a/crates/vm/src/arch/execution_mode/mod.rs b/crates/vm/src/arch/execution_mode/mod.rs new file mode 100644 index 0000000000..187082e630 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/mod.rs @@ -0,0 +1,15 @@ +use crate::{arch::VmExecState, system::memory::online::GuestMemory}; + +pub mod e1; +pub mod metered; +pub mod tracegen; + +pub trait E1ExecutionCtx: Sized { + fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32); + fn should_suspend(vm_state: &mut VmExecState) -> bool; + fn on_terminate(_vm_state: &mut VmExecState) {} +} + +pub trait E2ExecutionCtx: E1ExecutionCtx { + fn on_height_change(&mut self, chip_idx: usize, height_delta: u32); +} diff --git a/crates/vm/src/arch/execution_mode/tracegen.rs b/crates/vm/src/arch/execution_mode/tracegen.rs new file mode 100644 index 0000000000..d4a611dcac --- /dev/null +++ b/crates/vm/src/arch/execution_mode/tracegen.rs @@ -0,0 +1,24 @@ +use crate::arch::Arena; + +pub struct TracegenCtx { + pub arenas: Vec, + pub instret_end: Option, +} + +impl TracegenCtx { + /// `capacities` is list of `(height, width)` dimensions for each arena, indexed by AIR index. + /// The length of `capacities` must equal the number of AIRs. + /// Here `height` will always mean an overestimate of the trace height for that AIR, while + /// `width` may have different meanings depending on the `RA` type. + pub fn new_with_capacity(capacities: &[(usize, usize)], instret_end: Option) -> Self { + let arenas = capacities + .iter() + .map(|&(height, main_width)| RA::with_capacity(height, main_width)) + .collect(); + + Self { + arenas, + instret_end, + } + } +} diff --git a/crates/vm/src/arch/extensions.rs b/crates/vm/src/arch/extensions.rs index adda318f6a..7d0bdcf52e 100644 --- a/crates/vm/src/arch/extensions.rs +++ b/crates/vm/src/arch/extensions.rs @@ -1,56 +1,47 @@ +/// A full VM extension consists of three components, represented by sub-traits: +/// - [VmExecutionExtension] +/// - [VmCircuitExtension] +/// - [VmProverExtension]: there may be multiple implementations of `VmProverExtension` for the +/// same `VmCircuitExtension` for different prover backends. +/// +/// It is intended that `VmExecutionExtension` and `VmCircuitExtension` are implemented on the +/// same struct and `VmProverExtension` is implemented on a separate struct (usually a ZST) to +/// get around Rust orphan rules. use std::{ - any::{Any, TypeId}, - cell::RefCell, - iter::once, - sync::{Arc, Mutex}, + any::{type_name, Any}, + iter::{self, zip}, + sync::Arc, }; -use derive_more::derive::From; -use getset::Getters; -use itertools::{zip_eq, Itertools}; -#[cfg(feature = "bench-metrics")] -use metrics::counter; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; -use openvm_circuit_primitives::{ - utils::next_power_of_two_or_zero, - var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_instructions::{ - program::Program, LocalOpcode, PhantomDiscriminant, PublishOpcode, SystemOpcode, VmOpcode, +use getset::{CopyGetters, Getters}; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerAir, }; +use openvm_instructions::{PhantomDiscriminant, VmOpcode}; use openvm_stark_backend::{ - config::{Domain, StarkGenericConfig}, - interaction::{BusIndex, PermutationCheckBus}, - keygen::types::LinearConstraint, - p3_commit::PolynomialSpace, - p3_field::{FieldAlgebra, PrimeField32, TwoAdicField}, - p3_matrix::Matrix, - p3_util::log2_ceil_usize, - prover::types::{AirProofInput, CommittedTraceData, ProofInput}, - AirRef, Chip, ChipUsageGetter, + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + interaction::BusIndex, + keygen::types::MultiStarkProvingKey, + prover::{ + cpu::CpuBackend, + hal::ProverBackend, + types::{AirProvingContext, ProvingContext}, + }, + rap::AnyRap, + AirRef, AnyChip, Chip, }; -use p3_baby_bear::BabyBear; use rustc_hash::FxHashMap; -use serde::{Deserialize, Serialize}; - -use super::{ - vm_poseidon2_config, ExecutionBus, GenerationError, InstructionExecutor, PhantomSubExecutor, - Streams, SystemConfig, SystemTraceHeights, -}; -#[cfg(feature = "bench-metrics")] -use crate::metrics::VmMetrics; -use crate::system::{ - connector::VmConnectorChip, - memory::{ - offline_checker::{MemoryBridge, MemoryBus}, - MemoryController, MemoryImage, OfflineMemory, BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET, +use tracing::info_span; + +use super::{GenerationError, PhantomSubExecutor, SystemConfig}; +use crate::{ + arch::Arena, + system::{ + memory::{BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET}, + phantom::PhantomExecutor, + SystemAirInventory, SystemChipComplex, SystemRecords, }, - native_adapter::NativeAdapterChip, - phantom::PhantomChip, - poseidon2::Poseidon2PeripheryChip, - program::{ProgramBus, ProgramChip}, - public_values::{core::PublicValuesCoreChip, PublicValuesChip}, }; /// Global AIR ID in the VM circuit verifying key. @@ -67,240 +58,169 @@ pub const BOUNDARY_AIR_ID: usize = PUBLIC_VALUES_AIR_ID + 1 + BOUNDARY_AIR_OFFSE /// Merkle AIR commits start/final memory states. pub const MERKLE_AIR_ID: usize = CONNECTOR_AIR_ID + 1 + MERKLE_AIR_OFFSET; -/// Configuration for a processor extension. -/// -/// There are two associated types: -/// - `Executor`: enum for chips that are [`InstructionExecutor`]s. -/// - -pub trait VmExtension { - /// Enum of chips that implement [`InstructionExecutor`] for instruction execution. - /// `Executor` **must** implement `Chip` but the trait bound is omitted to omit the - /// `StarkGenericConfig` generic parameter. - type Executor: InstructionExecutor + AnyEnum; - /// Enum of periphery chips that do not implement [`InstructionExecutor`]. - /// `Periphery` **must** implement `Chip` but the trait bound is omitted to omit the - /// `StarkGenericConfig` generic parameter. - type Periphery: AnyEnum; - - fn build( +pub type ExecutorId = u32; + +// ======================= VM Extension Traits ============================= + +/// Extension of VM execution. Allows registration of custom execution of new instructions by +/// opcode. +pub trait VmExecutionExtension { + /// Enum of executor variants + type Executor: AnyEnum; + + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError>; + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError>; } -impl> VmExtension for Option { - type Executor = E::Executor; - type Periphery = E::Periphery; +/// Extension of the VM circuit. Allows _in-order_ addition of new AIRs with interactions. +pub trait VmCircuitExtension { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError>; +} - fn build( +/// Extension of VM trace generation. The generics are `E` for [StarkEngine], `RA` for record arena, +/// and `EXT` for execution and circuit extension. The returned vector should exactly match the +/// order of AIRs in [`VmCircuitExtension`] for this extension. +/// +/// Note that this trait differs from [VmExecutionExtension] and [VmCircuitExtension]. This trait is +/// meant to be implemented on a separate ZST which may be different for different [ProverBackend]s. +/// This is done to get around Rust orphan rules. +pub trait VmProverExtension +where + E: StarkEngine, + EXT: VmExecutionExtension> + VmCircuitExtension, +{ + /// We do not provide access to the [ExecutorInventory] because the process to find an executor + /// from the inventory seems more cumbersome than to simply re-construct any necessary executors + /// directly within this function implementation. + fn extend_prover( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - if let Some(extension) = self { - extension.build(builder) - } else { - Ok(VmInventory::new()) - } - } + extension: &EXT, + inventory: &mut ChipInventory, + ) -> Result<(), ChipInventoryError>; } -/// SystemPort combines system resources needed by most extensions -#[derive(Clone, Copy)] -pub struct SystemPort { - pub execution_bus: ExecutionBus, - pub program_bus: ProgramBus, - pub memory_bridge: MemoryBridge, +// ======================= Different Inventory Struct Definitions ============================= + +pub struct ExecutorInventory { + config: SystemConfig, + /// Lookup table to executor ID. + /// This is stored in a hashmap because it is _not_ expected to be used in the hot path. + /// A direct opcode -> executor mapping should be generated before runtime execution. + pub instruction_lookup: FxHashMap, + pub executors: Vec, + /// `ext_start[i]` will have the starting index in `executors` for extension `i` + ext_start: Vec, } -/// Builder for processing unit. Processing units extend an existing system unit. -pub struct VmInventoryBuilder<'a, F: PrimeField32> { - system_config: &'a SystemConfig, - system: &'a SystemBase, - streams: &'a Arc>>, - bus_idx_mgr: BusIndexManager, +// @dev: We need ExecutorInventoryBuilder separate from ExecutorInventory because of how +// ExecutorInventory::extend works: we want to build an inventory with some big E3 enum that +// includes both enum types E1, E2. However the interface for an ExecutionExtension will only know +// about the enum E2. In order to be able to allow access to the old executors with type E1 without +// referring to the type E1, we need to create this separate builder struct. +pub struct ExecutorInventoryBuilder<'a, F, E> { /// Chips that are already included in the chipset and may be used /// as dependencies. The order should be that depended-on chips are ordered /// **before** their dependents. - chips: Vec<&'a dyn AnyEnum>, + old_executors: Vec<&'a dyn AnyEnum>, + new_inventory: ExecutorInventory, + phantom_executors: FxHashMap>>, } -impl<'a, F: PrimeField32> VmInventoryBuilder<'a, F> { - pub fn new( - system_config: &'a SystemConfig, - system: &'a SystemBase, - streams: &'a Arc>>, - bus_idx_mgr: BusIndexManager, - ) -> Self { - Self { - system_config, - system, - streams, - bus_idx_mgr, - chips: Vec::new(), - } - } - - pub fn system_config(&self) -> &SystemConfig { - self.system_config - } - - pub fn system_base(&self) -> &SystemBase { - self.system - } - - pub fn system_port(&self) -> SystemPort { - SystemPort { - execution_bus: self.system_base().execution_bus(), - program_bus: self.system_base().program_bus(), - memory_bridge: self.system_base().memory_bridge(), - } - } - - pub fn new_bus_idx(&mut self) -> BusIndex { - self.bus_idx_mgr.new_bus_idx() - } - - /// Looks through built chips to see if there exists any of type `C` by downcasting. - /// Returns all chips of type `C` in the chipset. +#[derive(Clone, Getters, CopyGetters)] +pub struct AirInventory { + #[get = "pub"] + config: SystemConfig, + /// The system AIRs required by the circuit architecture. + #[get = "pub"] + system: SystemAirInventory, + /// List of all non-system AIRs in the circuit, in insertion order, which is the **reverse** of + /// the order they appear in the verifying key. /// - /// Note: the type `C` will usually be a smart pointer to a chip. - pub fn find_chip(&self) -> Vec<&C> { - self.chips - .iter() - .filter_map(|c| c.as_any_kind().downcast_ref()) - .collect() - } + /// Note that the system will ensure that the first AIR in the list is always the + /// [VariableRangeCheckerAir]. + #[get = "pub"] + ext_airs: Vec>, + /// `ext_start[i]` will have the starting index in `ext_airs` for extension `i` + ext_start: Vec, - /// The generic `F` must match that of the `PhantomChip`. - pub fn add_phantom_sub_executor + 'static>( - &self, - phantom_sub: PE, - discriminant: PhantomDiscriminant, - ) -> Result<(), VmInventoryError> { - let chip_ref: &RefCell> = - self.find_chip().first().expect("PhantomChip always exists"); - let mut chip = chip_ref.borrow_mut(); - let existing = chip.add_sub_executor(phantom_sub, discriminant); - if existing.is_some() { - return Err(VmInventoryError::PhantomSubExecutorExists { discriminant }); - } - Ok(()) - } - - /// Shareable streams. Clone to get a shared mutable reference. - pub fn streams(&self) -> &Arc>> { - self.streams - } - - fn add_chip(&mut self, chip: &'a E) { - self.chips.push(chip); - } -} - -#[derive(Clone, Debug)] -pub struct VmInventory { - /// Lookup table to executor ID. We store executors separately due to mutable borrow issues. - instruction_lookup: FxHashMap, - pub executors: Vec, - pub periphery: Vec

, - /// Order of insertion. The reverse of this will be the order the chips are destroyed - /// to generate trace. - insertion_order: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct VmInventoryTraceHeights { - pub chips: FxHashMap, + bus_idx_mgr: BusIndexManager, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, derive_new::new)] -pub struct VmComplexTraceHeights { - pub system: SystemTraceHeights, - pub inventory: VmInventoryTraceHeights, +#[derive(Clone, Copy, Debug, Default)] +pub struct BusIndexManager { + /// All existing buses use indices in [0, bus_idx_max) + bus_idx_max: BusIndex, } -type ExecutorId = usize; - -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub enum ChipId { - Executor(usize), - Periphery(usize), +// @dev: ChipInventory does not have the SystemChipComplex because that is custom depending on `PB`. +// The full struct with SystemChipComplex is VmChipComplex +#[derive(Getters)] +pub struct ChipInventory +where + SC: StarkGenericConfig, + PB: ProverBackend, +{ + /// Read-only view of AIRs, as constructed via the [VmCircuitExtension] trait. + #[get = "pub"] + airs: AirInventory, + /// Chips that are being built. + #[get = "pub"] + chips: Vec>>, + + /// Number of extensions that have chips added, including the current one that is still being + /// built. + cur_num_exts: usize, + /// Mapping from executor index to chip insertion index. Chips must be added in order so the + /// chip insertion index matches the AIR insertion index. Reminder: this is in **reverse** + /// order of the verifying key AIR ordering. + /// + /// Note: if public values chip exists, then it will be the first entry and point to + /// `usize::MAX`. This entry should never be used. + pub executor_idx_to_insertion_idx: Vec, } -#[derive(thiserror::Error, Debug)] -pub enum VmInventoryError { - #[error("Opcode {opcode} already owned by executor id {id}")] - ExecutorExists { opcode: VmOpcode, id: ExecutorId }, - #[error("Phantom discriminant {} already has sub-executor", .discriminant.0)] - PhantomSubExecutorExists { discriminant: PhantomDiscriminant }, - #[error("Chip {name} not found")] - ChipNotFound { name: String }, +/// The collection of all chips in the VM. The chips should correspond 1-to-1 with the associated +/// [AirInventory]. The [VmChipComplex] coordinates the trace generation for all chips in the VM +/// after construction. +#[derive(Getters)] +pub struct VmChipComplex +where + SC: StarkGenericConfig, + PB: ProverBackend, +{ + /// System chip complex responsible for trace generation of [SystemAirInventory] + pub system: SCC, + pub inventory: ChipInventory, } -impl Default for VmInventory { - fn default() -> Self { - Self::new() - } -} +// ======================= Inventory Function Definitions ============================= -impl VmInventory { - pub fn new() -> Self { +impl ExecutorInventory { + /// Empty inventory should be created at the start of the declaration of a new extension. + #[allow(clippy::new_without_default)] + pub fn new(config: SystemConfig) -> Self { Self { - instruction_lookup: FxHashMap::default(), - executors: Vec::new(), - periphery: Vec::new(), - insertion_order: Vec::new(), - } - } - - pub fn transmute(self) -> VmInventory - where - E: Into, - P: Into, - { - VmInventory { - instruction_lookup: self.instruction_lookup, - executors: self.executors.into_iter().map(|e| e.into()).collect(), - periphery: self.periphery.into_iter().map(|p| p.into()).collect(), - insertion_order: self.insertion_order, - } - } - - /// Append `other` to current inventory. This means `self` comes earlier in the dependency - /// chain. - pub fn append(&mut self, mut other: VmInventory) -> Result<(), VmInventoryError> { - let num_executors = self.executors.len(); - let num_periphery = self.periphery.len(); - for (opcode, mut id) in other.instruction_lookup.into_iter() { - id += num_executors; - if let Some(old_id) = self.instruction_lookup.insert(opcode, id) { - return Err(VmInventoryError::ExecutorExists { opcode, id: old_id }); - } - } - for chip_id in other.insertion_order.iter_mut() { - match chip_id { - ChipId::Executor(id) => *id += num_executors, - ChipId::Periphery(id) => *id += num_periphery, - } + config, + instruction_lookup: Default::default(), + executors: Default::default(), + ext_start: vec![0], } - self.executors.append(&mut other.executors); - self.periphery.append(&mut other.periphery); - self.insertion_order.append(&mut other.insertion_order); - Ok(()) } /// Inserts an executor with the collection of opcodes that it handles. - /// If some executor already owns one of the opcodes, it will be replaced and the old - /// executor ID is returned. + /// If some executor already owns one of the opcodes, an error is returned with the existing + /// executor. pub fn add_executor( &mut self, executor: impl Into, opcodes: impl IntoIterator, - ) -> Result<(), VmInventoryError> { + ) -> Result<(), ExecutorInventoryError> { let opcodes: Vec<_> = opcodes.into_iter().collect(); for opcode in &opcodes { if let Some(id) = self.instruction_lookup.get(opcode) { - return Err(VmInventoryError::ExecutorExists { + return Err(ExecutorInventoryError::ExecutorExists { opcode: *opcode, id: *id, }); @@ -308,897 +228,521 @@ impl VmInventory { } let id = self.executors.len(); self.executors.push(executor.into()); - self.insertion_order.push(ChipId::Executor(id)); for opcode in opcodes { - self.instruction_lookup.insert(opcode, id); + self.instruction_lookup + .insert(opcode, id.try_into().unwrap()); } Ok(()) } - pub fn add_periphery_chip(&mut self, periphery_chip: impl Into

) { - let id = self.periphery.len(); - self.periphery.push(periphery_chip.into()); - self.insertion_order.push(ChipId::Periphery(id)); - } - - pub fn get_executor(&self, opcode: VmOpcode) -> Option<&E> { - let id = self.instruction_lookup.get(&opcode)?; - self.executors.get(*id) - } - - pub fn get_mut_executor(&mut self, opcode: &VmOpcode) -> Option<&mut E> { - let id = self.instruction_lookup.get(opcode)?; - self.executors.get_mut(*id) - } - - pub fn executors(&self) -> &[E] { - &self.executors - } - - pub fn periphery(&self) -> &[P] { - &self.periphery - } + /// Extend the inventory with a new extension. + /// A new inventory with different type generics is returned with the combined inventory. + pub fn extend( + self, + other: &EXT, + ) -> Result, ExecutorInventoryError> + where + F: 'static, + E: Into + AnyEnum, + E3: AnyEnum, + EXT: VmExecutionExtension, + EXT::Executor: Into, + { + let mut builder: ExecutorInventoryBuilder = self.builder(); + other.extend_execution(&mut builder)?; + let other_inventory = builder.new_inventory; + let other_phantom_executors = builder.phantom_executors; + let mut inventory_ext = self.transmute(); + inventory_ext.append(other_inventory.transmute())?; + let phantom_chip: &mut PhantomExecutor = inventory_ext + .find_executor_mut() + .next() + .expect("system always has phantom chip"); + let phantom_executors = &mut phantom_chip.phantom_executors; + for (discriminant, sub_executor) in other_phantom_executors { + if phantom_executors + .insert(discriminant, sub_executor) + .is_some() + { + return Err(ExecutorInventoryError::PhantomSubExecutorExists { discriminant }); + } + } - pub fn num_airs(&self) -> usize { - self.executors.len() + self.periphery.len() + Ok(inventory_ext) } - /// Return trace heights of all chips in the inventory. - /// The order is deterministic: - /// - All executors come first, in the order they were added. - /// - All periphery chips come after, in the order they were added. - pub fn get_trace_heights(&self) -> VmInventoryTraceHeights + pub fn builder(&self) -> ExecutorInventoryBuilder<'_, F, E2> where - E: ChipUsageGetter, - P: ChipUsageGetter, + F: 'static, + E: AnyEnum, { - VmInventoryTraceHeights { - chips: self - .executors - .iter() - .enumerate() - .map(|(i, chip)| (ChipId::Executor(i), chip.current_trace_height())) - .chain( - self.periphery - .iter() - .enumerate() - .map(|(i, chip)| (ChipId::Periphery(i), chip.current_trace_height())), - ) - .collect(), + let old_executors = self.executors.iter().map(|e| e as &dyn AnyEnum).collect(); + ExecutorInventoryBuilder { + old_executors, + new_inventory: ExecutorInventory::new(self.config.clone()), + phantom_executors: Default::default(), } } - /// Return the dummy trace heights of the inventory. This is used for generating a dummy proof. - /// Regular users should not need this. - pub fn get_dummy_trace_heights(&self) -> VmInventoryTraceHeights + pub fn transmute(self) -> ExecutorInventory where - E: ChipUsageGetter, - P: ChipUsageGetter, + E: Into, { - VmInventoryTraceHeights { - chips: self - .executors - .iter() - .enumerate() - .map(|(i, _)| (ChipId::Executor(i), 1)) - .chain(self.periphery.iter().enumerate().map(|(i, chip)| { - ( - ChipId::Periphery(i), - chip.constant_trace_height().unwrap_or(1), - ) - })) - .collect(), + ExecutorInventory { + config: self.config, + instruction_lookup: self.instruction_lookup, + executors: self.executors.into_iter().map(|e| e.into()).collect(), + ext_start: self.ext_start, } } -} - -impl VmInventoryTraceHeights { - /// Round all trace heights to the next power of two. This will round trace heights of 0 to 1. - pub fn round_to_next_power_of_two(&mut self) { - self.chips - .values_mut() - .for_each(|v| *v = v.next_power_of_two()); - } - - /// Round all trace heights to the next power of two, except 0 stays 0. - pub fn round_to_next_power_of_two_or_zero(&mut self) { - self.chips - .values_mut() - .for_each(|v| *v = next_power_of_two_or_zero(*v)); - } -} - -impl VmComplexTraceHeights { - /// Round all trace heights to the next power of two. This will round trace heights of 0 to 1. - pub fn round_to_next_power_of_two(&mut self) { - self.system.round_to_next_power_of_two(); - self.inventory.round_to_next_power_of_two(); - } - - /// Round all trace heights to the next power of two, except 0 stays 0. - pub fn round_to_next_power_of_two_or_zero(&mut self) { - self.system.round_to_next_power_of_two_or_zero(); - self.inventory.round_to_next_power_of_two_or_zero(); - } -} - -// PublicValuesChip needs F: PrimeField32 due to Adapter -/// The minimum collection of chips that any VM must have. -#[derive(Getters)] -pub struct VmChipComplex { - #[getset(get = "pub")] - config: SystemConfig, - // ATTENTION: chip destruction should follow the **reverse** of the following field order: - pub base: SystemBase, - /// Extendable collection of chips for executing instructions. - /// System ensures it contains: - /// - PhantomChip - /// - PublicValuesChip if continuations disabled - /// - Poseidon2Chip if continuations enabled - pub inventory: VmInventory, - overridden_inventory_heights: Option, - - /// Absolute maximum value a trace height can be and still be provable. - max_trace_height: usize, - - streams: Arc>>, - bus_idx_mgr: BusIndexManager, -} - -#[derive(Clone, Copy, Debug, Default)] -pub struct BusIndexManager { - /// All existing buses use indices in [0, bus_idx_max) - bus_idx_max: BusIndex, -} - -impl BusIndexManager { - pub fn new() -> Self { - Self { bus_idx_max: 0 } - } - - pub fn new_bus_idx(&mut self) -> BusIndex { - let idx = self.bus_idx_max; - self.bus_idx_max = self.bus_idx_max.checked_add(1).unwrap(); - idx - } -} - -/// The base [VmChipComplex] with only system chips. -pub type SystemComplex = VmChipComplex, SystemPeriphery>; - -/// Base system chips. -/// The following don't execute instructions, but are essential -/// for the VM architecture. -pub struct SystemBase { - // RangeCheckerChip **must** be the last chip to have trace generation called on - pub range_checker_chip: SharedVariableRangeCheckerChip, - pub memory_controller: MemoryController, - pub connector_chip: VmConnectorChip, - pub program_chip: ProgramChip, -} -impl SystemBase { - pub fn range_checker_bus(&self) -> VariableRangeCheckerBus { - self.range_checker_chip.bus() - } - - pub fn memory_bus(&self) -> MemoryBus { - self.memory_controller.memory_bus + /// Append `other` to current inventory. This means `self` comes earlier in the dependency + /// chain. + fn append(&mut self, mut other: ExecutorInventory) -> Result<(), ExecutorInventoryError> { + let num_executors = self.executors.len(); + for (opcode, mut id) in other.instruction_lookup.into_iter() { + id = id.checked_add(num_executors.try_into().unwrap()).unwrap(); + if let Some(old_id) = self.instruction_lookup.insert(opcode, id) { + return Err(ExecutorInventoryError::ExecutorExists { opcode, id: old_id }); + } + } + for id in &mut other.ext_start { + *id = id.checked_add(num_executors).unwrap(); + } + self.executors.append(&mut other.executors); + self.ext_start.append(&mut other.ext_start); + Ok(()) } - pub fn program_bus(&self) -> ProgramBus { - self.program_chip.air.bus + pub fn get_executor(&self, opcode: VmOpcode) -> Option<&E> { + let id = self.instruction_lookup.get(&opcode)?; + self.executors.get(*id as usize) } - pub fn memory_bridge(&self) -> MemoryBridge { - self.memory_controller.memory_bridge() + pub fn get_mut_executor(&mut self, opcode: &VmOpcode) -> Option<&mut E> { + let id = self.instruction_lookup.get(opcode)?; + self.executors.get_mut(*id as usize) } - pub fn offline_memory(&self) -> Arc>> { - self.memory_controller.offline_memory().clone() + pub fn executors(&self) -> &[E] { + &self.executors } - pub fn execution_bus(&self) -> ExecutionBus { - self.connector_chip.air.execution_bus + pub fn find_executor(&self) -> impl Iterator + where + E: AnyEnum, + { + self.executors + .iter() + .filter_map(|e| e.as_any_kind().downcast_ref()) } - /// Return trace heights of SystemBase. Usually this is for aggregation and not useful for - /// regular users. - pub fn get_system_trace_heights(&self) -> SystemTraceHeights { - SystemTraceHeights { - memory: self.memory_controller.get_memory_trace_heights(), - } + pub fn find_executor_mut(&mut self) -> impl Iterator + where + E: AnyEnum, + { + self.executors + .iter_mut() + .filter_map(|e| e.as_any_kind_mut().downcast_mut()) } - /// Return dummy trace heights of SystemBase. Usually this is for aggregation to generate a - /// dummy proof and not useful for regular users. - pub fn get_dummy_system_trace_heights(&self) -> SystemTraceHeights { - SystemTraceHeights { - memory: self.memory_controller.get_dummy_memory_trace_heights(), - } + /// Returns the system config of the inventory. + pub fn config(&self) -> &SystemConfig { + &self.config } } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor)] -pub enum SystemExecutor { - PublicValues(PublicValuesChip), - Phantom(RefCell>), -} - -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] -pub enum SystemPeriphery { - /// Poseidon2 chip with direct compression interactions - Poseidon2(Poseidon2PeripheryChip), -} - -impl SystemComplex { - pub fn new(config: SystemConfig) -> Self { - let mut bus_idx_mgr = BusIndexManager::new(); - let execution_bus = ExecutionBus::new(bus_idx_mgr.new_bus_idx()); - let memory_bus = MemoryBus::new(bus_idx_mgr.new_bus_idx()); - let program_bus = ProgramBus::new(bus_idx_mgr.new_bus_idx()); - let range_bus = - VariableRangeCheckerBus::new(bus_idx_mgr.new_bus_idx(), config.memory_config.decomp); - - let range_checker = SharedVariableRangeCheckerChip::new(range_bus); - let memory_controller = if config.continuation_enabled { - MemoryController::with_persistent_memory( - memory_bus, - config.memory_config, - range_checker.clone(), - PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()), - PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()), - ) - } else { - MemoryController::with_volatile_memory( - memory_bus, - config.memory_config, - range_checker.clone(), - ) - }; - let memory_bridge = memory_controller.memory_bridge(); - let offline_memory = memory_controller.offline_memory(); - let program_chip = ProgramChip::new(program_bus); - let connector_chip = VmConnectorChip::new( - execution_bus, - program_bus, - range_checker.clone(), - config.memory_config.clk_max_bits, - ); - - let mut inventory = VmInventory::new(); - // PublicValuesChip is required when num_public_values > 0 in single segment mode. - if config.has_public_values_chip() { - assert_eq!(inventory.executors().len(), Self::PV_EXECUTOR_IDX); - let chip = PublicValuesChip::new( - NativeAdapterChip::new(execution_bus, program_bus, memory_bridge), - PublicValuesCoreChip::new( - config.num_public_values, - config.max_constraint_degree as u32 - 1, - ), - offline_memory, - ); - inventory - .add_executor(chip, [PublishOpcode::PUBLISH.global_opcode()]) - .unwrap(); - } - if config.continuation_enabled { - assert_eq!(inventory.periphery().len(), Self::POSEIDON2_PERIPHERY_IDX); - // Add direct poseidon2 chip for persistent memory. - // This is **not** an instruction executor. - // Currently we never use poseidon2 opcodes when continuations is enabled: we will need - // special handling when that happens - let direct_bus_idx = memory_controller - .interface_chip - .compression_bus() - .unwrap() - .index; - let chip = Poseidon2PeripheryChip::new( - vm_poseidon2_config(), - direct_bus_idx, - config.max_constraint_degree, - ); - inventory.add_periphery_chip(chip); - } - let streams = Arc::new(Mutex::new(Streams::default())); - let phantom_opcode = SystemOpcode::PHANTOM.global_opcode(); - let mut phantom_chip = - PhantomChip::new(execution_bus, program_bus, SystemOpcode::CLASS_OFFSET); - phantom_chip.set_streams(streams.clone()); - inventory - .add_executor(RefCell::new(phantom_chip), [phantom_opcode]) - .unwrap(); - - let base = SystemBase { - program_chip, - connector_chip, - memory_controller, - range_checker_chip: range_checker, - }; - - let max_trace_height = if TypeId::of::() == TypeId::of::() { - let min_log_blowup = log2_ceil_usize(config.max_constraint_degree - 1); - 1 << (BabyBear::TWO_ADICITY - min_log_blowup) - } else { - tracing::warn!( - "constructing SystemComplex for unrecognized field; using max_trace_height = 2^30" - ); - 1 << 30 - }; - - Self { - config, - base, - inventory, - bus_idx_mgr, - streams, - overridden_inventory_heights: None, - max_trace_height, - } +impl ExecutorInventoryBuilder<'_, F, E> { + pub fn add_executor( + &mut self, + executor: impl Into, + opcodes: impl IntoIterator, + ) -> Result<(), ExecutorInventoryError> { + self.new_inventory.add_executor(executor, opcodes) } -} - -impl VmChipComplex { - /// **If** public values chip exists, then its executor index is 0. - pub(super) const PV_EXECUTOR_IDX: ExecutorId = 0; - /// **If** internal poseidon2 chip exists, then its periphery index is 0. - pub(super) const POSEIDON2_PERIPHERY_IDX: usize = 0; - // @dev: Remember to update self.bus_idx_mgr after dropping this! - pub fn inventory_builder(&self) -> VmInventoryBuilder + pub fn add_phantom_sub_executor( + &mut self, + phantom_sub: PE, + discriminant: PhantomDiscriminant, + ) -> Result<(), ExecutorInventoryError> where E: AnyEnum, - P: AnyEnum, + F: 'static, + PE: PhantomSubExecutor + 'static, { - let mut builder = - VmInventoryBuilder::new(&self.config, &self.base, &self.streams, self.bus_idx_mgr); - // Add range checker for convenience, the other system base chips aren't included - they can - // be accessed directly from builder - builder.add_chip(&self.base.range_checker_chip); - for chip in self.inventory.executors() { - builder.add_chip(chip); - } - for chip in self.inventory.periphery() { - builder.add_chip(chip); + let existing = self + .phantom_executors + .insert(discriminant, Arc::new(phantom_sub)); + if existing.is_some() { + return Err(ExecutorInventoryError::PhantomSubExecutorExists { discriminant }); } - - builder - } - - /// Extend the chip complex with a new extension. - /// A new chip complex with different type generics is returned with the combined inventory. - pub fn extend( - mut self, - config: &Ext, - ) -> Result, VmInventoryError> - where - Ext: VmExtension, - E: Into + AnyEnum, - P: Into + AnyEnum, - Ext::Executor: Into, - Ext::Periphery: Into, - { - let mut builder = self.inventory_builder(); - let inventory_ext = config.build(&mut builder)?; - self.bus_idx_mgr = builder.bus_idx_mgr; - let mut ext_complex = self.transmute(); - ext_complex.append(inventory_ext.transmute())?; - Ok(ext_complex) + Ok(()) } - pub fn transmute(self) -> VmChipComplex + pub fn find_executor(&self) -> impl Iterator where - E: Into, - P: Into, + E: AnyEnum, { - VmChipComplex { - config: self.config, - base: self.base, - inventory: self.inventory.transmute(), - bus_idx_mgr: self.bus_idx_mgr, - streams: self.streams, - overridden_inventory_heights: self.overridden_inventory_heights, - max_trace_height: self.max_trace_height, - } + self.old_executors + .iter() + .filter_map(|e| e.as_any_kind().downcast_ref()) } - /// Appends `other` to the current inventory. - /// This means `self` comes earlier in the dependency chain. - pub fn append(&mut self, other: VmInventory) -> Result<(), VmInventoryError> { - self.inventory.append(other) + /// Returns the maximum number of bits used to represent addresses in memory + pub fn pointer_max_bits(&self) -> usize { + self.new_inventory.config().memory_config.pointer_max_bits } +} - pub fn program_chip(&self) -> &ProgramChip { - &self.base.program_chip +impl AirInventory { + /// Outside of this crate, [AirInventory] must be constructed via [SystemConfig]. + pub(crate) fn new( + config: SystemConfig, + system: SystemAirInventory, + bus_idx_mgr: BusIndexManager, + ) -> Self { + Self { + config, + system, + ext_start: Vec::new(), + ext_airs: Vec::new(), + bus_idx_mgr, + } } - pub fn program_chip_mut(&mut self) -> &mut ProgramChip { - &mut self.base.program_chip + /// This should be called **exactly once** at the start of the declaration of a new extension. + pub fn start_new_extension(&mut self) { + self.ext_start.push(self.ext_airs.len()); } - pub fn connector_chip(&self) -> &VmConnectorChip { - &self.base.connector_chip + pub fn new_bus_idx(&mut self) -> BusIndex { + self.bus_idx_mgr.new_bus_idx() } - pub fn connector_chip_mut(&mut self) -> &mut VmConnectorChip { - &mut self.base.connector_chip + /// Looks through already-defined AIRs to see if there exists any of type `A` by downcasting. + /// Returns all chips of type `A` in the circuit. + /// + /// This should not be used to look for system AIRs. + pub fn find_air(&self) -> impl Iterator { + self.ext_airs + .iter() + .filter_map(|air| air.as_any().downcast_ref()) } - pub fn memory_controller(&self) -> &MemoryController { - &self.base.memory_controller + pub fn add_air + 'static>(&mut self, air: A) { + self.add_air_ref(Arc::new(air)); } - pub fn range_checker_chip(&self) -> &SharedVariableRangeCheckerChip { - &self.base.range_checker_chip + pub fn add_air_ref(&mut self, air: AirRef) { + self.ext_airs.push(air); } - pub fn public_values_chip(&self) -> Option<&PublicValuesChip> - where - E: AnyEnum, - { - let chip = self.inventory.executors().get(Self::PV_EXECUTOR_IDX)?; - chip.as_any_kind().downcast_ref() + pub fn range_checker(&self) -> &VariableRangeCheckerAir { + self.find_air() + .next() + .expect("system always has range checker AIR") } - pub fn poseidon2_chip(&self) -> Option<&Poseidon2PeripheryChip> - where - P: AnyEnum, - { - let chip = self - .inventory - .periphery - .get(Self::POSEIDON2_PERIPHERY_IDX)?; - chip.as_any_kind().downcast_ref() + /// The AIRs in the order they appear in the verifying key. + /// This is the system AIRs, followed by the other AIRs in the **reverse** of the order they + /// were added in the VM extension definitions. In particular, the AIRs that have dependencies + /// appear later. The system guarantees that the last AIR is the [VariableRangeCheckerAir]. + pub fn into_airs(self) -> impl Iterator> { + self.system + .into_airs() + .into_iter() + .chain(self.ext_airs.into_iter().rev()) } - pub fn poseidon2_chip_mut(&mut self) -> Option<&mut Poseidon2PeripheryChip> - where - P: AnyEnum, - { - let chip = self - .inventory - .periphery - .get_mut(Self::POSEIDON2_PERIPHERY_IDX)?; - chip.as_any_kind_mut().downcast_mut() + /// This is O(1). Returns the total number of AIRs and equals the length of [`Self::into_airs`]. + pub fn num_airs(&self) -> usize { + self.config.num_airs() + self.ext_airs.len() } - pub fn finalize_memory(&mut self) - where - P: AnyEnum, - { - if self.config.continuation_enabled { - let chip = self - .inventory - .periphery - .get_mut(Self::POSEIDON2_PERIPHERY_IDX) - .expect("Poseidon2 chip required for persistent memory"); - let hasher: &mut Poseidon2PeripheryChip = chip - .as_any_kind_mut() - .downcast_mut() - .expect("Poseidon2 chip required for persistent memory"); - self.base.memory_controller.finalize(Some(hasher)) - } else { - self.base - .memory_controller - .finalize(None::<&mut Poseidon2PeripheryChip>) - }; + /// Standalone function to generate proving key and verifying key for this circuit. + pub fn keygen>(self, engine: &E) -> MultiStarkProvingKey { + let mut builder = engine.keygen_builder(); + for air in self.into_airs() { + builder.add_air(air); + } + builder.generate_pk() } - pub(crate) fn set_program(&mut self, program: Program) { - self.base.program_chip.set_program(program); + /// Returns the maximum number of bits used to represent addresses in memory + pub fn pointer_max_bits(&self) -> usize { + self.config.memory_config.pointer_max_bits } +} - pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage) { - self.base.memory_controller.set_initial_memory(memory); +impl BusIndexManager { + pub fn new() -> Self { + Self { bus_idx_max: 0 } } - /// Warning: this sets the stream in all chips which have a shared mutable reference to the - /// streams. - pub(crate) fn set_streams(&mut self, streams: Streams) { - *self.streams.lock().unwrap() = streams; + pub fn new_bus_idx(&mut self) -> BusIndex { + let idx = self.bus_idx_max; + self.bus_idx_max = self.bus_idx_max.checked_add(1).unwrap(); + idx } +} - /// This should **only** be called after segment execution has finished. - pub fn take_streams(&mut self) -> Streams { - std::mem::take(&mut self.streams.lock().unwrap()) +impl ChipInventory +where + SC: StarkGenericConfig, + PB: ProverBackend, +{ + pub fn new(airs: AirInventory) -> Self { + Self { + airs, + chips: Vec::new(), + cur_num_exts: 0, + executor_idx_to_insertion_idx: Vec::new(), + } } - // This is O(1). - pub fn num_airs(&self) -> usize { - 3 + self.memory_controller().num_airs() + self.inventory.num_airs() + pub fn config(&self) -> &SystemConfig { + &self.airs.config } - // we always need to special case it because we need to fix the air id. - fn public_values_chip_idx(&self) -> Option { - self.config - .has_public_values_chip() - .then_some(Self::PV_EXECUTOR_IDX) - } + pub fn start_new_extension(&mut self) -> Result<(), ChipInventoryError> { + if self.cur_num_exts >= self.airs.ext_start.len() { + return Err(ChipInventoryError::MissingCircuitExtension( + self.airs.ext_start.len(), + )); + } + if self.chips.len() != self.airs.ext_start[self.cur_num_exts] { + return Err(ChipInventoryError::MissingChip { + actual: self.chips.len(), + expected: self.airs.ext_start[self.cur_num_exts], + }); + } - // Avoids a downcast when you don't need the concrete type. - fn _public_values_chip(&self) -> Option<&E> { - self.config - .has_public_values_chip() - .then(|| &self.inventory.executors[Self::PV_EXECUTOR_IDX]) + self.cur_num_exts += 1; + Ok(()) } - // All inventory chips except public values chip, in reverse order they were added. - pub(crate) fn chips_excluding_pv_chip(&self) -> impl Iterator> { - let public_values_chip_idx = self.public_values_chip_idx(); - self.inventory - .insertion_order - .iter() - .rev() - .flat_map(move |chip_idx| match *chip_idx { - // Skip public values chip if it exists. - ChipId::Executor(id) => (Some(id) != public_values_chip_idx) - .then(|| Either::Executor(&self.inventory.executors[id])), - ChipId::Periphery(id) => Some(Either::Periphery(&self.inventory.periphery[id])), + /// Gets the next AIR from the pre-existing AIR inventory according to the index of the next + /// chip to be built. + pub fn next_air(&self) -> Result<&A, ChipInventoryError> { + let cur_idx = self.chips.len(); + self.airs + .ext_airs + .get(cur_idx) + .and_then(|air| air.as_any().downcast_ref()) + .ok_or_else(|| ChipInventoryError::AirNotFound { + name: type_name::().to_string(), }) } - /// Return air names of all chips in order. - pub(crate) fn air_names(&self) -> Vec - where - E: ChipUsageGetter, - P: ChipUsageGetter, - { - once(self.program_chip().air_name()) - .chain([self.connector_chip().air_name()]) - .chain(self._public_values_chip().map(|c| c.air_name())) - .chain(self.memory_controller().air_names()) - .chain(self.chips_excluding_pv_chip().map(|c| c.air_name())) - .chain([self.range_checker_chip().air_name()]) - .collect() - } - /// Return trace heights of all chips in order corresponding to `air_names`. - pub(crate) fn current_trace_heights(&self) -> Vec - where - E: ChipUsageGetter, - P: ChipUsageGetter, - { - once(self.program_chip().current_trace_height()) - .chain([self.connector_chip().current_trace_height()]) - .chain(self._public_values_chip().map(|c| c.current_trace_height())) - .chain(self.memory_controller().current_trace_heights()) - .chain( - self.chips_excluding_pv_chip() - .map(|c| c.current_trace_height()), - ) - .chain([self.range_checker_chip().current_trace_height()]) - .collect() - } - - /// Return trace heights of (SystemBase, Inventory). Usually this is for aggregation and not - /// useful for regular users. + /// Looks through built chips to see if there exists any of type `C` by downcasting. + /// Returns all chips of type `C` in the chipset. /// - /// **Warning**: the order of `get_trace_heights` is deterministic, but it is not the same as - /// the order of `air_names`. In other words, the order here does not match the order of AIR - /// IDs. - pub fn get_internal_trace_heights(&self) -> VmComplexTraceHeights - where - E: ChipUsageGetter, - P: ChipUsageGetter, - { - VmComplexTraceHeights::new( - self.base.get_system_trace_heights(), - self.inventory.get_trace_heights(), - ) + /// Note: the type `C` will usually be a smart pointer to a chip. + pub fn find_chip(&self) -> impl Iterator { + self.chips.iter().filter_map(|c| c.as_any().downcast_ref()) } - /// Return dummy trace heights of (SystemBase, Inventory). Usually this is for aggregation to - /// generate a dummy proof and not useful for regular users. - /// - /// **Warning**: the order of `get_dummy_trace_heights` is deterministic, but it is not the same - /// as the order of `air_names`. In other words, the order here does not match the order of - /// AIR IDs. - pub fn get_dummy_internal_trace_heights(&self) -> VmComplexTraceHeights - where - E: ChipUsageGetter, - P: ChipUsageGetter, - { - VmComplexTraceHeights::new( - self.base.get_dummy_system_trace_heights(), - self.inventory.get_dummy_trace_heights(), - ) + /// Adds a chip that is not associated with any executor, as defined by the + /// [VmExecutionExtension] trait. + pub fn add_periphery_chip + 'static>(&mut self, chip: C) { + self.chips.push(Box::new(chip)); } - /// Override the trace heights for chips in the inventory. Usually this is for aggregation to - /// generate a dummy proof and not useful for regular users. - pub(crate) fn set_override_inventory_trace_heights( - &mut self, - overridden_inventory_heights: VmInventoryTraceHeights, - ) { - self.overridden_inventory_heights = Some(overridden_inventory_heights); + /// Adds a chip and associates it to the next executor. + /// **Caution:** you must add chips in the order matching the order that executors were added in + /// the [VmExecutionExtension] implementation. + pub fn add_executor_chip + 'static>(&mut self, chip: C) { + tracing::debug!("add_executor_chip: {}", type_name::()); + self.executor_idx_to_insertion_idx.push(self.chips.len()); + self.chips.push(Box::new(chip)); } - pub(crate) fn set_override_system_trace_heights( - &mut self, - overridden_system_heights: SystemTraceHeights, - ) { - let memory_controller = &mut self.base.memory_controller; - memory_controller.set_override_trace_heights(overridden_system_heights.memory); + /// Returns the mapping from executor index to the AIR index, where AIR index is the index of + /// the AIR within the verifying key. + /// + /// This should only be called after the `ChipInventory` is fully built. + pub fn executor_idx_to_air_idx(&self) -> Vec { + let num_airs = self.airs.num_airs(); + assert_eq!( + num_airs, + self.config().num_airs() + self.chips.len(), + "Number of chips does not match number of AIRs" + ); + // system AIRs are at the front of vkey, and then insertion index is the reverse ordering of + // AIR index + self.executor_idx_to_insertion_idx + .iter() + .map(|insertion_idx| { + num_airs + .checked_sub(insertion_idx.checked_add(1).unwrap()) + .unwrap_or_else(|| { + panic!( + "Attempt to subtract num_airs={num_airs} by {}", + insertion_idx + 1 + ) + }) + }) + .collect() } - /// Return dynamic trace heights of all chips in order, or 0 if - /// chip has constant height. - // Used for continuation segmentation logic, so this is performance-sensitive. - // Return iterator so we can break early. - pub(crate) fn dynamic_trace_heights(&self) -> impl Iterator + '_ - where - E: ChipUsageGetter, - P: ChipUsageGetter, - { - // program_chip, connector_chip - [0, 0] - .into_iter() - .chain(self._public_values_chip().map(|c| c.current_trace_height())) - .chain(self.memory_controller().current_trace_heights()) - .chain(self.chips_excluding_pv_chip().map(|c| match c { - // executor should never be constant height - Either::Executor(c) => c.current_trace_height(), - Either::Periphery(c) => { - if c.constant_trace_height().is_some() { - 0 - } else { - c.current_trace_height() - } - } - })) - .chain([0]) // range_checker_chip + pub fn timestamp_max_bits(&self) -> usize { + self.airs.config().memory_config.timestamp_max_bits } +} - /// Return trace cells of all chips in order. - /// This returns 0 cells for chips with preprocessed trace because the number of trace cells is - /// constant in those cases. This function is used to sample periodically and provided to - /// the segmentation strategy to decide whether to segment during execution. - pub(crate) fn current_trace_cells(&self) -> Vec - where - E: ChipUsageGetter, - P: ChipUsageGetter, - { - // program_chip, connector_chip - [0, 0] - .into_iter() - .chain(self._public_values_chip().map(|c| c.current_trace_cells())) - .chain(self.memory_controller().current_trace_cells()) - .chain(self.chips_excluding_pv_chip().map(|c| match c { - Either::Executor(c) => c.current_trace_cells(), - Either::Periphery(c) => { - if c.constant_trace_height().is_some() { - 0 - } else { - c.current_trace_cells() - } - } - })) - .chain([0]) // range_checker_chip - .collect() +// SharedVariableRangeCheckerChip is only used by the CPU backend. +impl ChipInventory> +where + SC: StarkGenericConfig, +{ + pub fn range_checker(&self) -> Result<&SharedVariableRangeCheckerChip, ChipInventoryError> { + self.find_chip::() + .next() + .ok_or_else(|| ChipInventoryError::ChipNotFound { + name: "VariableRangeCheckerChip".to_string(), + }) } +} - pub fn airs(&self) -> Vec> - where - Domain: PolynomialSpace, - E: Chip, - P: Chip, - { - // ATTENTION: The order of AIR MUST be consistent with `generate_proof_input`. - let program_rap = Arc::new(self.program_chip().air) as AirRef; - let connector_rap = Arc::new(self.connector_chip().air) as AirRef; - [program_rap, connector_rap] - .into_iter() - .chain(self._public_values_chip().map(|chip| chip.air())) - .chain(self.memory_controller().airs()) - .chain(self.chips_excluding_pv_chip().map(|chip| match chip { - Either::Executor(chip) => chip.air(), - Either::Periphery(chip) => chip.air(), - })) - .chain(once(self.range_checker_chip().air())) - .collect() - } +// ================================== Error Types ===================================== - pub(crate) fn generate_proof_input( - mut self, - cached_program: Option>, - trace_height_constraints: &[LinearConstraint], - #[cfg(feature = "bench-metrics")] metrics: &mut VmMetrics, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - E: Chip, - P: AnyEnum + Chip, - { - // System: Finalize memory. - self.finalize_memory(); +#[derive(thiserror::Error, Debug)] +pub enum ExecutorInventoryError { + #[error("Opcode {opcode} already owned by executor id {id}")] + ExecutorExists { opcode: VmOpcode, id: ExecutorId }, + #[error("Phantom discriminant {} already has sub-executor", .discriminant.0)] + PhantomSubExecutorExists { discriminant: PhantomDiscriminant }, +} - let trace_heights = self - .current_trace_heights() - .iter() - .map(|h| next_power_of_two_or_zero(*h)) - .collect_vec(); - if let Some(index) = trace_heights - .iter() - .position(|h| *h > self.max_trace_height) - { - tracing::info!( - "trace height of air {index} has height {} greater than maximum {}", - trace_heights[index], - self.max_trace_height - ); - return Err(GenerationError::TraceHeightsLimitExceeded); - } - if trace_height_constraints.is_empty() { - tracing::warn!("generating proof input without trace height constraints"); - } - for (i, constraint) in trace_height_constraints.iter().enumerate() { - let value = zip_eq(&constraint.coefficients, &trace_heights) - .map(|(&c, &h)| c as u64 * h as u64) - .sum::(); - - if value >= constraint.threshold as u64 { - tracing::info!( - "trace heights {:?} violate linear constraint {} ({} >= {})", - trace_heights, - i, - value, - constraint.threshold - ); - return Err(GenerationError::TraceHeightsLimitExceeded); - } - } +#[derive(thiserror::Error, Debug)] +pub enum AirInventoryError { + #[error("AIR {name} not found")] + AirNotFound { name: String }, +} - #[cfg(feature = "bench-metrics")] - self.finalize_metrics(metrics); - - let has_pv_chip = self.public_values_chip_idx().is_some(); - // ATTENTION: The order of AIR proof input generation MUST be consistent with `airs`. - let mut builder = VmProofInputBuilder::new(); - let SystemBase { - range_checker_chip, - memory_controller, - connector_chip, - program_chip, - .. - } = self.base; - - // System: Program Chip - debug_assert_eq!(builder.curr_air_id, PROGRAM_AIR_ID); - builder.add_air_proof_input(program_chip.generate_air_proof_input(cached_program)); - // System: Connector Chip - debug_assert_eq!(builder.curr_air_id, CONNECTOR_AIR_ID); - builder.add_air_proof_input(connector_chip.generate_air_proof_input()); - - // Go through all chips in inventory in reverse order they were added (to resolve - // dependencies) Important Note: for air_id ordering reasons, we want to - // generate_air_proof_input for public values and memory chips **last** but include - // them into the `builder` **first**. - let mut public_values_input = None; - let mut insertion_order = self.inventory.insertion_order; - insertion_order.reverse(); - let mut non_sys_inputs = Vec::with_capacity(insertion_order.len()); - for chip_id in insertion_order { - let mut height = None; - if let Some(overridden_heights) = self.overridden_inventory_heights.as_ref() { - height = overridden_heights.chips.get(&chip_id).copied(); - } - let air_proof_input = match chip_id { - ChipId::Executor(id) => { - let chip = self.inventory.executors.pop().unwrap(); - assert_eq!(id, self.inventory.executors.len()); - generate_air_proof_input(chip, height) - } - ChipId::Periphery(id) => { - let chip = self.inventory.periphery.pop().unwrap(); - assert_eq!(id, self.inventory.periphery.len()); - generate_air_proof_input(chip, height) - } - }; - if has_pv_chip && chip_id == ChipId::Executor(Self::PV_EXECUTOR_IDX) { - public_values_input = Some(air_proof_input); - } else { - non_sys_inputs.push(air_proof_input); - } - } +#[derive(thiserror::Error, Debug)] +pub enum ChipInventoryError { + #[error("Air {name} not found")] + AirNotFound { name: String }, + #[error("Chip {name} not found")] + ChipNotFound { name: String }, + #[error("Adding prover extension without execution extension. Number of execution extensions is {0}")] + MissingExecutionExtension(usize), + #[error( + "Adding prover extension without circuit extension. Number of circuit extensions is {0}" + )] + MissingCircuitExtension(usize), + #[error("Missing chip. Number of chips is {actual}, expected number is {expected}")] + MissingChip { actual: usize, expected: usize }, + #[error("Missing executor chip. Number of executors with associated chips is {actual}, expected number is {expected}")] + MissingExecutor { actual: usize, expected: usize }, +} - if let Some(input) = public_values_input { - debug_assert_eq!(builder.curr_air_id, PUBLIC_VALUES_AIR_ID); - builder.add_air_proof_input(input); - } - // System: Memory Controller - { - // memory - let air_proof_inputs = memory_controller.generate_air_proof_inputs(); - for air_proof_input in air_proof_inputs { - builder.add_air_proof_input(air_proof_input); - } - } - // Non-system chips - non_sys_inputs - .into_iter() - .for_each(|input| builder.add_air_proof_input(input)); - // System: Range Checker Chip - builder.add_air_proof_input(range_checker_chip.generate_air_proof_input()); +// ======================= VM Chip Complex Implementation ============================= - Ok(builder.build()) +impl VmChipComplex +where + SC: StarkGenericConfig, + RA: Arena, + PB: ProverBackend, + SCC: SystemChipComplex, +{ + pub fn system_config(&self) -> &SystemConfig { + self.inventory.config() } - #[cfg(feature = "bench-metrics")] - fn finalize_metrics(&self, metrics: &mut VmMetrics) - where - E: ChipUsageGetter, - P: ChipUsageGetter, - { - tracing::info!(metrics.cycle_count); - counter!("total_cycles").absolute(metrics.cycle_count as u64); - counter!("main_cells_used") - .absolute(self.current_trace_cells().into_iter().sum::() as u64); - - if self.config.profiling { - metrics.chip_heights = - itertools::izip!(self.air_names(), self.current_trace_heights()).collect(); - metrics.emit(); - } + /// `record_arenas` is expected to have length equal to the number of AIRs in the verifying key + /// and in the same order as the AIRs appearing in the verifying key, even though some chips may + /// not require a record arena. + pub(crate) fn generate_proving_ctx( + &mut self, + system_records: SystemRecords, + record_arenas: Vec, + // trace_height_constraints: &[LinearConstraint], + ) -> Result, GenerationError> { + // ATTENTION: The order of AIR proving context generation MUST be consistent with + // `AirInventory::into_airs`. + + // Execution has finished at this point. + // ASSUMPTION WHICH MUST HOLD: non-system chips do not have a dependency on the system chips + // during trace generation. Given this assumption, we can generate trace on the system chips + // first. + let num_sys_airs = self.system_config().num_airs(); + let num_airs = num_sys_airs + self.inventory.chips.len(); + if num_airs != record_arenas.len() { + return Err(GenerationError::UnexpectedNumArenas { + actual: record_arenas.len(), + expected: num_airs, + }); + } + let mut _record_arenas = record_arenas; + let record_arenas = _record_arenas.split_off(num_sys_airs); + let sys_record_arenas = _record_arenas; + + // First go through all system chips + // Then go through all other chips in inventory in **reverse** order they were added (to + // resolve dependencies) + // + // Perf[jpw]: currently we call tracegen on each chip **serially** (although tracegen per + // chip is parallelized). We could introduce more parallelism, while potentially increasing + // the peak memory usage, by keeping a dependency tree and generating traces at the same + // layer of the tree in parallel. + let ctx_without_empties: Vec<(usize, AirProvingContext<_>)> = iter::empty() + .chain(info_span!("system_trace_gen").in_scope(|| { + self.system + .generate_proving_ctx(system_records, sys_record_arenas) + })) + .chain( + zip(self.inventory.chips.iter().enumerate().rev(), record_arenas).map( + |((insertion_idx, chip), records)| { + // Only create a span if record is not empty: + let _span = (!records.is_empty()).then(|| { + let air_name = self.inventory.airs.ext_airs[insertion_idx].name(); + info_span!("single_trace_gen", air = air_name).entered() + }); + chip.generate_proving_ctx(records) + }, + ), + ) + .enumerate() + .filter(|(_air_id, ctx)| { + (!ctx.cached_mains.is_empty() || ctx.common_main.is_some()) + && ctx.main_trace_height() > 0 + }) + .collect(); + + Ok(ProvingContext { + per_air: ctx_without_empties, + }) } } -struct VmProofInputBuilder { - curr_air_id: usize, - proof_input_per_air: Vec<(usize, AirProofInput)>, -} +// ============ Blanket implementation of VM extension traits for Option =========== -impl VmProofInputBuilder { - fn new() -> Self { - Self { - curr_air_id: 0, - proof_input_per_air: vec![], - } - } - /// Adds air proof input if one of the main trace matrices is non-empty. - /// Always increments the internal `curr_air_id` regardless of whether a new air proof input was - /// added or not. - fn add_air_proof_input(&mut self, air_proof_input: AirProofInput) { - let h = if !air_proof_input.raw.cached_mains.is_empty() { - air_proof_input.raw.cached_mains[0].height() - } else { - air_proof_input - .raw - .common_main - .as_ref() - .map(|trace| trace.height()) - .unwrap() - }; - if h > 0 { - self.proof_input_per_air - .push((self.curr_air_id, air_proof_input)); - } - self.curr_air_id += 1; - } +impl> VmExecutionExtension for Option { + type Executor = EXT::Executor; - fn build(self) -> ProofInput { - ProofInput { - per_air: self.proof_input_per_air, + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + if let Some(extension) = self { + extension.extend_execution(inventory) + } else { + Ok(()) } } } -/// Generates an AIR proof input of the chip with the given height, if any. -/// -/// Assumption: an all-0 row is a valid dummy row for `chip`. -pub fn generate_air_proof_input>( - chip: C, - height: Option, -) -> AirProofInput { - let mut proof_input = chip.generate_air_proof_input(); - if let Some(height) = height { - let height = height.next_power_of_two(); - let main = proof_input.raw.common_main.as_mut().unwrap(); - assert!( - height >= main.height(), - "Overridden height must be greater than or equal to the used height" - ); - main.pad_to_height(height, FieldAlgebra::ZERO); +impl> VmCircuitExtension for Option { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + if let Some(extension) = self { + extension.extend_circuit(inventory) + } else { + Ok(()) + } } - proof_input } /// A helper trait for downcasting types that may be enums. @@ -1219,57 +763,13 @@ impl AnyEnum for () { } } -impl AnyEnum for SharedVariableRangeCheckerChip { - fn as_any_kind(&self) -> &dyn Any { - self - } - fn as_any_kind_mut(&mut self) -> &mut dyn Any { - self - } -} - -pub(crate) enum Either { - Executor(E), - Periphery(P), -} - -impl<'a, E, P> ChipUsageGetter for Either<&'a E, &'a P> -where - E: ChipUsageGetter, - P: ChipUsageGetter, -{ - fn air_name(&self) -> String { - match self { - Either::Executor(chip) => chip.air_name(), - Either::Periphery(chip) => chip.air_name(), - } - } - fn current_trace_height(&self) -> usize { - match self { - Either::Executor(chip) => chip.current_trace_height(), - Either::Periphery(chip) => chip.current_trace_height(), - } - } - fn trace_width(&self) -> usize { - match self { - Either::Executor(chip) => chip.trace_width(), - Either::Periphery(chip) => chip.trace_width(), - } - } - fn current_trace_cells(&self) -> usize { - match self { - Either::Executor(chip) => chip.current_trace_cells(), - Either::Periphery(chip) => chip.current_trace_cells(), - } - } -} - #[cfg(test)] mod tests { - use p3_baby_bear::BabyBear; + use openvm_circuit_derive::AnyEnum; + use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; use super::*; - use crate::system::memory::interface::MemoryInterface; + use crate::{arch::VmCircuitConfig, system::memory::interface::MemoryInterfaceAirs}; #[allow(dead_code)] #[derive(Copy, Clone)] @@ -1343,15 +843,17 @@ mod tests { #[test] fn test_system_bus_indices() { let config = SystemConfig::default().with_continuations(); - let complex = SystemComplex::::new(config); - assert_eq!(complex.base.execution_bus().index(), 0); - assert_eq!(complex.base.memory_bus().index(), 1); - assert_eq!(complex.base.program_bus().index(), 2); - assert_eq!(complex.base.range_checker_bus().index(), 3); - match &complex.memory_controller().interface_chip { - MemoryInterface::Persistent { boundary_chip, .. } => { - assert_eq!(boundary_chip.air.merkle_bus.index, 4); - assert_eq!(boundary_chip.air.compression_bus.index, 5); + let inventory: AirInventory = config.create_airs().unwrap(); + let system = inventory.system(); + let port = system.port(); + assert_eq!(port.execution_bus.index(), 0); + assert_eq!(port.memory_bridge.memory_bus().index(), 1); + assert_eq!(port.program_bus.index(), 2); + assert_eq!(port.memory_bridge.range_bus().index(), 3); + match &system.memory.interface { + MemoryInterfaceAirs::Persistent { boundary, .. } => { + assert_eq!(boundary.merkle_bus.index, 4); + assert_eq!(boundary.compression_bus.index, 5); } _ => unreachable!(), }; diff --git a/crates/vm/src/arch/hasher/mod.rs b/crates/vm/src/arch/hasher/mod.rs index df90a55e4b..e858da25f9 100644 --- a/crates/vm/src/arch/hasher/mod.rs +++ b/crates/vm/src/arch/hasher/mod.rs @@ -24,10 +24,10 @@ pub trait Hasher { leaves[0] } } -pub trait HasherChip: Hasher { +pub trait HasherChip: Hasher + Send + Sync { /// Stateful version of `hash` for recording the event in the chip. - fn compress_and_record(&mut self, left: &[F; CHUNK], right: &[F; CHUNK]) -> [F; CHUNK]; - fn hash_and_record(&mut self, values: &[F; CHUNK]) -> [F; CHUNK] { + fn compress_and_record(&self, left: &[F; CHUNK], right: &[F; CHUNK]) -> [F; CHUNK]; + fn hash_and_record(&self, values: &[F; CHUNK]) -> [F; CHUNK] { self.compress_and_record(values, &[F::ZERO; CHUNK]) } } diff --git a/crates/vm/src/arch/integration_api.rs b/crates/vm/src/arch/integration_api.rs index b1116d8c48..1105cb40a8 100644 --- a/crates/vm/src/arch/integration_api.rs +++ b/crates/vm/src/arch/integration_api.rs @@ -1,28 +1,23 @@ -use std::{ - array::from_fn, - borrow::Borrow, - marker::PhantomData, - sync::{Arc, Mutex}, -}; +use std::{array::from_fn, borrow::Borrow, marker::PhantomData, sync::Arc}; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_stark_backend::{ - air_builders::{debug::DebugConstraintBuilder, symbolic::SymbolicRapBuilder}, config::{StarkGenericConfig, Val}, p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{FieldAlgebra, PrimeField32}, + p3_field::FieldAlgebra, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; -use super::{ExecutionState, InstructionExecutor, Result}; -use crate::system::memory::{MemoryController, OfflineMemory}; +use crate::{ + arch::RowMajorMatrixArena, + system::memory::{online::TracingMemory, MemoryAuxColsFactory, SharedMemoryHelper}, +}; /// The interface between primitive AIR and machine adapter AIR. pub trait VmAdapterInterface { @@ -37,60 +32,6 @@ pub trait VmAdapterInterface { type ProcessedInstruction; } -/// The adapter owns all memory accesses and timestamp changes. -/// The adapter AIR should also own `ExecutionBridge` and `MemoryBridge`. -pub trait VmAdapterChip { - /// Records generated by adapter before main instruction execution - type ReadRecord: Send + Serialize + DeserializeOwned; - /// Records generated by adapter after main instruction execution - type WriteRecord: Send + Serialize + DeserializeOwned; - /// AdapterAir should not have public values - type Air: BaseAir + Clone; - - type Interface: VmAdapterInterface; - - /// Given instruction, perform memory reads and return only the read data that the integrator - /// needs to use. This is called at the start of instruction execution. - /// - /// The implementer may choose to store data in the `Self::ReadRecord` struct, for example in - /// an [Option], which will later be sent to the `postprocess` method. - #[allow(clippy::type_complexity)] - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )>; - - /// Given instruction and the data to write, perform memory writes and return the `(record, - /// next_timestamp)` of the full adapter record for this instruction. This is guaranteed to - /// be called after `preprocess`. - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)>; - - /// Populates `row_slice` with values corresponding to `record`. - /// The provided `row_slice` will have length equal to `self.air().width()`. - /// This function will be called for each row in the trace which is being used, and all other - /// rows in the trace will be filled with zeroes. - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ); - - fn air(&self) -> &Self::Air; -} - pub trait VmAdapterAir: BaseAir { type Interface: VmAdapterInterface; @@ -111,47 +52,6 @@ pub trait VmAdapterAir: BaseAir { fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var; } -/// Trait to be implemented on primitive chip to integrate with the machine. -pub trait VmCoreChip> { - /// Minimum data that must be recorded to be able to generate trace for one row of - /// `PrimitiveAir`. - type Record: Send + Serialize + DeserializeOwned; - /// The primitive AIR with main constraints that do not depend on memory and other - /// architecture-specifics. - type Air: BaseAirWithPublicValues + Clone; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)>; - - fn get_opcode_name(&self, opcode: usize) -> String; - - /// Populates `row_slice` with values corresponding to `record`. - /// The provided `row_slice` will have length equal to `self.air().width()`. - /// This function will be called for each row in the trace which is being used, and all other - /// rows in the trace will be filled with zeroes. - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record); - - /// Returns a list of public values to publish. - fn generate_public_values(&self) -> Vec { - vec![] - } - - fn air(&self) -> &Self::Air; - - /// Finalize the trace, especially the padded rows if the all-zero rows don't satisfy the - /// constraints. This is done **after** records are consumed and the trace matrix is - /// generated. Most implementations should just leave the default implementation if padding - /// with rows of all 0s satisfies the constraints. - fn finalize(&self, _trace: &mut RowMajorMatrix, _num_records: usize) { - // do nothing by default - } -} - pub trait VmCoreAir: BaseAirWithPublicValues where AB: AirBuilder, @@ -183,22 +83,6 @@ where } } -pub struct AdapterRuntimeContext> { - /// Leave as `None` to allow the adapter to decide the `to_pc` automatically. - pub to_pc: Option, - pub writes: I::Writes, -} - -impl> AdapterRuntimeContext { - /// Leave `to_pc` as `None` to allow the adapter to decide the `to_pc` automatically. - pub fn without_pc(writes: impl Into) -> Self { - Self { - to_pc: None, - writes: writes.into(), - } - } -} - pub struct AdapterAirContext> { /// Leave as `None` to allow the adapter to decide the `to_pc` automatically. pub to_pc: Option, @@ -207,140 +91,125 @@ pub struct AdapterAirContext> { pub instruction: I::ProcessedInstruction, } -pub struct VmChipWrapper, C: VmCoreChip> { - pub adapter: A, - pub core: C, - pub records: Vec<(A::ReadRecord, A::WriteRecord, C::Record)>, - offline_memory: Arc>>, -} - -const DEFAULT_RECORDS_CAPACITY: usize = 1 << 20; +/// Helper trait for CPU tracegen. +pub trait TraceFiller: Send + Sync { + /// Populates `trace`. This function will always be called after + /// [`TraceExecutor::execute`], so the `trace` should already contain the records necessary to + /// fill in the rest of it. + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) where + F: Send + Sync + Clone, + { + let width = trace.width(); + trace.values[..rows_used * width] + .par_chunks_exact_mut(width) + .for_each(|row_slice| { + self.fill_trace_row(mem_helper, row_slice); + }); + trace.values[rows_used * width..] + .par_chunks_exact_mut(width) + .for_each(|row_slice| { + self.fill_dummy_trace_row(row_slice); + }); + } -impl VmChipWrapper -where - A: VmAdapterChip, - C: VmCoreChip, -{ - pub fn new(adapter: A, core: C, offline_memory: Arc>>) -> Self { - Self { - adapter, - core, - records: Vec::with_capacity(DEFAULT_RECORDS_CAPACITY), - offline_memory, - } + /// Populates `row_slice`. This function will always be called after + /// [`TraceExecutor::execute`], so the `row_slice` should already contain context necessary to + /// fill in the rest of the row. This function will be called for each row in the trace which + /// is being used, and for all other rows in the trace see `fill_dummy_trace_row`. + /// + /// The provided `row_slice` will have length equal to the width of the AIR. + fn fill_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, _row_slice: &mut [F]) { + unreachable!("fill_trace_row is not implemented") } -} -impl InstructionExecutor for VmChipWrapper -where - F: PrimeField32, - A: VmAdapterChip + Send + Sync, - M: VmCoreChip + Send + Sync, -{ - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - ) -> Result> { - let (reads, read_record) = self.adapter.preprocess(memory, instruction)?; - let (output, core_record) = - self.core - .execute_instruction(instruction, from_state.pc, reads)?; - let (to_state, write_record) = - self.adapter - .postprocess(memory, instruction, from_state, output, &read_record)?; - self.records.push((read_record, write_record, core_record)); - Ok(to_state) + /// Populates `row_slice`. This function will be called on dummy rows. + /// By default the trace is padded with empty (all 0) rows to make the height a power of 2. + /// + /// The provided `row_slice` will have length equal to the width of the AIR. + fn fill_dummy_trace_row(&self, _row_slice: &mut [F]) { + // By default, the row is filled with zeroes } - fn get_opcode_name(&self, opcode: usize) -> String { - self.core.get_opcode_name(opcode) + /// Returns a list of public values to publish. + fn generate_public_values(&self) -> Vec { + vec![] } } -// Note[jpw]: the statement we want is: -// - when A::Air is an AdapterAir for all AirBuilders needed by stark-backend -// - and when M::Air is an CoreAir for all AirBuilders needed by stark-backend, -// then VmAirWrapper is an Air for all AirBuilders needed -// by stark-backend, which is equivalent to saying it implements AirRef -// The where clauses to achieve this statement is unfortunately really verbose. -impl Chip for VmChipWrapper, A, C> +/// We want a blanket implementation of `Chip` on any struct that +/// implements [TraceFiller] but due to Rust orphan rules, we need a wrapper struct. +// @dev: You could make a macro, but it's hard to handle generics in the struct definition. +#[derive(derive_new::new)] +pub struct VmChipWrapper { + pub inner: FILLER, + pub mem_helper: SharedMemoryHelper, +} + +impl Chip> for VmChipWrapper, FILLER> where SC: StarkGenericConfig, - Val: PrimeField32, - A: VmAdapterChip> + Send + Sync, - C: VmCoreChip, A::Interface> + Send + Sync, - A::Air: Send + Sync + 'static, - A::Air: VmAdapterAir>>, - A::Air: for<'a> VmAdapterAir>, - C::Air: Send + Sync + 'static, - C::Air: VmCoreAir< - SymbolicRapBuilder>, - >>>::Interface, - >, - C::Air: for<'a> VmCoreAir< - DebugConstraintBuilder<'a, SC>, - >>::Interface, - >, + FILLER: TraceFiller>, + RA: RowMajorMatrixArena>, { - fn air(&self) -> AirRef { - let air: VmAirWrapper = VmAirWrapper { - adapter: self.adapter.air().clone(), - core: self.core.air().clone(), - }; - Arc::new(air) + fn generate_proving_ctx(&self, arena: RA) -> AirProvingContext> { + let rows_used = arena.trace_offset() / arena.width(); + let mut trace = arena.into_matrix(); + let mem_helper = self.mem_helper.as_borrowed(); + self.inner.fill_trace(&mem_helper, &mut trace, rows_used); + + AirProvingContext::simple(Arc::new(trace), self.inner.generate_public_values()) } +} - fn generate_air_proof_input(self) -> AirProofInput { - let num_records = self.records.len(); - let height = next_power_of_two_or_zero(num_records); - let core_width = self.core.air().width(); - let adapter_width = self.adapter.air().width(); - let width = core_width + adapter_width; - let mut values = Val::::zero_vec(height * width); - - let memory = self.offline_memory.lock().unwrap(); - - // This zip only goes through records. - // The padding rows between records.len()..height are filled with zeros. - values - .par_chunks_mut(width) - .zip(self.records.into_par_iter()) - .for_each(|(row_slice, record)| { - let (adapter_row, core_row) = row_slice.split_at_mut(adapter_width); - self.adapter - .generate_trace_row(adapter_row, record.0, record.1, &memory); - self.core.generate_trace_row(core_row, record.2); - }); +/// A helper trait for expressing generic state accesses within the implementation of +/// [TraceExecutor]. Note that this is only a helper trait when the same interface of state access +/// is reused or shared by multiple implementations. It is not required to implement this trait if +/// it is easier to implement the [TraceExecutor] trait directly without this trait. +pub trait AdapterTraceExecutor: Clone { + const WIDTH: usize; + type ReadData; + type WriteData; + // @dev This can either be a &mut _ type or a struct with &mut _ fields. + // The latter is helpful if we want to directly write certain values in place into a trace + // matrix. + type RecordMut<'a> + where + Self: 'a; - let mut trace = RowMajorMatrix::new(values, width); - self.core.finalize(&mut trace, num_records); + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>); - AirProofInput::simple(trace, self.core.generate_public_values()) - } + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData; + + fn write( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ); } -impl ChipUsageGetter for VmChipWrapper -where - A: VmAdapterChip + Sync, - M: VmCoreChip + Sync, -{ - fn air_name(&self) -> String { - format!( - "<{},{}>", - get_air_name(self.adapter.air()), - get_air_name(self.core.air()) - ) - } - fn current_trace_height(&self) -> usize { - self.records.len() - } - fn trace_width(&self) -> usize { - self.adapter.air().width() + self.core.air().width() - } +// NOTE[jpw]: cannot reuse `TraceSubRowGenerator` trait because we need associated constant +// `WIDTH`. +pub trait AdapterTraceFiller: Send + Sync { + const WIDTH: usize; + /// Post-execution filling of rest of adapter row. + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, adapter_row: &mut [F]); } +// ============================== Adapter|Core Air Wrapper =============================== + +#[derive(Clone, Copy, derive_new::new)] pub struct VmAirWrapper { pub adapter: A, pub core: C, @@ -455,40 +324,6 @@ impl< type ProcessedInstruction = MinimalInstruction; } -pub struct VecHeapTwoReadsAdapterInterface< - T, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->(PhantomData); - -impl< - T, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > VmAdapterInterface - for VecHeapTwoReadsAdapterInterface< - T, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > -{ - type Reads = ( - [[T; READ_SIZE]; BLOCKS_PER_READ1], - [[T; READ_SIZE]; BLOCKS_PER_READ2], - ); - type Writes = [[T; WRITE_SIZE]; BLOCKS_PER_WRITE]; - type ProcessedInstruction = MinimalInstruction; -} - /// Similar to `BasicAdapterInterface`, but it flattens the reads and writes into a single flat /// array for each pub struct FlatInterface( @@ -608,49 +443,6 @@ mod conversions { } } - // AdapterRuntimeContext: VecHeapAdapterInterface -> DynInterface - impl< - T, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - >, - > for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - >, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> VecHeapAdapterInterface impl< T, @@ -682,155 +474,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> VecHeapAdapterInterface - impl< - T, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - - // AdapterAirContext: DynInterface -> VecHeapTwoReadsAdapterInterface - impl< - T: Clone, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterAirContext< - T, - VecHeapTwoReadsAdapterInterface< - T, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from(ctx: AdapterAirContext>) -> Self { - AdapterAirContext { - to_pc: ctx.to_pc, - reads: ctx.reads.into(), - writes: ctx.writes.into(), - instruction: ctx.instruction.into(), - } - } - } - - // AdapterRuntimeContext: DynInterface -> VecHeapAdapterInterface - impl< - T, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - VecHeapTwoReadsAdapterInterface< - T, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - - // AdapterRuntimeContext: BasicInterface -> VecHeapAdapterInterface - impl< - T, - PI, - const BASIC_NUM_READS: usize, - const BASIC_NUM_WRITES: usize, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface< - T, - PI, - BASIC_NUM_READS, - BASIC_NUM_WRITES, - READ_SIZE, - WRITE_SIZE, - >, - >, - > - for AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface< - T, - PI, - BASIC_NUM_READS, - BASIC_NUM_WRITES, - READ_SIZE, - WRITE_SIZE, - >, - >, - ) -> Self { - assert_eq!(BASIC_NUM_WRITES, BLOCKS_PER_WRITE); - let mut writes_it = ctx.writes.into_iter(); - let writes = from_fn(|_| writes_it.next().unwrap()); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - // AdapterAirContext: BasicInterface -> VecHeapAdapterInterface impl< T, @@ -985,79 +628,6 @@ mod conversions { } } - // AdapterRuntimeContext: BasicInterface -> FlatInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - const READ_CELLS: usize, - const WRITE_CELLS: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - > for AdapterRuntimeContext> - { - /// ## Panics - /// If `WRITE_CELLS != NUM_WRITES * WRITE_SIZE`. - /// This is a runtime assertion until Rust const generics expressions are stabilized. - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - ) -> AdapterRuntimeContext> { - assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE); - let mut writes_it = ctx.writes.into_iter().flatten(); - let writes = from_fn(|_| writes_it.next().unwrap()); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - - // AdapterRuntimeContext: FlatInterface -> BasicInterface - impl< - T: FieldAlgebra, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - const READ_CELLS: usize, - const WRITE_CELLS: usize, - > From>> - for AdapterRuntimeContext< - T, - BasicAdapterInterface, - > - { - /// ## Panics - /// If `WRITE_CELLS != NUM_WRITES * WRITE_SIZE`. - /// This is a runtime assertion until Rust const generics expressions are stabilized. - fn from( - ctx: AdapterRuntimeContext>, - ) -> AdapterRuntimeContext< - T, - BasicAdapterInterface, - > { - assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE); - let mut writes_it = ctx.writes.into_iter(); - let writes: [[T; WRITE_SIZE]; NUM_WRITES] = - from_fn(|_| from_fn(|_| writes_it.next().unwrap())); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - impl From> for DynArray { fn from(v: Vec) -> Self { Self(v) @@ -1169,35 +739,6 @@ mod conversions { } } - // AdapterRuntimeContext: BasicInterface -> DynInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - > for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> BasicInterface impl< T, @@ -1224,28 +765,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> BasicInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - BasicAdapterInterface, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: FlatInterface -> DynInterface impl>, const READ_CELLS: usize, const WRITE_CELLS: usize> From>> @@ -1261,21 +780,6 @@ mod conversions { } } - // AdapterRuntimeContext: FlatInterface -> DynInterface - impl - From>> - for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext>, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.to_vec().into(), - } - } - } - impl From> for DynArray { fn from(m: MinimalInstruction) -> Self { Self(vec![m.is_valid, m.opcode]) diff --git a/crates/vm/src/arch/interpreter.rs b/crates/vm/src/arch/interpreter.rs new file mode 100644 index 0000000000..753b11ff2a --- /dev/null +++ b/crates/vm/src/arch/interpreter.rs @@ -0,0 +1,592 @@ +use std::{ + alloc::{alloc, dealloc, handle_alloc_error, Layout}, + borrow::{Borrow, BorrowMut}, + ptr::NonNull, +}; + +use itertools::Itertools; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + exe::{SparseMemoryImage, VmExe}, + instruction::Instruction, + program::{Program, DEFAULT_PC_STEP}, + LocalOpcode, SystemOpcode, +}; +use openvm_stark_backend::p3_field::PrimeField32; +use tracing::info_span; + +use crate::{ + arch::{ + execution_mode::{ + e1::E1Ctx, + metered::{MeteredCtx, Segment}, + E1ExecutionCtx, E2ExecutionCtx, + }, + ExecuteFunc, ExecutionError, Executor, ExecutorInventory, ExitCode, MeteredExecutor, + StaticProgramError, Streams, SystemConfig, VmExecState, VmState, + }, + system::memory::online::GuestMemory, +}; + +/// VM pure executor(E1/E2 executor) which doesn't consider trace generation. +/// Note: This executor doesn't hold any VM state and can be used for multiple execution. +/// +/// The generic `Ctx` and constructor determine whether this supported pure execution or metered +/// execution. +// @dev: the lifetime 'a represents the lifetime of borrowed ExecutorInventory, which must outlive +// the InterpretedInstance because `pre_compute_buf` may contain pointers to references held by +// executors. +pub struct InterpretedInstance<'a, F, Ctx> { + system_config: SystemConfig, + // SAFETY: this is not actually dead code, but `pre_compute_insns` contains raw pointer refers + // to this buffer. + #[allow(dead_code)] + pre_compute_buf: AlignedBuf, + /// Instruction table of function pointers and pointers to the pre-computed buffer. Indexed by + /// `pc_index = (pc - pc_base) / DEFAULT_PC_STEP`. + pre_compute_insns: Vec>, + + pc_base: u32, + pc_start: u32, + + init_memory: SparseMemoryImage, +} + +struct PreComputeInstruction<'a, F, Ctx> { + pub handler: ExecuteFunc, + pub pre_compute: &'a [u8], +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct TerminatePreCompute { + exit_code: u32, +} + +macro_rules! execute_with_metrics { + ($span:literal, $pc_base:expr, $exec_state:expr, $pre_compute_insts:expr) => {{ + #[cfg(feature = "metrics")] + let start = std::time::Instant::now(); + #[cfg(feature = "metrics")] + let start_instret = $exec_state.instret; + + info_span!($span).in_scope(|| unsafe { + execute_trampoline($pc_base, $exec_state, $pre_compute_insts); + }); + + #[cfg(feature = "metrics")] + { + let elapsed = start.elapsed(); + let insns = $exec_state.instret - start_instret; + metrics::counter!("insns").absolute(insns); + metrics::gauge!(concat!($span, "_insn_mi/s")) + .set(insns as f64 / elapsed.as_micros() as f64); + } + }}; +} + +// Constructors for E1 and E2 respectively, which generate pre-computed buffers and function +// pointers +// - Generic in `Ctx` + +impl<'a, F, Ctx> InterpretedInstance<'a, F, Ctx> +where + F: PrimeField32, + Ctx: E1ExecutionCtx, +{ + /// Creates a new interpreter instance for pure execution. + // (E1 execution) + pub fn new( + inventory: &'a ExecutorInventory, + exe: &VmExe, + ) -> Result + where + E: Executor, + { + let program = &exe.program; + let pre_compute_max_size = get_pre_compute_max_size(program, inventory); + let mut pre_compute_buf = alloc_pre_compute_buf(program.len(), pre_compute_max_size); + let mut split_pre_compute_buf = + split_pre_compute_buf(program, &mut pre_compute_buf, pre_compute_max_size); + let pre_compute_insns = get_pre_compute_instructions::( + program, + inventory, + &mut split_pre_compute_buf, + )?; + let pc_base = program.pc_base; + let pc_start = exe.pc_start; + let init_memory = exe.init_memory.clone(); + + Ok(Self { + system_config: inventory.config().clone(), + pre_compute_buf, + pre_compute_insns, + pc_base, + pc_start, + init_memory, + }) + } +} + +impl<'a, F, Ctx> InterpretedInstance<'a, F, Ctx> +where + F: PrimeField32, + Ctx: E2ExecutionCtx, +{ + /// Creates a new interpreter instance for pure execution. + // (E1 execution) + pub fn new_metered( + inventory: &'a ExecutorInventory, + exe: &VmExe, + executor_idx_to_air_idx: &[usize], + ) -> Result + where + E: MeteredExecutor, + { + let program = &exe.program; + let pre_compute_max_size = get_metered_pre_compute_max_size(program, inventory); + let mut pre_compute_buf = alloc_pre_compute_buf(program.len(), pre_compute_max_size); + let mut split_pre_compute_buf = + split_pre_compute_buf(program, &mut pre_compute_buf, pre_compute_max_size); + let pre_compute_insns = get_metered_pre_compute_instructions::( + program, + inventory, + executor_idx_to_air_idx, + &mut split_pre_compute_buf, + )?; + + let pc_base = program.pc_base; + let pc_start = exe.pc_start; + let init_memory = exe.init_memory.clone(); + + Ok(Self { + system_config: inventory.config().clone(), + pre_compute_buf, + pre_compute_insns, + pc_base, + pc_start, + init_memory, + }) + } +} + +// Execute functions specialize to relevant Ctx types to provide more streamlines APIs + +impl InterpretedInstance<'_, F, E1Ctx> +where + F: PrimeField32, +{ + /// Pure execution, without metering, for the given `inputs`. Execution begins from the initial + /// state specified by the `VmExe`. This function executes the program until either termination + /// if `num_insns` is `None` or for exactly `num_insns` instructions if `num_insns` is `Some`. + /// + /// Returns the final VM state when execution stops. + pub fn execute( + &self, + inputs: impl Into>, + num_insns: Option, + ) -> Result, ExecutionError> { + let vm_state = VmState::initial( + &self.system_config.memory_config, + self.init_memory.clone(), + self.pc_start, + inputs, + ); + self.execute_from_state(vm_state, num_insns) + } + + /// Pure execution, without metering, from the given `VmState`. This function executes the + /// program until either termination if `num_insns` is `None` or for exactly `num_insns` + /// instructions if `num_insns` is `Some`. + /// + /// Returns the final VM state when execution stops. + pub fn execute_from_state( + &self, + from_state: VmState, + num_insns: Option, + ) -> Result, ExecutionError> { + let ctx = E1Ctx::new(num_insns); + let mut exec_state = VmExecState::new(from_state, ctx); + // Start execution + execute_with_metrics!( + "execute_e1", + self.pc_base, + &mut exec_state, + &self.pre_compute_insns + ); + if num_insns.is_some() { + check_exit_code(exec_state.exit_code)?; + } else { + check_termination(exec_state.exit_code)?; + } + Ok(exec_state.vm_state) + } +} + +impl InterpretedInstance<'_, F, MeteredCtx> +where + F: PrimeField32, +{ + /// Metered execution for the given `inputs`. Execution begins from the initial + /// state specified by the `VmExe`. This function executes the program until termination. + /// + /// Returns the segmentation boundary data and the final VM state when execution stops. + pub fn execute_metered( + &self, + inputs: impl Into>, + ctx: MeteredCtx, + ) -> Result<(Vec, VmState), ExecutionError> { + let vm_state = VmState::initial( + &self.system_config.memory_config, + self.init_memory.clone(), + self.pc_start, + inputs, + ); + self.execute_metered_from_state(vm_state, ctx) + } + + /// Metered execution for the given `VmState`. This function executes the program until + /// termination. + /// + /// Returns the segmentation boundary data and the final VM state when execution stops. + /// + /// The [MeteredCtx] can be constructed using either + /// [VmExecutor::build_metered_ctx](super::VmExecutor::build_metered_ctx) or + /// [VirtualMachine::build_metered_ctx](super::VirtualMachine::build_metered_ctx). + pub fn execute_metered_from_state( + &self, + from_state: VmState, + ctx: MeteredCtx, + ) -> Result<(Vec, VmState), ExecutionError> { + let mut exec_state = VmExecState::new(from_state, ctx); + // Start execution + execute_with_metrics!( + "execute_metered", + self.pc_base, + &mut exec_state, + &self.pre_compute_insns + ); + check_termination(exec_state.exit_code)?; + let VmExecState { vm_state, ctx, .. } = exec_state; + Ok((ctx.into_segments(), vm_state)) + } +} + +fn alloc_pre_compute_buf(program_len: usize, pre_compute_max_size: usize) -> AlignedBuf { + let buf_len = program_len * pre_compute_max_size; + AlignedBuf::uninit(buf_len, pre_compute_max_size) +} + +fn split_pre_compute_buf<'a, F>( + program: &Program, + pre_compute_buf: &'a mut AlignedBuf, + pre_compute_max_size: usize, +) -> Vec<&'a mut [u8]> { + let program_len = program.instructions_and_debug_infos.len(); + let buf_len = program_len * pre_compute_max_size; + let mut pre_compute_buf_ptr = + unsafe { std::slice::from_raw_parts_mut(pre_compute_buf.ptr, buf_len) }; + let mut split_pre_compute_buf = Vec::with_capacity(program_len); + for _ in 0..program_len { + let (first, last) = pre_compute_buf_ptr.split_at_mut(pre_compute_max_size); + pre_compute_buf_ptr = last; + split_pre_compute_buf.push(first); + } + split_pre_compute_buf +} + +/// Executes using function pointers with the trampoline (loop) approach. +/// +/// # Safety +/// The `fn_ptrs` pointer to pre-computed buffers that outlive this function. +#[inline(always)] +unsafe fn execute_trampoline( + pc_base: u32, + vm_state: &mut VmExecState, + fn_ptrs: &[PreComputeInstruction], +) { + while vm_state + .exit_code + .as_ref() + .is_ok_and(|exit_code| exit_code.is_none()) + { + if Ctx::should_suspend(vm_state) { + break; + } + let pc_index = get_pc_index(pc_base, vm_state.pc); + if let Some(inst) = fn_ptrs.get(pc_index) { + // SAFETY: pre_compute assumed to live long enough + unsafe { (inst.handler)(inst.pre_compute, vm_state) }; + } else { + vm_state.exit_code = Err(ExecutionError::PcOutOfBounds { + pc: vm_state.pc, + pc_base, + program_len: fn_ptrs.len(), + }); + } + } + if vm_state + .exit_code + .as_ref() + .is_ok_and(|exit_code| exit_code.is_some()) + { + Ctx::on_terminate(vm_state); + } +} + +#[inline(always)] +fn get_pc_index(pc_base: u32, pc: u32) -> usize { + ((pc - pc_base) / DEFAULT_PC_STEP) as usize +} + +/// Bytes allocated according to the given Layout +// @dev: This is duplicate from the openvm crate, but it doesn't seem worth importing `openvm` here +// just for this. +pub struct AlignedBuf { + pub ptr: *mut u8, + pub layout: Layout, +} + +impl AlignedBuf { + /// Allocate a new buffer whose start address is aligned to `align` bytes. + /// *NOTE* if `len` is zero then a creates new `NonNull` that is dangling and 16-byte aligned. + pub fn uninit(len: usize, align: usize) -> Self { + let layout = Layout::from_size_align(len, align).unwrap(); + if layout.size() == 0 { + return Self { + ptr: NonNull::::dangling().as_ptr() as *mut u8, + layout, + }; + } + // SAFETY: `len` is nonzero + let ptr = unsafe { alloc(layout) }; + if ptr.is_null() { + handle_alloc_error(layout); + } + AlignedBuf { ptr, layout } + } +} + +impl Drop for AlignedBuf { + fn drop(&mut self) { + if self.layout.size() != 0 { + unsafe { + dealloc(self.ptr, self.layout); + } + } + } +} + +unsafe fn terminate_execute_e12_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &TerminatePreCompute = pre_compute.borrow(); + vm_state.instret += 1; + vm_state.exit_code = Ok(Some(pre_compute.exit_code)); +} + +fn get_pre_compute_max_size>( + program: &Program, + inventory: &ExecutorInventory, +) -> usize { + program + .instructions_and_debug_infos + .iter() + .map(|inst_opt| { + if let Some((inst, _)) = inst_opt { + if let Some(size) = system_opcode_pre_compute_size(inst) { + size + } else { + inventory + .get_executor(inst.opcode) + .map(|executor| executor.pre_compute_size()) + .unwrap() + } + } else { + 0 + } + }) + .max() + .unwrap() + .next_power_of_two() +} + +fn get_metered_pre_compute_max_size>( + program: &Program, + inventory: &ExecutorInventory, +) -> usize { + program + .instructions_and_debug_infos + .iter() + .map(|inst_opt| { + if let Some((inst, _)) = inst_opt { + if let Some(size) = system_opcode_pre_compute_size(inst) { + size + } else { + inventory + .get_executor(inst.opcode) + .map(|executor| executor.metered_pre_compute_size()) + .unwrap() + } + } else { + 0 + } + }) + .max() + .unwrap() + .next_power_of_two() +} + +fn system_opcode_pre_compute_size(inst: &Instruction) -> Option { + if inst.opcode == SystemOpcode::TERMINATE.global_opcode() { + return Some(size_of::()); + } + None +} + +fn get_pre_compute_instructions<'a, F, Ctx, E>( + program: &Program, + inventory: &'a ExecutorInventory, + pre_compute: &mut [&mut [u8]], +) -> Result>, StaticProgramError> +where + F: PrimeField32, + Ctx: E1ExecutionCtx, + E: Executor, +{ + program + .instructions_and_debug_infos + .iter() + .zip_eq(pre_compute.iter_mut()) + .enumerate() + .map(|(i, (inst_opt, buf))| { + // SAFETY: we cast to raw pointer and then borrow to remove the lifetime. This is safe + // only in the current context because `buf` comes from `pre_compute_buf` which will + // outlive the returned `PreComputeInstruction`s. + let buf: &mut [u8] = unsafe { &mut *(*buf as *mut [u8]) }; + let pre_inst = if let Some((inst, _)) = inst_opt { + tracing::trace!("get_pre_compute_instruction {inst:?}"); + let pc = program.pc_base + i as u32 * DEFAULT_PC_STEP; + if let Some(handler) = get_system_opcode_handler(inst, buf) { + PreComputeInstruction { + handler, + pre_compute: buf, + } + } else if let Some(executor) = inventory.get_executor(inst.opcode) { + PreComputeInstruction { + handler: executor.pre_compute(pc, inst, buf)?, + pre_compute: buf, + } + } else { + return Err(StaticProgramError::DisabledOperation { + pc, + opcode: inst.opcode, + }); + } + } else { + // Dead instruction at this pc + PreComputeInstruction { + handler: |_, vm_state| { + vm_state.exit_code = Err(ExecutionError::Unreachable(vm_state.pc)); + }, + pre_compute: buf, + } + }; + Ok(pre_inst) + }) + .collect::, _>>() +} + +fn get_metered_pre_compute_instructions<'a, F, Ctx, E>( + program: &Program, + inventory: &'a ExecutorInventory, + executor_idx_to_air_idx: &[usize], + pre_compute: &mut [&mut [u8]], +) -> Result>, StaticProgramError> +where + F: PrimeField32, + Ctx: E2ExecutionCtx, + E: MeteredExecutor, +{ + program + .instructions_and_debug_infos + .iter() + .zip_eq(pre_compute.iter_mut()) + .enumerate() + .map(|(i, (inst_opt, buf))| { + // SAFETY: we cast to raw pointer and then borrow to remove the lifetime. This is safe + // only in the current context because `buf` comes from `pre_compute_buf` which will + // outlive the returned `PreComputeInstruction`s. + let buf: &mut [u8] = unsafe { &mut *(*buf as *mut [u8]) }; + let pre_inst = if let Some((inst, _)) = inst_opt { + tracing::trace!("get_metered_pre_compute_instruction {inst:?}"); + let pc = program.pc_base + i as u32 * DEFAULT_PC_STEP; + if let Some(handler) = get_system_opcode_handler(inst, buf) { + PreComputeInstruction { + handler, + pre_compute: buf, + } + } else if let Some(&executor_idx) = inventory.instruction_lookup.get(&inst.opcode) { + let executor_idx = executor_idx as usize; + let executor = inventory + .executors + .get(executor_idx) + .expect("ExecutorInventory ensures executor_idx is in bounds"); + let air_idx = executor_idx_to_air_idx[executor_idx]; + PreComputeInstruction { + handler: executor.metered_pre_compute(air_idx, pc, inst, buf)?, + pre_compute: buf, + } + } else { + return Err(StaticProgramError::DisabledOperation { + pc, + opcode: inst.opcode, + }); + } + } else { + PreComputeInstruction { + handler: |_, vm_state| { + vm_state.exit_code = Err(ExecutionError::Unreachable(vm_state.pc)); + }, + pre_compute: buf, + } + }; + Ok(pre_inst) + }) + .collect::, _>>() +} + +fn get_system_opcode_handler( + inst: &Instruction, + buf: &mut [u8], +) -> Option> { + if inst.opcode == SystemOpcode::TERMINATE.global_opcode() { + let pre_compute: &mut TerminatePreCompute = buf.borrow_mut(); + pre_compute.exit_code = inst.c.as_canonical_u32(); + return Some(terminate_execute_e12_impl); + } + None +} + +/// Errors if exit code is either error or terminated with non-successful exit code. +fn check_exit_code(exit_code: Result, ExecutionError>) -> Result<(), ExecutionError> { + let exit_code = exit_code?; + if let Some(exit_code) = exit_code { + // This means execution did terminate + if exit_code != ExitCode::Success as u32 { + return Err(ExecutionError::FailedWithExitCode(exit_code)); + } + } + Ok(()) +} + +/// Same as [check_exit_code] but errors if program did not terminate. +fn check_termination(exit_code: Result, ExecutionError>) -> Result<(), ExecutionError> { + let did_terminate = matches!(exit_code.as_ref(), Ok(Some(_))); + check_exit_code(exit_code)?; + match did_terminate { + true => Ok(()), + false => Err(ExecutionError::DidNotTerminate), + } +} diff --git a/crates/vm/src/arch/interpreter_preflight.rs b/crates/vm/src/arch/interpreter_preflight.rs new file mode 100644 index 0000000000..c878579b9a --- /dev/null +++ b/crates/vm/src/arch/interpreter_preflight.rs @@ -0,0 +1,139 @@ +use openvm_stark_backend::p3_field::PrimeField32; + +use super::ExecutionError; +use crate::{ + arch::{ + execution_mode::tracegen::TracegenCtx, instructions::*, Arena, PreflightExecutor, + VmExecState, VmStateMut, + }, + system::{memory::online::TracingMemory, program::ProgramHandler}, +}; + +pub struct PreflightInterpretedInstance { + pub handler: ProgramHandler, + executor_idx_to_air_idx: Vec, +} + +impl PreflightInterpretedInstance +where + F: PrimeField32, +{ + /// Creates a new execution segment from a program and initial state, using parent VM config + pub fn new(handler: ProgramHandler, executor_idx_to_air_idx: Vec) -> Self { + Self { + handler, + executor_idx_to_air_idx, + } + } + + /// Stopping is triggered by should_stop() or if VM is terminated + pub fn execute_from_state( + &mut self, + state: &mut VmExecState>, + ) -> Result<(), ExecutionError> + where + RA: Arena, + E: PreflightExecutor, + { + loop { + if let Ok(Some(_)) = state.exit_code { + // should terminate + break; + } + if state + .ctx + .instret_end + .is_some_and(|instret_end| state.instret >= instret_end) + { + // should suspend + break; + } + + // Fetch, decode and execute single instruction + self.execute_instruction(state)?; + state.instret += 1; + } + + Ok(()) + } + + /// Executes a single instruction and updates VM state + #[inline(always)] + fn execute_instruction( + &mut self, + state: &mut VmExecState>, + ) -> Result<(), ExecutionError> + where + RA: Arena, + E: PreflightExecutor, + { + let pc = state.pc; + let (executor, pc_entry) = self.handler.get_executor(pc)?; + tracing::trace!("pc: {pc:#x} | {:?}", pc_entry.insn); + + let opcode = pc_entry.insn.opcode; + let c = pc_entry.insn.c; + // Handle termination instruction + if opcode.as_usize() == SystemOpcode::CLASS_OFFSET + SystemOpcode::TERMINATE as usize { + state.exit_code = Ok(Some(c.as_canonical_u32())); + return Ok(()); + } + + // Execute the instruction using the control implementation + tracing::trace!( + "opcode: {} | timestamp: {}", + executor.get_opcode_name(pc_entry.insn.opcode.as_usize()), + state.memory.timestamp() + ); + let arena = unsafe { + // SAFETY: executor_idx is guarantee to be within bounds by ProgramHandler constructor + let air_idx = *self + .executor_idx_to_air_idx + .get_unchecked(pc_entry.executor_idx as usize); + // SAFETY: air_idx is a valid AIR index in the vkey, and always construct arenas with + // length equal to num_airs + state.ctx.arenas.get_unchecked_mut(air_idx) + }; + let state_mut = VmStateMut { + pc: &mut state.vm_state.pc, + memory: &mut state.vm_state.memory, + streams: &mut state.vm_state.streams, + rng: &mut state.vm_state.rng, + ctx: arena, + #[cfg(feature = "metrics")] + metrics: &mut state.vm_state.metrics, + }; + executor.execute(state_mut, &pc_entry.insn)?; + + #[cfg(feature = "metrics")] + { + crate::metrics::update_instruction_metrics(state, executor, pc, pc_entry); + } + + Ok(()) + } +} + +/// Macro for executing and emitting metrics for instructions/s and number of instructions executed. +/// Does not include any tracing span. +#[macro_export] +macro_rules! execute_spanned { + ($name:literal, $executor:expr, $state:expr) => {{ + #[cfg(feature = "metrics")] + let start = std::time::Instant::now(); + #[cfg(feature = "metrics")] + let start_instret = $state.instret; + + let result = $executor.execute_from_state($state); + + #[cfg(feature = "metrics")] + { + let elapsed = start.elapsed(); + let insns = $state.instret - start_instret; + metrics::counter!("insns").absolute(insns); + metrics::gauge!(concat!($name, "_insn_mi/s")) + .set(insns as f64 / elapsed.as_micros() as f64); + } + result + }}; +} diff --git a/crates/vm/src/arch/mod.rs b/crates/vm/src/arch/mod.rs index 63ee5e6f8b..72612d1c30 100644 --- a/crates/vm/src/arch/mod.rs +++ b/crates/vm/src/arch/mod.rs @@ -1,26 +1,36 @@ mod config; /// Instruction execution traits and types. /// Execution bus and interface. -mod execution; +pub mod execution; +/// Execution context types for different execution modes. +pub mod execution_mode; /// Traits and builders to compose collections of chips into a virtual machine. mod extensions; /// Traits and wrappers to facilitate VM chip integration mod integration_api; -/// Runtime execution and segmentation -pub mod segment; -/// Top level [VirtualMachine] constructor and API. +/// [RecordArena] trait definitions and implementations. Currently there are two concrete +/// implementations: [MatrixRecordArena] and [DenseRecordArena]. +mod record_arena; +/// VM state definitions +mod state; +/// Top level [VmExecutor] and [VirtualMachine] constructor and API. pub mod vm; -pub use openvm_instructions as instructions; - pub mod hasher; +/// Interpreter for pure and metered VM execution +pub mod interpreter; +/// Interpreter for preflight VM execution, for trace generation purposes. +pub mod interpreter_preflight; /// Testing framework #[cfg(any(test, feature = "test-utils"))] pub mod testing; pub use config::*; pub use execution::*; +pub use execution_mode::{E1ExecutionCtx, E2ExecutionCtx}; pub use extensions::*; pub use integration_api::*; -pub use segment::*; +pub use openvm_instructions as instructions; +pub use record_arena::*; +pub use state::*; pub use vm::*; diff --git a/crates/vm/src/arch/record_arena.rs b/crates/vm/src/arch/record_arena.rs new file mode 100644 index 0000000000..cfb269d095 --- /dev/null +++ b/crates/vm/src/arch/record_arena.rs @@ -0,0 +1,667 @@ +use std::{ + borrow::BorrowMut, + io::Cursor, + marker::PhantomData, + ptr::{copy_nonoverlapping, slice_from_raw_parts_mut}, +}; + +use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_stark_backend::{ + p3_field::{Field, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, +}; + +pub trait Arena { + /// Currently `width` always refers to the main trace width. + fn with_capacity(height: usize, width: usize) -> Self; + + fn is_empty(&self) -> bool; + + /// Only used for metric collection purposes. Intended usage is that for a record arena that + /// corresponds to a single trace matrix, this function can extract the current number of used + /// rows of the corresponding trace matrix. This is currently expected to work only for + /// [MatrixRecordArena]. + #[cfg(feature = "metrics")] + fn current_trace_height(&self) -> usize { + 0 + } +} + +/// Given some minimum layout of type `Layout`, the `RecordArena` should allocate a buffer, of +/// size possibly larger than the record, and then return mutable pointers to the record within the +/// buffer. +pub trait RecordArena<'a, Layout, RecordMut> { + /// Allocates underlying buffer and returns a mutable reference `RecordMut`. + /// Note that calling this function may not call an underlying memory allocation as the record + /// arena may be virtual. + fn alloc(&'a mut self, layout: Layout) -> RecordMut; +} + +/// Helper trait for arenas backed by row-major matrices. +pub trait RowMajorMatrixArena: Arena { + /// Set the arena's capacity based on the projected trace height. + fn set_capacity(&mut self, trace_height: usize); + fn width(&self) -> usize; + fn trace_offset(&self) -> usize; + fn into_matrix(self) -> RowMajorMatrix; +} + +/// `SizedRecord` is a trait that provides additional information about the size and alignment +/// requirements of a record. Should be implemented on RecordMut types +pub trait SizedRecord { + /// The minimal size in bytes that the RecordMut requires to be properly constructed + /// given the layout. + fn size(layout: &Layout) -> usize; + /// The minimal alignment required for the RecordMut to be properly constructed + /// given the layout. + fn alignment(layout: &Layout) -> usize; +} + +impl SizedRecord for &mut Record +where + Record: Sized, +{ + fn size(_layout: &Layout) -> usize { + size_of::() + } + + fn alignment(_layout: &Layout) -> usize { + align_of::() + } +} + +// =================== Arena Implementations ========================= + +#[derive(Default)] +pub struct MatrixRecordArena { + pub trace_buffer: Vec, + pub width: usize, + pub trace_offset: usize, + /// The arena is created with a specified capacity, but may be truncated before being converted + /// into a [RowMajorMatrix] if `allow_truncate == true`. If `allow_truncate == false`, then the + /// matrix will never be truncated. The latter is used if the trace matrix must have fixed + /// dimensions (e.g., for a static verifier). + pub(super) allow_truncate: bool, +} + +impl MatrixRecordArena { + pub fn alloc_single_row(&mut self) -> &mut [u8] { + self.alloc_buffer(1) + } + + pub fn alloc_buffer(&mut self, num_rows: usize) -> &mut [u8] { + let start = self.trace_offset; + self.trace_offset += num_rows * self.width; + let row_slice = &mut self.trace_buffer[start..self.trace_offset]; + let size = size_of_val(row_slice); + let ptr = row_slice as *mut [F] as *mut u8; + // SAFETY: + // - `ptr` is non-null + // - `size` is correct + // - alignment of `u8` is always satisfied + unsafe { &mut *std::ptr::slice_from_raw_parts_mut(ptr, size) } + } + + pub fn force_matrix_dimensions(&mut self) { + self.allow_truncate = false; + } +} + +impl Arena for MatrixRecordArena { + fn with_capacity(height: usize, width: usize) -> Self { + let height = next_power_of_two_or_zero(height); + let trace_buffer = F::zero_vec(height * width); + Self { + trace_buffer, + width, + trace_offset: 0, + allow_truncate: true, + } + } + + fn is_empty(&self) -> bool { + self.trace_offset == 0 + } + + #[cfg(feature = "metrics")] + fn current_trace_height(&self) -> usize { + self.trace_offset / self.width + } +} + +impl RowMajorMatrixArena for MatrixRecordArena { + fn set_capacity(&mut self, trace_height: usize) { + let size = trace_height * self.width; + // PERF: use memset + self.trace_buffer.resize(size, F::ZERO); + } + + fn width(&self) -> usize { + self.width + } + + fn trace_offset(&self) -> usize { + self.trace_offset + } + + fn into_matrix(mut self) -> RowMajorMatrix { + let width = self.width(); + assert_eq!(self.trace_offset() % width, 0); + let rows_used = self.trace_offset() / width; + let height = next_power_of_two_or_zero(rows_used); + // This should be automatic since trace_buffer's height is a power of two: + assert!(height.checked_mul(width).unwrap() <= self.trace_buffer.len()); + if self.allow_truncate { + self.trace_buffer.truncate(height * width); + } else { + assert_eq!(self.trace_buffer.len() % width, 0); + let height = self.trace_buffer.len() / width; + assert!(height.is_power_of_two() || height == 0); + } + RowMajorMatrix::new(self.trace_buffer, self.width) + } +} + +pub struct DenseRecordArena { + pub records_buffer: Cursor>, +} + +const MAX_ALIGNMENT: usize = 32; + +impl DenseRecordArena { + /// Creates a new [DenseRecordArena] with the given capacity in bytes. + pub fn with_byte_capacity(size_bytes: usize) -> Self { + let buffer = vec![0; size_bytes + MAX_ALIGNMENT]; + let offset = (MAX_ALIGNMENT - (buffer.as_ptr() as usize % MAX_ALIGNMENT)) % MAX_ALIGNMENT; + let mut cursor = Cursor::new(buffer); + cursor.set_position(offset as u64); + Self { + records_buffer: cursor, + } + } + + pub fn set_byte_capacity(&mut self, size_bytes: usize) { + let buffer = vec![0; size_bytes + MAX_ALIGNMENT]; + let offset = (MAX_ALIGNMENT - (buffer.as_ptr() as usize % MAX_ALIGNMENT)) % MAX_ALIGNMENT; + let mut cursor = Cursor::new(buffer); + cursor.set_position(offset as u64); + self.records_buffer = cursor; + } + + /// Returns the allocated size of the arena in bytes. + /// + /// **Note**: This may include additional bytes for alignment. + pub fn capacity(&self) -> usize { + self.records_buffer.get_ref().len() + } + + /// Allocates `count` bytes and returns as a mutable slice. + pub fn alloc_bytes<'a>(&mut self, count: usize) -> &'a mut [u8] { + let begin = self.records_buffer.position(); + debug_assert!( + begin as usize + count <= self.records_buffer.get_ref().len(), + "failed to allocate {count} bytes from {begin} when the capacity is {}", + self.records_buffer.get_ref().len() + ); + self.records_buffer.set_position(begin + count as u64); + unsafe { + std::slice::from_raw_parts_mut( + self.records_buffer + .get_mut() + .as_mut_ptr() + .add(begin as usize), + count, + ) + } + } + + pub fn allocated(&self) -> &[u8] { + let size = self.records_buffer.position() as usize; + let offset = (MAX_ALIGNMENT + - (self.records_buffer.get_ref().as_ptr() as usize % MAX_ALIGNMENT)) + % MAX_ALIGNMENT; + &self.records_buffer.get_ref()[offset..size] + } + + pub fn allocated_mut(&mut self) -> &mut [u8] { + let size = self.records_buffer.position() as usize; + let offset = (MAX_ALIGNMENT + - (self.records_buffer.get_ref().as_ptr() as usize % MAX_ALIGNMENT)) + % MAX_ALIGNMENT; + &mut self.records_buffer.get_mut()[offset..size] + } + + pub fn align_to(&mut self, alignment: usize) { + debug_assert!(MAX_ALIGNMENT % alignment == 0); + let offset = + (alignment - (self.records_buffer.get_ref().as_ptr() as usize % alignment)) % alignment; + self.records_buffer.set_position(offset as u64); + } + + // Returns a [RecordSeeker] on the allocated buffer + pub fn get_record_seeker(&mut self) -> RecordSeeker { + RecordSeeker::new(self.allocated_mut()) + } +} + +impl Arena for DenseRecordArena { + // TODO[jpw]: treat `width` as AIR width in number of columns for now + fn with_capacity(height: usize, width: usize) -> Self { + let size_bytes = height * (width * size_of::()); + Self::with_byte_capacity(size_bytes) + } + + fn is_empty(&self) -> bool { + self.allocated().is_empty() + } +} + +// =================== Helper Functions ================================= + +/// Converts a field element slice into a record type. +/// This function transmutes the `&mut [F]` to raw bytes, +/// then uses the `CustomBorrow` trait to transmute to the desired record type `T`. +/// ## Safety +/// `slice` must satisfy the requirements of the `CustomBorrow` trait. +pub unsafe fn get_record_from_slice<'a, T, F, L>(slice: &mut &'a mut [F], layout: L) -> T +where + [u8]: CustomBorrow<'a, T, L>, +{ + // The alignment of `[u8]` is always satisfiedƒ + let record_buffer = + &mut *slice_from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, size_of_val::<[F]>(*slice)); + let record: T = record_buffer.custom_borrow(layout); + record +} + +/// A trait that allows for custom implementation of `borrow` given the necessary information +/// This is useful for record structs that have dynamic size +pub trait CustomBorrow<'a, T, L> { + fn custom_borrow(&'a mut self, layout: L) -> T; + + /// Given `&self` as a valid starting pointer of a reference that has already been previously + /// allocated and written to, extracts and returns the corresponding layout. + /// This must work even if `T` is not sized. + /// + /// # Safety + /// - `&self` must be a valid starting pointer on which `custom_borrow` has already been called + /// - The data underlying `&self` has already been written to and is self-describing, so layout + /// can be extracted + unsafe fn extract_layout(&self) -> L; +} + +// This is a helper struct that implements a few utility methods +pub struct RecordSeeker<'a, RA, RecordMut, Layout> { + pub buffer: &'a mut [u8], // The buffer that the records are written to + _phantom: PhantomData<(RA, RecordMut, Layout)>, +} + +impl<'a, RA, RecordMut, Layout> RecordSeeker<'a, RA, RecordMut, Layout> { + pub fn new(record_buffer: &'a mut [u8]) -> Self { + Self { + buffer: record_buffer, + _phantom: PhantomData, + } + } +} + +// `RecordSeeker` implementation for [DenseRecordArena], with [MultiRowLayout] +// **NOTE** Assumes that `layout` can be extracted from the record alone +impl<'a, R, M> RecordSeeker<'a, DenseRecordArena, R, MultiRowLayout> +where + [u8]: CustomBorrow<'a, R, MultiRowLayout>, + R: SizedRecord>, + M: MultiRowMetadata + Clone, +{ + // Returns the layout at the given offset in the buffer + // **SAFETY**: `offset` has to be a valid offset, pointing to the start of a record + pub fn get_layout_at(offset: &mut usize, buffer: &[u8]) -> MultiRowLayout { + let buffer = &buffer[*offset..]; + unsafe { buffer.extract_layout() } + } + + // Returns a record at the given offset in the buffer + // **SAFETY**: `offset` has to be a valid offset, pointing to the start of a record + pub fn get_record_at(offset: &mut usize, buffer: &'a mut [u8]) -> R { + let layout = Self::get_layout_at(offset, buffer); + let buffer = &mut buffer[*offset..]; + let record_size = R::size(&layout); + let record_alignment = R::alignment(&layout); + let aligned_record_size = record_size.next_multiple_of(record_alignment); + let record: R = buffer.custom_borrow(layout); + *offset += aligned_record_size; + record + } + + // Returns a vector of all the records in the buffer + pub fn extract_records(&'a mut self) -> Vec { + let mut records = Vec::new(); + let len = self.buffer.len(); + let buff = &mut self.buffer[..]; + let mut offset = 0; + while offset < len { + let record: R = { + let buff = unsafe { &mut *slice_from_raw_parts_mut(buff.as_mut_ptr(), len) }; + Self::get_record_at(&mut offset, buff) + }; + records.push(record); + } + records + } + + // Transfers the records in the buffer to a [MatrixRecordArena], used in testing + pub fn transfer_to_matrix_arena( + &'a mut self, + arena: &mut MatrixRecordArena, + ) { + let len = self.buffer.len(); + arena.trace_offset = 0; + let mut offset = 0; + while offset < len { + let layout = Self::get_layout_at(&mut offset, self.buffer); + let record_size = R::size(&layout); + let record_alignment = R::alignment(&layout); + let aligned_record_size = record_size.next_multiple_of(record_alignment); + let src_ptr = unsafe { self.buffer.as_ptr().add(offset) }; + let dst_ptr = arena + .alloc_buffer(layout.metadata.get_num_rows()) + .as_mut_ptr(); + unsafe { copy_nonoverlapping(src_ptr, dst_ptr, aligned_record_size) }; + offset += aligned_record_size; + } + } +} + +// `RecordSeeker` implementation for [DenseRecordArena], with [AdapterCoreLayout] +// **NOTE** Assumes that `layout` is the same for all the records, so it is expected to be passed as +// a parameter +impl<'a, A, C, M> RecordSeeker<'a, DenseRecordArena, (A, C), AdapterCoreLayout> +where + [u8]: CustomBorrow<'a, A, AdapterCoreLayout> + CustomBorrow<'a, C, AdapterCoreLayout>, + A: SizedRecord>, + C: SizedRecord>, + M: AdapterCoreMetadata + Clone, +{ + // Returns the aligned sizes of the adapter and core records given their layout + pub fn get_aligned_sizes(layout: &AdapterCoreLayout) -> (usize, usize) { + let adapter_alignment = A::alignment(layout); + let core_alignment = C::alignment(layout); + let adapter_size = A::size(layout); + let aligned_adapter_size = adapter_size.next_multiple_of(core_alignment); + let core_size = C::size(layout); + let aligned_core_size = (aligned_adapter_size + core_size) + .next_multiple_of(adapter_alignment) + - aligned_adapter_size; + (aligned_adapter_size, aligned_core_size) + } + + // Returns the aligned size of a single record given its layout + pub fn get_aligned_record_size(layout: &AdapterCoreLayout) -> usize { + let (adapter_size, core_size) = Self::get_aligned_sizes(layout); + adapter_size + core_size + } + + // Returns a record at the given offset in the buffer + // **SAFETY**: `offset` has to be a valid offset, pointing to the start of a record + pub fn get_record_at( + offset: &mut usize, + buffer: &'a mut [u8], + layout: AdapterCoreLayout, + ) -> (A, C) { + let buffer = &mut buffer[*offset..]; + let (adapter_size, core_size) = Self::get_aligned_sizes(&layout); + let (adapter_buffer, core_buffer) = unsafe { buffer.split_at_mut_unchecked(adapter_size) }; + let adapter_record: A = adapter_buffer.custom_borrow(layout.clone()); + let core_record: C = core_buffer.custom_borrow(layout); + *offset += adapter_size + core_size; + (adapter_record, core_record) + } + + // Returns a vector of all the records in the buffer + pub fn extract_records(&'a mut self, layout: AdapterCoreLayout) -> Vec<(A, C)> { + let mut records = Vec::new(); + let len = self.buffer.len(); + let buff = &mut self.buffer[..]; + let mut offset = 0; + while offset < len { + let record: (A, C) = { + let buff = unsafe { &mut *slice_from_raw_parts_mut(buff.as_mut_ptr(), len) }; + Self::get_record_at(&mut offset, buff, layout.clone()) + }; + records.push(record); + } + records + } + + // Transfers the records in the buffer to a [MatrixRecordArena], used in testing + pub fn transfer_to_matrix_arena( + &'a mut self, + arena: &mut MatrixRecordArena, + layout: AdapterCoreLayout, + ) { + let len = self.buffer.len(); + arena.trace_offset = 0; + let mut offset = 0; + let (adapter_size, core_size) = Self::get_aligned_sizes(&layout); + while offset < len { + let dst_buffer = arena.alloc_single_row(); + let (adapter_buf, core_buf) = + unsafe { dst_buffer.split_at_mut_unchecked(M::get_adapter_width()) }; + unsafe { + let src_ptr = self.buffer.as_ptr().add(offset); + copy_nonoverlapping(src_ptr, adapter_buf.as_mut_ptr(), adapter_size); + copy_nonoverlapping(src_ptr.add(adapter_size), core_buf.as_mut_ptr(), core_size); + } + offset += adapter_size + core_size; + } + } +} + +// ============================== MultiRowLayout ======================================= + +/// Minimal layout information that [RecordArena] requires for record allocation +/// in scenarios involving chips that: +/// - can have multiple rows per record, and +/// - have possibly variable length records +/// +/// **NOTE**: `M` is the metadata type that implements `MultiRowMetadata` +#[derive(Debug, Clone, Default, derive_new::new)] +pub struct MultiRowLayout { + pub metadata: M, +} + +/// `Metadata` types need to implement this trait to be used with `MultiRowLayout` +pub trait MultiRowMetadata { + fn get_num_rows(&self) -> usize; +} + +/// Empty metadata that implements `MultiRowMetadata` with `get_num_rows` always returning 1 +#[derive(Debug, Clone, Default, derive_new::new)] +pub struct EmptyMultiRowMetadata {} + +impl MultiRowMetadata for EmptyMultiRowMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + 1 + } +} + +/// Empty metadata that implements `MultiRowMetadata` +pub type EmptyMultiRowLayout = MultiRowLayout; + +/// If a struct implements `BorrowMut`, then the same implementation can be used for +/// `CustomBorrow::custom_borrow` with any layout +impl<'a, T: Sized, L: Default> CustomBorrow<'a, &'a mut T, L> for [u8] +where + [u8]: BorrowMut, +{ + fn custom_borrow(&'a mut self, _layout: L) -> &'a mut T { + self.borrow_mut() + } + + unsafe fn extract_layout(&self) -> L { + L::default() + } +} + +/// [RecordArena] implementation for [MatrixRecordArena], with [MultiRowLayout] +/// **NOTE**: `R` is the RecordMut type +impl<'a, F: Field, M: MultiRowMetadata, R> RecordArena<'a, MultiRowLayout, R> + for MatrixRecordArena +where + [u8]: CustomBorrow<'a, R, MultiRowLayout>, +{ + fn alloc(&'a mut self, layout: MultiRowLayout) -> R { + let buffer = self.alloc_buffer(layout.metadata.get_num_rows()); + let record: R = buffer.custom_borrow(layout); + record + } +} + +/// [RecordArena] implementation for [DenseRecordArena], with [MultiRowLayout] +/// **NOTE**: `R` is the RecordMut type +impl<'a, R, M> RecordArena<'a, MultiRowLayout, R> for DenseRecordArena +where + [u8]: CustomBorrow<'a, R, MultiRowLayout>, + R: SizedRecord>, +{ + fn alloc(&'a mut self, layout: MultiRowLayout) -> R { + let record_size = R::size(&layout); + let record_alignment = R::alignment(&layout); + let aligned_record_size = record_size.next_multiple_of(record_alignment); + let buffer = self.alloc_bytes(aligned_record_size); + let record: R = buffer.custom_borrow(layout); + record + } +} + +// ============================== AdapterCoreLayout ======================================= +// This is for integration_api usage + +/// Minimal layout information that [RecordArena] requires for record allocation +/// in scenarios involving chips that: +/// - have a single row per record, and +/// - have trace row = [adapter_row, core_row] +/// +/// **NOTE**: `M` is the metadata type that implements `AdapterCoreMetadata` +#[derive(Debug, Clone, Default)] +pub struct AdapterCoreLayout { + pub metadata: M, +} + +/// `Metadata` types need to implement this trait to be used with `AdapterCoreLayout` +/// **NOTE**: get_adapter_width returns the size in bytes +pub trait AdapterCoreMetadata { + fn get_adapter_width() -> usize; +} + +impl AdapterCoreLayout { + pub fn new() -> Self + where + M: Default, + { + Self::default() + } + + pub fn with_metadata(metadata: M) -> Self { + Self { metadata } + } +} + +/// Empty metadata that implements `AdapterCoreMetadata` +/// **NOTE**: `AS` is the adapter type that implements `AdapterTraceExecutor` +/// **WARNING**: `AS::WIDTH` is the number of field elements, not the size in bytes +pub struct AdapterCoreEmptyMetadata { + _phantom: PhantomData<(F, AS)>, +} + +impl Clone for AdapterCoreEmptyMetadata { + fn clone(&self) -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl AdapterCoreEmptyMetadata { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl Default for AdapterCoreEmptyMetadata { + fn default() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl AdapterCoreMetadata for AdapterCoreEmptyMetadata +where + AS: super::AdapterTraceExecutor, +{ + #[inline(always)] + fn get_adapter_width() -> usize { + AS::WIDTH * size_of::() + } +} + +/// AdapterCoreLayout with empty metadata that can be used by chips that have record type +/// (&mut A, &mut C) where `A` and `C` are `Sized` +pub type EmptyAdapterCoreLayout = AdapterCoreLayout>; + +/// [RecordArena] implementation for [MatrixRecordArena], with [AdapterCoreLayout] +/// **NOTE**: `A` is the adapter RecordMut type and `C` is the core RecordMut type +impl<'a, F: Field, A, C, M: AdapterCoreMetadata> RecordArena<'a, AdapterCoreLayout, (A, C)> + for MatrixRecordArena +where + [u8]: CustomBorrow<'a, A, AdapterCoreLayout> + CustomBorrow<'a, C, AdapterCoreLayout>, + M: Clone, +{ + fn alloc(&'a mut self, layout: AdapterCoreLayout) -> (A, C) { + let adapter_width = M::get_adapter_width(); + let buffer = self.alloc_single_row(); + // Doing a unchecked split here for perf + let (adapter_buffer, core_buffer) = unsafe { buffer.split_at_mut_unchecked(adapter_width) }; + + let adapter_record: A = adapter_buffer.custom_borrow(layout.clone()); + let core_record: C = core_buffer.custom_borrow(layout); + + (adapter_record, core_record) + } +} + +/// [RecordArena] implementation for [DenseRecordArena], with [AdapterCoreLayout] +/// **NOTE**: `A` is the adapter RecordMut type and `C` is the core record type +impl<'a, A, C, M> RecordArena<'a, AdapterCoreLayout, (A, C)> for DenseRecordArena +where + [u8]: CustomBorrow<'a, A, AdapterCoreLayout> + CustomBorrow<'a, C, AdapterCoreLayout>, + M: Clone, + A: SizedRecord>, + C: SizedRecord>, +{ + fn alloc(&'a mut self, layout: AdapterCoreLayout) -> (A, C) { + let adapter_alignment = A::alignment(&layout); + let core_alignment = C::alignment(&layout); + let adapter_size = A::size(&layout); + let aligned_adapter_size = adapter_size.next_multiple_of(core_alignment); + let core_size = C::size(&layout); + let aligned_core_size = (aligned_adapter_size + core_size) + .next_multiple_of(adapter_alignment) + - aligned_adapter_size; + debug_assert_eq!(MAX_ALIGNMENT % adapter_alignment, 0); + debug_assert_eq!(MAX_ALIGNMENT % core_alignment, 0); + let buffer = self.alloc_bytes(aligned_adapter_size + aligned_core_size); + // Doing an unchecked split here for perf + let (adapter_buffer, core_buffer) = + unsafe { buffer.split_at_mut_unchecked(aligned_adapter_size) }; + + let adapter_record: A = adapter_buffer.custom_borrow(layout.clone()); + let core_record: C = core_buffer.custom_borrow(layout); + + (adapter_record, core_record) + } +} diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs deleted file mode 100644 index 634632ce2b..0000000000 --- a/crates/vm/src/arch/segment.rs +++ /dev/null @@ -1,387 +0,0 @@ -use std::sync::Arc; - -use backtrace::Backtrace; -use openvm_instructions::{ - exe::FnBounds, - instruction::{DebugInfo, Instruction}, - program::Program, -}; -use openvm_stark_backend::{ - config::{Domain, StarkGenericConfig}, - keygen::types::LinearConstraint, - p3_commit::PolynomialSpace, - p3_field::PrimeField32, - prover::types::{CommittedTraceData, ProofInput}, - utils::metrics_span, - Chip, -}; - -use super::{ - ExecutionError, GenerationError, Streams, SystemBase, SystemConfig, VmChipComplex, - VmComplexTraceHeights, VmConfig, -}; -#[cfg(feature = "bench-metrics")] -use crate::metrics::VmMetrics; -use crate::{ - arch::{instructions::*, ExecutionState, InstructionExecutor}, - system::memory::MemoryImage, -}; - -/// Check segment every 100 instructions. -const SEGMENT_CHECK_INTERVAL: usize = 100; - -const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; -// a heuristic number for the maximum number of cells per chip in a segment -// a few reasons for this number: -// 1. `VmAirWrapper` is -// the chip with the most cells in a segment from the reth-benchmark. -// 2. `VmAirWrapper`: -// its trace width is 36 and its after challenge trace width is 80. -const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120; - -pub trait SegmentationStrategy: - std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe -{ - /// Whether the execution should segment based on the trace heights and cells. - /// - /// Air names are provided for debugging purposes. - fn should_segment( - &self, - air_names: &[String], - trace_heights: &[usize], - trace_cells: &[usize], - ) -> bool; - - /// A strategy that segments more aggressively than the current one. - /// - /// Called when `should_segment` results in a segment that is infeasible. Execution will be - /// re-run with the stricter segmentation strategy. - fn stricter_strategy(&self) -> Arc; -} - -/// Default segmentation strategy: segment if any chip's height or cells exceed the limits. -#[derive(Debug, Clone)] -pub struct DefaultSegmentationStrategy { - max_segment_len: usize, - max_cells_per_chip_in_segment: usize, -} - -impl Default for DefaultSegmentationStrategy { - fn default() -> Self { - Self { - max_segment_len: DEFAULT_MAX_SEGMENT_LEN, - max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT, - } - } -} - -impl DefaultSegmentationStrategy { - pub fn new_with_max_segment_len(max_segment_len: usize) -> Self { - Self { - max_segment_len, - max_cells_per_chip_in_segment: max_segment_len * 120, - } - } - - pub fn new(max_segment_len: usize, max_cells_per_chip_in_segment: usize) -> Self { - Self { - max_segment_len, - max_cells_per_chip_in_segment, - } - } - - pub fn max_segment_len(&self) -> usize { - self.max_segment_len - } -} - -const SEGMENTATION_BACKOFF_FACTOR: usize = 4; - -impl SegmentationStrategy for DefaultSegmentationStrategy { - fn should_segment( - &self, - air_names: &[String], - trace_heights: &[usize], - trace_cells: &[usize], - ) -> bool { - for (i, &height) in trace_heights.iter().enumerate() { - if height > self.max_segment_len { - tracing::info!( - "Should segment because chip {} (name: {}) has height {}", - i, - air_names[i], - height - ); - return true; - } - } - for (i, &num_cells) in trace_cells.iter().enumerate() { - if num_cells > self.max_cells_per_chip_in_segment { - tracing::info!( - "Should segment because chip {} (name: {}) has {} cells", - i, - air_names[i], - num_cells - ); - return true; - } - } - false - } - - fn stricter_strategy(&self) -> Arc { - Arc::new(Self { - max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR, - max_cells_per_chip_in_segment: self.max_cells_per_chip_in_segment - / SEGMENTATION_BACKOFF_FACTOR, - }) - } -} - -pub struct ExecutionSegment -where - F: PrimeField32, - VC: VmConfig, -{ - pub chip_complex: VmChipComplex, - /// Memory image after segment was executed. Not used in trace generation. - pub final_memory: Option>, - - pub since_last_segment_check: usize, - pub trace_height_constraints: Vec, - - /// Air names for debug purposes only. - pub(crate) air_names: Vec, - /// Metrics collected for this execution segment alone. - #[cfg(feature = "bench-metrics")] - pub metrics: VmMetrics, -} - -pub struct ExecutionSegmentState { - pub pc: u32, - pub is_terminated: bool, -} - -impl> ExecutionSegment { - /// Creates a new execution segment from a program and initial state, using parent VM config - pub fn new( - config: &VC, - program: Program, - init_streams: Streams, - initial_memory: Option>, - trace_height_constraints: Vec, - #[allow(unused_variables)] fn_bounds: FnBounds, - ) -> Self { - let mut chip_complex = config.create_chip_complex().unwrap(); - chip_complex.set_streams(init_streams); - let program = if !config.system().profiling { - program.strip_debug_infos() - } else { - program - }; - chip_complex.set_program(program); - - if let Some(initial_memory) = initial_memory { - chip_complex.set_initial_memory(initial_memory); - } - let air_names = chip_complex.air_names(); - - Self { - chip_complex, - final_memory: None, - air_names, - trace_height_constraints, - #[cfg(feature = "bench-metrics")] - metrics: VmMetrics { - fn_bounds, - ..Default::default() - }, - since_last_segment_check: 0, - } - } - - pub fn system_config(&self) -> &SystemConfig { - self.chip_complex.config() - } - - pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) { - self.chip_complex - .set_override_system_trace_heights(overridden_heights.system); - self.chip_complex - .set_override_inventory_trace_heights(overridden_heights.inventory); - } - - /// Stopping is triggered by should_segment() - pub fn execute_from_pc( - &mut self, - mut pc: u32, - ) -> Result { - let mut timestamp = self.chip_complex.memory_controller().timestamp(); - let mut prev_backtrace: Option = None; - - self.chip_complex - .connector_chip_mut() - .begin(ExecutionState::new(pc, timestamp)); - - let mut did_terminate = false; - - loop { - #[allow(unused_variables)] - let (opcode, dsl_instr) = { - let Self { - chip_complex, - #[cfg(feature = "bench-metrics")] - metrics, - .. - } = self; - let SystemBase { - program_chip, - memory_controller, - .. - } = &mut chip_complex.base; - - let (instruction, debug_info) = program_chip.get_instruction(pc)?; - tracing::trace!("pc: {pc:#x} | time: {timestamp} | {:?}", instruction); - - #[allow(unused_variables)] - let (dsl_instr, trace) = debug_info.as_ref().map_or( - (None, None), - |DebugInfo { - dsl_instruction, - trace, - }| (Some(dsl_instruction), trace.as_ref()), - ); - - let &Instruction { opcode, c, .. } = instruction; - if opcode == SystemOpcode::TERMINATE.global_opcode() { - did_terminate = true; - self.chip_complex.connector_chip_mut().end( - ExecutionState::new(pc, timestamp), - Some(c.as_canonical_u32()), - ); - break; - } - - // Some phantom instruction handling is more convenient to do here than in - // PhantomChip. - if opcode == SystemOpcode::PHANTOM.global_opcode() { - // Note: the discriminant is the lower 16 bits of the c operand. - let discriminant = c.as_canonical_u32() as u16; - let phantom = SysPhantom::from_repr(discriminant); - tracing::trace!("pc: {pc:#x} | system phantom: {phantom:?}"); - match phantom { - Some(SysPhantom::DebugPanic) => { - if let Some(mut backtrace) = prev_backtrace { - backtrace.resolve(); - eprintln!("openvm program failure; backtrace:\n{:?}", backtrace); - } else { - eprintln!("openvm program failure; no backtrace"); - } - return Err(ExecutionError::Fail { pc }); - } - Some(SysPhantom::CtStart) => - { - #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .start(dsl_instr.cloned().unwrap_or("Default".to_string())) - } - Some(SysPhantom::CtEnd) => - { - #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .end(dsl_instr.cloned().unwrap_or("Default".to_string())) - } - _ => {} - } - } - prev_backtrace = trace.cloned(); - - if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { - let next_state = InstructionExecutor::execute( - executor, - memory_controller, - instruction, - ExecutionState::new(pc, timestamp), - )?; - assert!(next_state.timestamp > timestamp); - pc = next_state.pc; - timestamp = next_state.timestamp; - } else { - return Err(ExecutionError::DisabledOperation { pc, opcode }); - }; - (opcode, dsl_instr.cloned()) - }; - - #[cfg(feature = "bench-metrics")] - self.update_instruction_metrics(pc, opcode, dsl_instr); - - if self.should_segment() { - self.chip_complex - .connector_chip_mut() - .end(ExecutionState::new(pc, timestamp), None); - break; - } - } - self.final_memory = Some( - self.chip_complex - .base - .memory_controller - .memory_image() - .clone(), - ); - - Ok(ExecutionSegmentState { - pc, - is_terminated: did_terminate, - }) - } - - /// Generate ProofInput to prove the segment. Should be called after ::execute - pub fn generate_proof_input( - #[allow(unused_mut)] mut self, - cached_program: Option>, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - VC::Executor: Chip, - VC::Periphery: Chip, - { - metrics_span("trace_gen_time_ms", || { - self.chip_complex.generate_proof_input( - cached_program, - &self.trace_height_constraints, - #[cfg(feature = "bench-metrics")] - &mut self.metrics, - ) - }) - } - - /// Returns bool of whether to switch to next segment or not. This is called every clock cycle - /// inside of Core trace generation. - fn should_segment(&mut self) -> bool { - if !self.system_config().continuation_enabled { - return false; - } - // Avoid checking segment too often. - if self.since_last_segment_check != SEGMENT_CHECK_INTERVAL { - self.since_last_segment_check += 1; - return false; - } - self.since_last_segment_check = 0; - let segmentation_strategy = &self.system_config().segmentation_strategy; - segmentation_strategy.should_segment( - &self.air_names, - &self - .chip_complex - .dynamic_trace_heights() - .collect::>(), - &self.chip_complex.current_trace_cells(), - ) - } - - pub fn current_trace_cells(&self) -> Vec { - self.chip_complex.current_trace_cells() - } -} diff --git a/crates/vm/src/arch/state.rs b/crates/vm/src/arch/state.rs new file mode 100644 index 0000000000..d4caeb3402 --- /dev/null +++ b/crates/vm/src/arch/state.rs @@ -0,0 +1,160 @@ +use std::{ + fmt::Debug, + ops::{Deref, DerefMut}, +}; + +use openvm_instructions::exe::SparseMemoryImage; +use rand::{rngs::StdRng, SeedableRng}; + +use super::{create_memory_image, ExecutionError, Streams}; +#[cfg(feature = "metrics")] +use crate::metrics::VmMetrics; +use crate::{ + arch::{execution_mode::E1ExecutionCtx, MemoryConfig}, + system::memory::online::GuestMemory, +}; + +/// Represents the core state of a VM. +pub struct VmState { + pub instret: u64, + pub pc: u32, + pub memory: MEM, + pub streams: Streams, + pub rng: StdRng, + #[cfg(feature = "metrics")] + pub metrics: VmMetrics, +} + +impl VmState { + pub fn new( + instret: u64, + pc: u32, + memory: MEM, + streams: impl Into>, + seed: u64, + ) -> Self { + Self { + instret, + pc, + memory, + streams: streams.into(), + rng: StdRng::seed_from_u64(seed), + #[cfg(feature = "metrics")] + metrics: VmMetrics::default(), + } + } +} + +impl VmState { + pub fn initial( + memory_config: &MemoryConfig, + init_memory: SparseMemoryImage, + pc_start: u32, + inputs: impl Into>, + ) -> Self { + let memory = create_memory_image(memory_config, init_memory); + let seed = 0; + VmState::new(0, pc_start, memory, inputs.into(), seed) + } +} + +/// Represents the full execution state of a VM during execution. +/// The global state is generic in guest memory `MEM` and additional context `CTX`. +/// The host state is execution context specific. +// @dev: Do not confuse with `ExecutionState` struct. +pub struct VmExecState { + /// Core VM state + pub vm_state: VmState, + /// Execution-specific fields + pub exit_code: Result, ExecutionError>, + pub ctx: CTX, +} + +impl VmExecState { + pub fn new(vm_state: VmState, ctx: CTX) -> Self { + Self { + vm_state, + ctx, + exit_code: Ok(None), + } + } +} + +impl Deref for VmExecState { + type Target = VmState; + + fn deref(&self) -> &Self::Target { + &self.vm_state + } +} + +impl DerefMut for VmExecState { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.vm_state + } +} + +impl VmExecState +where + CTX: E1ExecutionCtx, +{ + /// Runtime read operation for a block of memory + #[inline(always)] + pub fn vm_read( + &mut self, + addr_space: u32, + ptr: u32, + ) -> [T; BLOCK_SIZE] { + self.ctx + .on_memory_operation(addr_space, ptr, BLOCK_SIZE as u32); + self.host_read(addr_space, ptr) + } + + /// Runtime write operation for a block of memory + #[inline(always)] + pub fn vm_write( + &mut self, + addr_space: u32, + ptr: u32, + data: &[T; BLOCK_SIZE], + ) { + self.ctx + .on_memory_operation(addr_space, ptr, BLOCK_SIZE as u32); + self.host_write(addr_space, ptr, data) + } + + #[inline(always)] + pub fn vm_read_slice( + &mut self, + addr_space: u32, + ptr: u32, + len: usize, + ) -> &[T] { + self.ctx.on_memory_operation(addr_space, ptr, len as u32); + self.host_read_slice(addr_space, ptr, len) + } + + #[inline(always)] + pub fn host_read( + &self, + addr_space: u32, + ptr: u32, + ) -> [T; BLOCK_SIZE] { + unsafe { self.memory.read(addr_space, ptr) } + } + + #[inline(always)] + pub fn host_write( + &mut self, + addr_space: u32, + ptr: u32, + data: &[T; BLOCK_SIZE], + ) { + unsafe { self.memory.write(addr_space, ptr, *data) } + } + + #[inline(always)] + pub fn host_read_slice(&self, addr_space: u32, ptr: u32, len: usize) -> &[T] { + unsafe { self.memory.get_slice(addr_space, ptr, len) } + } +} diff --git a/crates/vm/src/arch/testing/execution/mod.rs b/crates/vm/src/arch/testing/execution/mod.rs index c0fdb71c71..3177e7250b 100644 --- a/crates/vm/src/arch/testing/execution/mod.rs +++ b/crates/vm/src/arch/testing/execution/mod.rs @@ -1,12 +1,12 @@ use std::{borrow::BorrowMut, mem::size_of, sync::Arc}; -use air::{DummyExecutionInteractionCols, ExecutionDummyAir}; +use air::DummyExecutionInteractionCols; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, ChipUsageGetter, }; use crate::arch::{ExecutionBus, ExecutionState}; @@ -48,24 +48,20 @@ impl ExecutionTester { } } -impl Chip for ExecutionTester> +impl Chip> for ExecutionTester> where Val: Field, { - fn air(&self) -> AirRef { - Arc::new(ExecutionDummyAir::new(self.bus)) - } - - fn generate_air_proof_input(self) -> AirProofInput { + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { let height = self.records.len().next_power_of_two(); let width = self.trace_width(); let mut values = Val::::zero_vec(height * width); // This zip only goes through records. The padding rows between records.len()..height // are filled with zeros - in particular count = 0 so nothing is added to bus. - for (row, record) in values.chunks_mut(width).zip(self.records) { - *row.borrow_mut() = record; + for (row, record) in values.chunks_mut(width).zip(&self.records) { + *row.borrow_mut() = *record; } - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) + AirProvingContext::simple_no_pis(Arc::new(RowMajorMatrix::new(values, width))) } } impl ChipUsageGetter for ExecutionTester { diff --git a/crates/vm/src/arch/testing/memory/air.rs b/crates/vm/src/arch/testing/memory/air.rs index 8a394c0cce..efca131ae8 100644 --- a/crates/vm/src/arch/testing/memory/air.rs +++ b/crates/vm/src/arch/testing/memory/air.rs @@ -1,46 +1,153 @@ -use std::{borrow::Borrow, mem::size_of}; +use std::{mem::size_of, sync::Arc}; -use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, BaseAir}, - p3_matrix::Matrix, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, }; use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress}; -#[derive(Clone, Copy, Debug, AlignedBorrow, derive_new::new)] #[repr(C)] -pub struct DummyMemoryInteractionCols { - pub address: MemoryAddress, - pub data: [T; BLOCK_SIZE], - pub timestamp: T, +#[derive(Clone, Copy)] +pub struct DummyMemoryInteractionColsRef<'a, T> { + pub address: MemoryAddress<&'a T, &'a T>, + pub data: &'a [T], + pub timestamp: &'a T, /// The send frequency. Send corresponds to write. To read, set to negative. - pub count: T, + pub count: &'a T, +} + +#[repr(C)] +pub struct DummyMemoryInteractionColsMut<'a, T> { + pub address: MemoryAddress<&'a mut T, &'a mut T>, + pub data: &'a mut [T], + pub timestamp: &'a mut T, + /// The send frequency. Send corresponds to write. To read, set to negative. + pub count: &'a mut T, +} + +impl<'a, T> DummyMemoryInteractionColsRef<'a, T> { + pub fn from_slice(slice: &'a [T]) -> Self { + let (address, slice) = slice.split_at(size_of::>()); + let (count, slice) = slice.split_last().unwrap(); + let (timestamp, data) = slice.split_last().unwrap(); + Self { + address: MemoryAddress::new(&address[0], &address[1]), + data, + timestamp, + count, + } + } +} + +impl<'a, T> DummyMemoryInteractionColsMut<'a, T> { + pub fn from_mut_slice(slice: &'a mut [T]) -> Self { + let (addr_space, slice) = slice.split_first_mut().unwrap(); + let (ptr, slice) = slice.split_first_mut().unwrap(); + let (count, slice) = slice.split_last_mut().unwrap(); + let (timestamp, data) = slice.split_last_mut().unwrap(); + Self { + address: MemoryAddress::new(addr_space, ptr), + data, + timestamp, + count, + } + } } #[derive(Clone, Copy, Debug, derive_new::new)] -pub struct MemoryDummyAir { +pub struct MemoryDummyAir { pub bus: MemoryBus, + pub block_size: usize, } -impl BaseAirWithPublicValues for MemoryDummyAir {} -impl PartitionedBaseAir for MemoryDummyAir {} -impl BaseAir for MemoryDummyAir { +impl BaseAirWithPublicValues for MemoryDummyAir {} +impl PartitionedBaseAir for MemoryDummyAir {} +impl BaseAir for MemoryDummyAir { fn width(&self) -> usize { - size_of::>() + self.block_size + 4 } } -impl Air for MemoryDummyAir { +impl Air for MemoryDummyAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); let local = main.row_slice(0); - let local: &DummyMemoryInteractionCols = (*local).borrow(); + let local = DummyMemoryInteractionColsRef::from_slice(&local); self.bus - .send(local.address, local.data.to_vec(), local.timestamp) - .eval(builder, local.count); + .send( + MemoryAddress::new(*local.address.address_space, *local.address.pointer), + local.data.to_vec(), + *local.timestamp, + ) + .eval(builder, *local.count); + } +} + +#[derive(Clone)] +pub struct MemoryDummyChip { + pub air: MemoryDummyAir, + pub trace: Vec, +} + +impl MemoryDummyChip { + pub fn new(air: MemoryDummyAir) -> Self { + Self { + air, + trace: Vec::new(), + } + } +} + +impl MemoryDummyChip { + pub fn send(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) { + self.push(addr_space, ptr, data, timestamp, F::ONE); + } + + pub fn receive(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) { + self.push(addr_space, ptr, data, timestamp, F::NEG_ONE); + } + + pub fn push(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32, count: F) { + assert_eq!(data.len(), self.air.block_size); + self.trace.push(F::from_canonical_u32(addr_space)); + self.trace.push(F::from_canonical_u32(ptr)); + self.trace.extend_from_slice(data); + self.trace.push(F::from_canonical_u32(timestamp)); + self.trace.push(count); + } +} + +impl Chip> for MemoryDummyChip> +where + Val: PrimeField32, +{ + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { + let height = self.current_trace_height().next_power_of_two(); + let width = self.trace_width(); + let mut trace = self.trace.clone(); + trace.resize(height * width, Val::::ZERO); + + let trace = Arc::new(RowMajorMatrix::new(trace, width)); + AirProvingContext::simple_no_pis(trace) + } +} + +impl ChipUsageGetter for MemoryDummyChip { + fn air_name(&self) -> String { + format!("MemoryDummyAir<{}>", self.air.block_size) + } + fn current_trace_height(&self) -> usize { + self.trace.len() / self.trace_width() + } + fn trace_width(&self) -> usize { + BaseAir::::width(&self.air) } } diff --git a/crates/vm/src/arch/testing/memory/mod.rs b/crates/vm/src/arch/testing/memory/mod.rs index ae1136bc7f..a16adc7e2d 100644 --- a/crates/vm/src/arch/testing/memory/mod.rs +++ b/crates/vm/src/arch/testing/memory/mod.rs @@ -1,138 +1,91 @@ -use std::{array::from_fn, borrow::BorrowMut as _, cell::RefCell, mem::size_of, rc::Rc, sync::Arc}; +use std::collections::HashMap; -use air::{DummyMemoryInteractionCols, MemoryDummyAir}; -use openvm_circuit::system::memory::MemoryController; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, -}; -use rand::{seq::SliceRandom, Rng}; +use air::{MemoryDummyAir, MemoryDummyChip}; +use openvm_stark_backend::p3_field::{Field, PrimeField32}; +use rand::Rng; -use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress, RecordId}; +use crate::system::memory::{online::TracingMemory, MemoryController}; pub mod air; -const WORD_SIZE: usize = 1; - /// A dummy testing chip that will add unconstrained messages into the [MemoryBus]. /// Stores a log of raw messages to send/receive to the [MemoryBus]. /// /// It will create a [air::MemoryDummyAir] to add messages to MemoryBus. -pub struct MemoryTester { - pub bus: MemoryBus, - pub controller: Rc>>, - /// Log of record ids - pub records: Vec, +pub struct MemoryTester { + /// Map from `block_size` to [MemoryDummyChip] of that block size + pub chip_for_block: HashMap>, + pub memory: TracingMemory, + pub(super) controller: MemoryController, } impl MemoryTester { - pub fn new(controller: Rc>>) -> Self { - let bus = controller.borrow().memory_bus; + pub fn new(controller: MemoryController, memory: TracingMemory) -> Self { + let bus = controller.memory_bus; + let mut chip_for_block = HashMap::new(); + for log_block_size in 0..6 { + let block_size = 1 << log_block_size; + let chip = MemoryDummyChip::new(MemoryDummyAir::new(bus, block_size)); + chip_for_block.insert(block_size, chip); + } Self { - bus, + chip_for_block, + memory, controller, - records: Vec::new(), } } - /// Returns the cell value at the current timestamp according to `MemoryController`. - pub fn read_cell(&mut self, address_space: usize, pointer: usize) -> F { - let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize); - // core::BorrowMut confuses compiler - let (record_id, value) = - RefCell::borrow_mut(&self.controller).read_cell(addr_space, pointer); - self.records.push(record_id); - value - } - - pub fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) { - let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize); - let (record_id, _) = - RefCell::borrow_mut(&self.controller).write_cell(addr_space, pointer, value); - self.records.push(record_id); + pub fn read(&mut self, addr_space: usize, ptr: usize) -> [F; N] { + let memory = &mut self.memory; + let t = memory.timestamp(); + // TODO: this could be improved if we added a TracingMemory::get_f function + let (t_prev, data) = if addr_space <= 3 { + let (t_prev, data) = unsafe { memory.read::(addr_space as u32, ptr as u32) }; + (t_prev, data.map(F::from_canonical_u8)) + } else { + unsafe { memory.read::(addr_space as u32, ptr as u32) } + }; + self.chip_for_block.get_mut(&N).unwrap().receive( + addr_space as u32, + ptr as u32, + &data, + t_prev, + ); + self.chip_for_block + .get_mut(&N) + .unwrap() + .send(addr_space as u32, ptr as u32, &data, t); + + data } - pub fn read(&mut self, address_space: usize, pointer: usize) -> [F; N] { - from_fn(|i| self.read_cell(address_space, pointer + i)) - } - - pub fn write( - &mut self, - address_space: usize, - mut pointer: usize, - cells: [F; N], - ) { - for cell in cells { - self.write_cell(address_space, pointer, cell); - pointer += 1; - } - } -} - -impl Chip for MemoryTester> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(MemoryDummyAir::::new(self.bus)) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let offline_memory = self.controller.borrow().offline_memory(); - let offline_memory = offline_memory.lock().unwrap(); - - let height = self.records.len().next_power_of_two(); - let width = self.trace_width(); - let mut values = Val::::zero_vec(2 * height * width); - // This zip only goes through records. The padding rows between records.len()..height - // are filled with zeros - in particular count = 0 so nothing is added to bus. - for (row, id) in values.chunks_mut(2 * width).zip(self.records) { - let (first, second) = row.split_at_mut(width); - let row: &mut DummyMemoryInteractionCols, WORD_SIZE> = first.borrow_mut(); - let record = offline_memory.record_by_id(id); - row.address = MemoryAddress { - address_space: record.address_space, - pointer: record.pointer, + pub fn write(&mut self, addr_space: usize, ptr: usize, data: [F; N]) { + let memory = &mut self.memory; + let t = memory.timestamp(); + // TODO: this could be improved if we added a TracingMemory::write_f function + let (t_prev, data_prev) = if addr_space <= 3 { + let (t_prev, data_prev) = unsafe { + memory.write::( + addr_space as u32, + ptr as u32, + data.map(|x| x.as_canonical_u32() as u8), + ) }; - row.data - .copy_from_slice(record.prev_data_slice().unwrap_or(record.data_slice())); - row.timestamp = Val::::from_canonical_u32(record.prev_timestamp); - row.count = -Val::::ONE; - - let row: &mut DummyMemoryInteractionCols, WORD_SIZE> = second.borrow_mut(); - row.address = MemoryAddress { - address_space: record.address_space, - pointer: record.pointer, - }; - row.data.copy_from_slice(record.data_slice()); - row.timestamp = Val::::from_canonical_u32(record.timestamp); - row.count = Val::::ONE; - } - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) - } -} - -impl ChipUsageGetter for MemoryTester { - fn air_name(&self) -> String { - "MemoryDummyAir".to_string() - } - fn current_trace_height(&self) -> usize { - self.records.len() + (t_prev, data_prev.map(F::from_canonical_u8)) + } else { + unsafe { memory.write::(addr_space as u32, ptr as u32, data) } + }; + self.chip_for_block.get_mut(&N).unwrap().receive( + addr_space as u32, + ptr as u32, + &data_prev, + t_prev, + ); + self.chip_for_block + .get_mut(&N) + .unwrap() + .send(addr_space as u32, ptr as u32, &data, t); } - - fn trace_width(&self) -> usize { - size_of::>() - } -} - -pub fn gen_address_space(rng: &mut R) -> usize -where - R: Rng + ?Sized, -{ - *[1, 2].choose(rng).unwrap() } pub fn gen_pointer(rng: &mut R, len: usize) -> usize diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 44b19177be..588e3fa3b1 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -1,21 +1,26 @@ -use std::{ - cell::RefCell, - iter::zip, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{marker::PhantomData, sync::Arc}; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use itertools::zip_eq; +use openvm_circuit_primitives::{ + utils::next_power_of_two_or_zero, + var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip, + }, }; -use openvm_instructions::instruction::Instruction; +use openvm_instructions::{instruction::Instruction, riscv::RV32_REGISTER_AS, NATIVE_AS}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::VerificationData, - interaction::BusIndex, - p3_field::PrimeField32, - p3_matrix::dense::{DenseMatrix, RowMajorMatrix}, - prover::types::AirProofInput, + interaction::{BusIndex, PermutationCheckBus}, + p3_air::BaseAir, + p3_field::{Field, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, + p3_util::log2_strict_usize, + prover::{ + cpu::{CpuBackend, CpuDevice}, + types::AirProvingContext, + }, + rap::AnyRap, verifier::VerificationError, AirRef, Chip, }; @@ -32,27 +37,32 @@ use program::ProgramTester; use rand::{rngs::StdRng, RngCore, SeedableRng}; use tracing::Level; -use super::{ExecutionBus, InstructionExecutor, SystemPort}; +use super::{ExecutionBridge, ExecutionBus, PreflightExecutor}; use crate::{ - arch::{ExecutionState, MemoryConfig}, + arch::{ + testing::{execution::air::ExecutionDummyAir, program::air::ProgramDummyAir}, + vm_poseidon2_config, Arena, ExecutionState, MatrixRecordArena, MemoryConfig, Streams, + VmStateMut, + }, system::{ memory::{ + adapter::records::arena_size_bound, offline_checker::{MemoryBridge, MemoryBus}, - MemoryController, OfflineMemory, + online::TracingMemory, + MemoryAirInventory, MemoryController, SharedMemoryHelper, CHUNK, }, poseidon2::Poseidon2PeripheryChip, program::ProgramBus, + SystemPort, }, }; pub mod execution; pub mod memory; pub mod program; -pub mod test_adapter; pub use execution::ExecutionTester; pub use memory::MemoryTester; -pub use test_adapter::TestAdapterChip; pub const EXECUTION_BUS: BusIndex = 0; pub const MEMORY_BUS: BusIndex = 1; @@ -63,79 +73,123 @@ pub const BYTE_XOR_BUS: BusIndex = 10; pub const RANGE_TUPLE_CHECKER_BUS: BusIndex = 11; pub const MEMORY_MERKLE_BUS: BusIndex = 12; -const RANGE_CHECKER_BUS: BusIndex = 4; +pub const RANGE_CHECKER_BUS: BusIndex = 4; + +pub type ArenaId = usize; + +pub struct TestChipHarness> { + pub executor: E, + pub air: A, + pub chip: C, + pub arena: RA, + phantom: PhantomData, +} -pub struct VmChipTestBuilder { +impl TestChipHarness +where + F: Field, + A: BaseAir, + RA: Arena, +{ + pub fn with_capacity(executor: E, air: A, chip: C, height: usize) -> Self { + let width = air.width(); + let height = next_power_of_two_or_zero(height); + let arena = RA::with_capacity(height, width); + Self { + executor, + air, + chip, + arena, + phantom: PhantomData, + } + } +} + +pub struct VmChipTestBuilder { pub memory: MemoryTester, + pub streams: Streams, + pub rng: StdRng, pub execution: ExecutionTester, pub program: ProgramTester, - rng: StdRng, + internal_rng: StdRng, default_register: usize, default_pointer: usize, } impl VmChipTestBuilder { pub fn new( - memory_controller: Rc>>, + controller: MemoryController, + memory: TracingMemory, + streams: Streams, + rng: StdRng, execution_bus: ExecutionBus, program_bus: ProgramBus, - rng: StdRng, + internal_rng: StdRng, ) -> Self { setup_tracing_with_log_level(Level::WARN); Self { - memory: MemoryTester::new(memory_controller), + memory: MemoryTester::new(controller, memory), + streams, + rng, execution: ExecutionTester::new(execution_bus), program: ProgramTester::new(program_bus), - rng, + internal_rng, default_register: 0, default_pointer: 0, } } // Passthrough functions from ExecutionTester and MemoryTester for better dev-ex - pub fn execute>( + pub fn execute( &mut self, - executor: &mut E, + harness: &mut TestChipHarness, instruction: &Instruction, - ) { + ) where + E: PreflightExecutor, + { let initial_pc = self.next_elem_size_u32(); - self.execute_with_pc(executor, instruction, initial_pc); + self.execute_with_pc(harness, instruction, initial_pc); } - pub fn execute_with_pc>( + pub fn execute_with_pc( &mut self, - executor: &mut E, + harness: &mut TestChipHarness, instruction: &Instruction, initial_pc: u32, - ) { + ) where + E: PreflightExecutor, + { let initial_state = ExecutionState { pc: initial_pc, - timestamp: self.memory.controller.borrow().timestamp(), + timestamp: self.memory.memory.timestamp(), + }; + tracing::debug!("initial_timestamp={}", self.memory.memory.timestamp()); + + let mut pc = initial_pc; + let state_mut = VmStateMut { + pc: &mut pc, + memory: &mut self.memory.memory, + streams: &mut self.streams, + rng: &mut self.rng, + ctx: &mut harness.arena, + #[cfg(feature = "metrics")] + metrics: &mut Default::default(), }; - tracing::debug!(?initial_state.timestamp); - - let final_state = executor - .execute( - &mut *self.memory.controller.borrow_mut(), - instruction, - initial_state, - ) + harness + .executor + .execute(state_mut, instruction) .expect("Expected the execution not to fail"); + let final_state = ExecutionState { + pc, + timestamp: self.memory.memory.timestamp(), + }; self.program.execute(instruction, &initial_state); self.execution.execute(initial_state, final_state); } fn next_elem_size_u32(&mut self) -> u32 { - self.rng.next_u32() % (1 << (F::bits() - 2)) - } - - pub fn read_cell(&mut self, address_space: usize, pointer: usize) -> F { - self.memory.read_cell(address_space, pointer) - } - - pub fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) { - self.memory.write_cell(address_space, pointer, value); + self.internal_rng.next_u32() % (1 << (F::bits() - 2)) } pub fn read(&mut self, address_space: usize, pointer: usize) -> [F; N] { @@ -162,9 +216,22 @@ impl VmChipTestBuilder { pointer: usize, writes: Vec<[F; NUM_LIMBS]>, ) { - self.write(1usize, register, [F::from_canonical_usize(pointer)]); - for (i, &write) in writes.iter().enumerate() { - self.write(2usize, pointer + i * NUM_LIMBS, write); + self.write( + 1usize, + register, + pointer.to_le_bytes().map(F::from_canonical_u8), + ); + if NUM_LIMBS.is_power_of_two() { + for (i, &write) in writes.iter().enumerate() { + self.write(2usize, pointer + i * NUM_LIMBS, write); + } + } else { + for (i, &write) in writes.iter().enumerate() { + let ptr = pointer + i * NUM_LIMBS; + for j in (0..NUM_LIMBS).step_by(4) { + self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap()); + } + } } } @@ -176,6 +243,10 @@ impl VmChipTestBuilder { } } + pub fn execution_bridge(&self) -> ExecutionBridge { + ExecutionBridge::new(self.execution.bus, self.program.bus) + } + pub fn execution_bus(&self) -> ExecutionBus { self.execution.bus } @@ -185,27 +256,23 @@ impl VmChipTestBuilder { } pub fn memory_bus(&self) -> MemoryBus { - self.memory.bus - } - - pub fn memory_controller(&self) -> Rc>> { - self.memory.controller.clone() + self.memory.controller.memory_bus } pub fn range_checker(&self) -> SharedVariableRangeCheckerChip { - self.memory.controller.borrow().range_checker.clone() + self.memory.controller.range_checker.clone() } pub fn memory_bridge(&self) -> MemoryBridge { - self.memory.controller.borrow().memory_bridge() + self.memory.controller.memory_bridge() } - pub fn address_bits(&self) -> usize { - self.memory.controller.borrow().mem_config.pointer_max_bits + pub fn memory_helper(&self) -> SharedMemoryHelper { + self.memory.controller.helper() } - pub fn offline_memory_mutex_arc(&self) -> Arc>> { - self.memory_controller().borrow().offline_memory().clone() + pub fn address_bits(&self) -> usize { + self.memory.controller.memory_config().pointer_max_bits } pub fn get_default_register(&mut self, increment: usize) -> usize { @@ -243,68 +310,126 @@ impl VmChipTestBuilder { } // Use Blake3 as hash for faster tests. -type TestSC = BabyBearBlake3Config; +pub(crate) type TestSC = BabyBearBlake3Config; impl VmChipTestBuilder { pub fn build(self) -> VmChipTester { - self.memory - .controller - .borrow_mut() - .finalize(None::<&mut Poseidon2PeripheryChip>); let tester = VmChipTester { memory: Some(self.memory), ..Default::default() }; - let tester = tester.load(self.execution); - tester.load(self.program) + let tester = + tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution)); + tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program)) } pub fn build_babybear_poseidon2(self) -> VmChipTester { - self.memory - .controller - .borrow_mut() - .finalize(None::<&mut Poseidon2PeripheryChip>); let tester = VmChipTester { memory: Some(self.memory), ..Default::default() }; - let tester = tester.load(self.execution); - tester.load(self.program) + let tester = + tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution)); + tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program)) } } -impl Default for VmChipTestBuilder { - fn default() -> Self { - let mem_config = MemoryConfig::default(); - let range_checker = SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new( +impl VmChipTestBuilder { + pub fn default_persistent() -> Self { + let mut mem_config = MemoryConfig::default(); + mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29; + mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0; + Self::persistent(mem_config) + } + + pub fn default_native() -> Self { + Self::volatile(MemoryConfig::aggregation()) + } + + fn range_checker_and_memory( + mem_config: &MemoryConfig, + init_block_size: usize, + ) -> (SharedVariableRangeCheckerChip, TracingMemory) { + let range_checker = Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new( RANGE_CHECKER_BUS, mem_config.decomp, + ))); + let max_access_adapter_n = log2_strict_usize(mem_config.max_access_adapter_n); + let arena_size_bound = arena_size_bound(&vec![1 << 16; max_access_adapter_n]); + let memory = TracingMemory::new(mem_config, init_block_size, arena_size_bound); + + (range_checker, memory) + } + + pub fn persistent(mem_config: MemoryConfig) -> Self { + setup_tracing_with_log_level(Level::INFO); + let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, CHUNK); + let hasher_chip = Arc::new(Poseidon2PeripheryChip::new( + vm_poseidon2_config(), + POSEIDON2_DIRECT_BUS, + 3, )); - let memory_controller = MemoryController::with_volatile_memory( + let memory_controller = MemoryController::with_persistent_memory( MemoryBus::new(MEMORY_BUS), mem_config, range_checker, + PermutationCheckBus::new(MEMORY_MERKLE_BUS), + PermutationCheckBus::new(POSEIDON2_DIRECT_BUS), + hasher_chip, ); Self { - memory: MemoryTester::new(Rc::new(RefCell::new(memory_controller))), + memory: MemoryTester::new(memory_controller, memory), + streams: Default::default(), + rng: StdRng::seed_from_u64(0), execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)), program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)), + internal_rng: StdRng::seed_from_u64(0), + default_register: 0, + default_pointer: 0, + } + } + + pub fn volatile(mem_config: MemoryConfig) -> Self { + setup_tracing_with_log_level(Level::INFO); + let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, 1); + let memory_controller = MemoryController::with_volatile_memory( + MemoryBus::new(MEMORY_BUS), + mem_config, + range_checker, + ); + Self { + memory: MemoryTester::new(memory_controller, memory), + streams: Default::default(), rng: StdRng::seed_from_u64(0), + execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)), + program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)), + internal_rng: StdRng::seed_from_u64(0), default_register: 0, default_pointer: 0, } } } +impl Default for VmChipTestBuilder { + fn default() -> Self { + let mut mem_config = MemoryConfig::default(); + // TODO[jpw]: this is because old tests use `gen_pointer` on address space 1; this can be + // removed when tests are updated. + mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29; + mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0; + Self::volatile(mem_config) + } +} + pub struct VmChipTester { pub memory: Option>>, - pub air_proof_inputs: Vec<(AirRef, AirProofInput)>, + pub air_ctxs: Vec<(AirRef, AirProvingContext>)>, } impl Default for VmChipTester { fn default() -> Self { Self { memory: None, - air_proof_inputs: vec![], + air_ctxs: vec![], } } } @@ -313,90 +438,149 @@ impl VmChipTester where Val: PrimeField32, { - pub fn load>(mut self, chip: C) -> Self { - if chip.current_trace_height() > 0 { - let air = chip.air(); - let air_proof_input = chip.generate_air_proof_input(); - tracing::debug!("Generated air proof input for {}", air.name()); - self.air_proof_inputs.push((air, air_proof_input)); + pub fn load( + mut self, + harness: TestChipHarness, E, A, C, MatrixRecordArena>>, + ) -> Self + where + A: AnyRap + 'static, + C: Chip>, CpuBackend>, + { + let arena = harness.arena; + let rows_used = arena.trace_offset.div_ceil(arena.width); + if rows_used > 0 { + let air = Arc::new(harness.air) as AirRef; + let ctx = harness.chip.generate_proving_ctx(arena); + tracing::debug!("Generated air proving context for {}", air.name()); + self.air_ctxs.push((air, ctx)); } self } + pub fn load_periphery(self, (air, chip): (A, C)) -> Self + where + A: AnyRap + 'static, + C: Chip<(), CpuBackend>, + { + let air = Arc::new(air) as AirRef; + self.load_periphery_ref((air, chip)) + } + + pub fn load_periphery_ref(mut self, (air, chip): (AirRef, C)) -> Self + where + C: Chip<(), CpuBackend>, + { + let ctx = chip.generate_proving_ctx(()); + tracing::debug!("Generated air proving context for {}", air.name()); + self.air_ctxs.push((air, ctx)); + + self + } + pub fn finalize(mut self) -> Self { if let Some(memory_tester) = self.memory.take() { - let memory_controller = memory_tester.controller.clone(); - let range_checker = memory_controller.borrow().range_checker.clone(); - self = self.load(memory_tester); // dummy memory interactions + let mut memory_controller = memory_tester.controller; + let is_persistent = memory_controller.continuation_enabled(); + let mut memory = memory_tester.memory; + let touched_memory = memory.finalize::>(is_persistent); + // Balance memory boundaries + let range_checker = memory_controller.range_checker.clone(); + for mem_chip in memory_tester.chip_for_block.into_values() { + self = self.load_periphery((mem_chip.air, mem_chip)); + } + let mem_inventory = MemoryAirInventory::new( + memory_controller.memory_bridge(), + memory_controller.memory_config(), + range_checker.bus(), + is_persistent.then_some(( + PermutationCheckBus::new(MEMORY_MERKLE_BUS), + PermutationCheckBus::new(POSEIDON2_DIRECT_BUS), + )), + ); + let ctxs = memory_controller + .generate_proving_ctx(memory.access_adapter_records, touched_memory); + for (air, ctx) in zip_eq(mem_inventory.into_airs(), ctxs) + .filter(|(_, ctx)| ctx.main_trace_height() > 0) { - let airs = memory_controller.borrow().airs(); - let air_proof_inputs = Rc::try_unwrap(memory_controller) - .unwrap_or_else(|_| panic!("Memory controller was not dropped")) - .into_inner() - .generate_air_proof_inputs(); - self.air_proof_inputs.extend( - zip(airs, air_proof_inputs).filter(|(_, input)| input.main_trace_height() > 0), - ); + self.air_ctxs.push((air, ctx)); } - self = self.load(range_checker); // this must be last because other trace generation - // mutates its state + if let Some(hasher_chip) = memory_controller.hasher_chip { + let air: AirRef = match hasher_chip.as_ref() { + Poseidon2PeripheryChip::Register0(chip) => chip.air.clone(), + Poseidon2PeripheryChip::Register1(chip) => chip.air.clone(), + }; + self = self.load_periphery_ref((air, hasher_chip)); + } + // this must be last because other trace generation mutates its state + self = self.load_periphery((range_checker.air, range_checker)); } self } - pub fn load_air_proof_input( + pub fn load_air_proving_ctx( mut self, - air_proof_input: (AirRef, AirProofInput), + air_proving_ctx: (AirRef, AirProvingContext>), ) -> Self { - self.air_proof_inputs.push(air_proof_input); + self.air_ctxs.push(air_proving_ctx); self } - pub fn load_with_custom_trace>( - mut self, - chip: C, - trace: RowMajorMatrix>, - ) -> Self { - let air = chip.air(); - let mut air_proof_input = chip.generate_air_proof_input(); - air_proof_input.raw.common_main = Some(trace); - self.air_proof_inputs.push((air, air_proof_input)); - self - } + // pub fn load_with_custom_trace>( + // mut self, + // chip: C, + // trace: RowMajorMatrix>, + // ) -> Self { + // let air = chip.air(); + // let mut air_proof_input = chip.generate_air_proof_input(); + // air_proof_input.raw.common_main = Some(trace); + // self.air_proof_inputs.push((air, air_proof_input)); + // self + // } - pub fn load_and_prank_trace, P>(mut self, chip: C, modify_trace: P) -> Self + pub fn load_and_prank_trace( + mut self, + harness: TestChipHarness, E, A, C, MatrixRecordArena>>, + modify_trace: P, + ) -> Self where - P: Fn(&mut DenseMatrix>), + A: AnyRap + 'static, + C: Chip>, CpuBackend>, + P: Fn(&mut RowMajorMatrix>), { - let air = chip.air(); - let mut air_proof_input = chip.generate_air_proof_input(); - let trace = air_proof_input.raw.common_main.as_mut().unwrap(); - modify_trace(trace); - self.air_proof_inputs.push((air, air_proof_input)); + let arena = harness.arena; + let mut ctx = harness.chip.generate_proving_ctx(arena); + let trace: Arc>> = Option::take(&mut ctx.common_main).unwrap(); + let mut trace = Arc::into_inner(trace).unwrap(); + modify_trace(&mut trace); + ctx.common_main = Some(Arc::new(trace)); + self.air_ctxs.push((Arc::new(harness.air), ctx)); self } /// Given a function to produce an engine from the max trace height, /// runs a simple test on that engine - pub fn test, P: Fn() -> E>( - &self, // do no take ownership so it's easier to prank + pub fn test E>( + self, // do no take ownership so it's easier to prank engine_provider: P, - ) -> Result, VerificationError> { + ) -> Result, VerificationError> + where + E: StarkEngine, PD = CpuDevice>, + { assert!(self.memory.is_none(), "Memory must be finalized"); - let (airs, air_proof_inputs) = self.air_proof_inputs.iter().cloned().unzip(); - engine_provider().run_test_impl(airs, air_proof_inputs) + let (airs, ctxs): (Vec<_>, Vec<_>) = self.air_ctxs.into_iter().unzip(); + engine_provider().run_test_impl(airs, ctxs) } } impl VmChipTester { pub fn simple_test( - &self, + self, ) -> Result, VerificationError> { self.test(|| BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1))) } - pub fn simple_test_with_expected_error(&self, expected_error: VerificationError) { + pub fn simple_test_with_expected_error(self, expected_error: VerificationError) { let msg = format!( "Expected verification to fail with {:?}, but it didn't", &expected_error @@ -407,11 +591,11 @@ impl VmChipTester { } impl VmChipTester { - pub fn simple_test(&self) -> Result, VerificationError> { + pub fn simple_test(self) -> Result, VerificationError> { self.test(|| BabyBearBlake3Engine::new(FriParameters::new_for_testing(1))) } - pub fn simple_test_with_expected_error(&self, expected_error: VerificationError) { + pub fn simple_test_with_expected_error(self, expected_error: VerificationError) { let msg = format!( "Expected verification to fail with {:?}, but it didn't", &expected_error diff --git a/crates/vm/src/arch/testing/program/mod.rs b/crates/vm/src/arch/testing/program/mod.rs index 04c4feee60..224743cab5 100644 --- a/crates/vm/src/arch/testing/program/mod.rs +++ b/crates/vm/src/arch/testing/program/mod.rs @@ -1,13 +1,12 @@ use std::{borrow::BorrowMut, mem::size_of, sync::Arc}; -use air::ProgramDummyAir; use openvm_instructions::instruction::Instruction; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, ChipUsageGetter, }; use crate::{ @@ -15,7 +14,7 @@ use crate::{ system::program::{ProgramBus, ProgramExecutionCols}, }; -mod air; +pub mod air; #[derive(Debug)] pub struct ProgramTester { @@ -52,22 +51,18 @@ impl ProgramTester { } } -impl Chip for ProgramTester> { - fn air(&self) -> AirRef { - Arc::new(ProgramDummyAir::new(self.bus)) - } - - fn generate_air_proof_input(self) -> AirProofInput { +impl Chip> for ProgramTester> { + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { let height = self.records.len().next_power_of_two(); let width = self.trace_width(); let mut values = Val::::zero_vec(height * width); // This zip only goes through records. The padding rows between records.len()..height // are filled with zeros - in particular count = 0 so nothing is added to bus. - for (row, record) in values.chunks_mut(width).zip(self.records) { - *(row[..width - 1]).borrow_mut() = record; + for (row, record) in values.chunks_mut(width).zip(&self.records) { + *(row[..width - 1]).borrow_mut() = *record; row[width - 1] = Val::::ONE; } - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) + AirProvingContext::simple_no_pis(Arc::new(RowMajorMatrix::new(values, width))) } } diff --git a/crates/vm/src/arch/testing/test_adapter.rs b/crates/vm/src/arch/testing/test_adapter.rs deleted file mode 100644 index bca9eed724..0000000000 --- a/crates/vm/src/arch/testing/test_adapter.rs +++ /dev/null @@ -1,175 +0,0 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - collections::VecDeque, - fmt::Debug, -}; - -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::instruction::Instruction; -use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, -}; -use serde::{Deserialize, Serialize}; - -use crate::{ - arch::{ - AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, ExecutionBridge, - ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - }, - system::memory::{MemoryController, OfflineMemory}, -}; - -// Replaces A: VmAdapterChip while testing VmCoreChip functionality, as it has no -// constraints and thus cannot cause a failure. -pub struct TestAdapterChip { - /// List of the return values of `preprocess` this chip should provide on each sequential call. - pub prank_reads: VecDeque>, - /// List of `pc_inc` to use in `postprocess` on each sequential call. - /// Defaults to `4` if not provided. - pub prank_pc_inc: VecDeque>, - - pub air: TestAdapterAir, -} - -impl TestAdapterChip { - pub fn new( - prank_reads: Vec>, - prank_pc_inc: Vec>, - execution_bridge: ExecutionBridge, - ) -> Self { - Self { - prank_reads: prank_reads.into(), - prank_pc_inc: prank_pc_inc.into(), - air: TestAdapterAir { execution_bridge }, - } - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct TestAdapterRecord { - pub from_pc: u32, - pub operands: [T; 7], -} - -impl VmAdapterChip for TestAdapterChip { - type ReadRecord = (); - type WriteRecord = TestAdapterRecord; - type Air = TestAdapterAir; - type Interface = DynAdapterInterface; - - fn preprocess( - &mut self, - _memory: &mut MemoryController, - _instruction: &Instruction, - ) -> Result<(DynArray, Self::ReadRecord)> { - Ok(( - self.prank_reads - .pop_front() - .expect("Not enough prank reads provided") - .into(), - (), - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - _output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let pc_inc = self - .prank_pc_inc - .pop_front() - .map(|x| x.unwrap_or(4)) - .unwrap_or(4); - Ok(( - ExecutionState { - pc: from_state.pc + pc_inc, - timestamp: memory.timestamp(), - }, - TestAdapterRecord { - operands: [ - instruction.a, - instruction.b, - instruction.c, - instruction.d, - instruction.e, - instruction.f, - instruction.g, - ], - from_pc: from_state.pc, - }, - )) - } - - fn generate_trace_row( - &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - _memory: &OfflineMemory, - ) { - let cols: &mut TestAdapterCols = row_slice.borrow_mut(); - cols.from_pc = F::from_canonical_u32(write_record.from_pc); - cols.operands = write_record.operands; - // row_slice[0] = F::from_canonical_u32(write_record.from_pc); - // row_slice[1..].copy_from_slice(&write_record.operands); - } - - fn air(&self) -> &Self::Air { - &self.air - } -} - -#[derive(Clone, Copy, Debug)] -pub struct TestAdapterAir { - pub execution_bridge: ExecutionBridge, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct TestAdapterCols { - pub from_pc: T, - pub operands: [T; 7], -} - -impl BaseAir for TestAdapterAir { - fn width(&self) -> usize { - TestAdapterCols::::width() - } -} - -impl VmAdapterAir for TestAdapterAir { - type Interface = DynAdapterInterface; - - fn eval( - &self, - builder: &mut AB, - local: &[AB::Var], - ctx: AdapterAirContext, - ) { - let processed_instruction: MinimalInstruction = ctx.instruction.into(); - let cols: &TestAdapterCols = local.borrow(); - self.execution_bridge - .execute_and_increment_or_set_pc( - processed_instruction.opcode, - cols.operands.to_vec(), - ExecutionState { - pc: cols.from_pc.into(), - timestamp: AB::Expr::ONE, - }, - AB::Expr::ZERO, - (4, ctx.to_pc), - ) - .eval(builder, processed_instruction.is_valid); - } - - fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { - let cols: &TestAdapterCols = local.borrow(); - cols.from_pc - } -} diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index c9d5cb2ffc..498bc47a52 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -1,60 +1,101 @@ +//! [VmExecutor] is the struct that can execute an _arbitrary_ program, provided in the form of a +//! [VmExe], for a fixed set of OpenVM instructions corresponding to a [VmExecutionConfig]. +//! Internally once it is given a program, it will preprocess the program to rewrite it into a more +//! optimized format for runtime execution. This **instance** of the executor will be a separate +//! struct specialized to running a _fixed_ program on different program inputs. +//! +//! [VirtualMachine] will similarly be the struct that has done all the setup so it can +//! execute+prove an arbitrary program for a fixed config - it will internally still hold VmExecutor use std::{ + any::TypeId, borrow::Borrow, collections::{HashMap, VecDeque}, marker::PhantomData, - mem, sync::Arc, }; +use getset::{Getters, MutGetters, Setters, WithSetters}; +use itertools::{zip_eq, Itertools}; use openvm_circuit::system::program::trace::compute_exe_commit; -use openvm_instructions::exe::VmExe; +use openvm_instructions::exe::{SparseMemoryImage, VmExe}; use openvm_stark_backend::{ - config::{Com, Domain, StarkGenericConfig, Val}, + config::{Com, StarkGenericConfig, Val}, engine::StarkEngine, - keygen::types::{LinearConstraint, MultiStarkProvingKey, MultiStarkVerifyingKey}, - p3_commit::PolynomialSpace, - p3_field::{FieldAlgebra, PrimeField32}, + keygen::types::{MultiStarkProvingKey, MultiStarkVerifyingKey}, + p3_field::{FieldAlgebra, FieldExtensionAlgebra, PrimeField32, TwoAdicField}, + p3_util::{log2_ceil_usize, log2_strict_usize}, proof::Proof, - prover::types::{CommittedTraceData, ProofInput}, - utils::metrics_span, + prover::{ + hal::{DeviceDataTransporter, MatrixDimensions}, + types::{CommittedTraceData, DeviceMultiStarkProvingKey, ProvingContext}, + }, verifier::VerificationError, - Chip, }; +use p3_baby_bear::BabyBear; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tracing::info_span; +use tracing::{info_span, instrument}; use super::{ - ExecutionError, VmComplexTraceHeights, VmConfig, CONNECTOR_AIR_ID, MERKLE_AIR_ID, - PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, + execution_mode::e1::E1Ctx, ExecutionError, Executor, MemoryConfig, VmChipComplex, + CONNECTOR_AIR_ID, MERKLE_AIR_ID, PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, }; -#[cfg(feature = "bench-metrics")] -use crate::metrics::VmMetrics; use crate::{ - arch::{hasher::poseidon2::vm_poseidon2_hasher, segment::ExecutionSegment}, + arch::{ + execution_mode::{ + metered::{MeteredCtx, Segment}, + tracegen::TracegenCtx, + }, + hasher::poseidon2::vm_poseidon2_hasher, + interpreter::InterpretedInstance, + interpreter_preflight::PreflightInterpretedInstance, + AirInventoryError, AnyEnum, ChipInventoryError, ExecutionState, ExecutorInventory, + ExecutorInventoryError, MeteredExecutor, PreflightExecutor, StaticProgramError, + SystemConfig, TraceFiller, VmBuilder, VmCircuitConfig, VmExecState, VmExecutionConfig, + VmState, PUBLIC_VALUES_AIR_ID, + }, + execute_spanned, system::{ connector::{VmConnectorPvs, DEFAULT_SUSPEND_EXIT_CODE}, memory::{ - merkle::MemoryMerklePvs, - paged_vec::AddressMap, - tree::public_values::{UserPublicValuesProof, UserPublicValuesProofError}, - MemoryImage, CHUNK, + adapter::records, + merkle::{ + public_values::{UserPublicValuesProof, UserPublicValuesProofError}, + MemoryMerklePvs, + }, + online::{GuestMemory, TracingMemory}, + AddressMap, CHUNK, }, - program::trace::VmCommittedExe, + program::{trace::VmCommittedExe, ProgramHandler}, + public_values::PublicValuesExecutor, + SystemChipComplex, SystemRecords, SystemWithFixedTraceHeights, PV_EXECUTOR_IDX, }, }; #[derive(Error, Debug)] pub enum GenerationError { - #[error("generated trace heights violate constraints")] - TraceHeightsLimitExceeded, - #[error(transparent)] - Execution(#[from] ExecutionError), + #[error("unexpected number of arenas: {actual} (expected num_airs={expected})")] + UnexpectedNumArenas { actual: usize, expected: usize }, + #[error("trace height for air_idx={air_idx} must be fixed to {expected}, actual={actual}")] + ForceTraceHeightIncorrect { + air_idx: usize, + actual: usize, + expected: usize, + }, + #[error("trace height of air {air_idx} has height {height} greater than maximum {max_height}")] + TraceHeightsLimitExceeded { + air_idx: usize, + height: usize, + max_height: usize, + }, + #[error("trace heights violate linear constraint {constraint_idx} ({value} >= {threshold})")] + LinearTraceHeightConstraintExceeded { + constraint_idx: usize, + value: u64, + threshold: u32, + }, } -/// VM memory state for continuations. -pub type VmMemoryState = MemoryImage; - /// A trait for key-value store for `Streams`. pub trait KvStore: Send + Sync { fn get(&self, key: &[u8]) -> Option<&[u8]>; @@ -105,11 +146,21 @@ impl From>> for Streams { } } -pub struct VmExecutor { +/// [VmExecutor] is the struct that can execute an _arbitrary_ program, provided in the form of a +/// [VmExe], for a fixed set of OpenVM instructions corresponding to a [VmExecutionConfig]. +/// Internally once it is given a program, it will preprocess the program to rewrite it into a more +/// optimized format for runtime execution. This **instance** of the executor will be a separate +/// struct specialized to running a _fixed_ program on different program inputs. +pub struct VmExecutor +where + VC: VmExecutionConfig, +{ pub config: VC, - pub overridden_heights: Option, - pub trace_height_constraints: Vec, - _marker: PhantomData, + /// If any executors are stateful (i.e., they mutate during execution), then the `inventory` + /// must store the executors in their initialized state. Internally, the executors are cloned + /// into a separate instance before running a program. + inventory: ExecutorInventory, + phantom: PhantomData, } #[repr(i32)] @@ -119,395 +170,94 @@ pub enum ExitCode { Suspended = -1, // Continuations } -pub struct VmExecutorResult { - pub per_segment: Vec>, - /// When VM is running on persistent mode, public values are stored in a special memory space. - pub final_memory: Option>>, -} - -pub struct VmExecutorNextSegmentState { - pub memory: MemoryImage, - pub input: Streams, - pub pc: u32, - #[cfg(feature = "bench-metrics")] - pub metrics: VmMetrics, -} - -impl VmExecutorNextSegmentState { - pub fn new(memory: MemoryImage, input: impl Into>, pc: u32) -> Self { - Self { - memory, - input: input.into(), - pc, - #[cfg(feature = "bench-metrics")] - metrics: VmMetrics::default(), - } - } -} - -pub struct VmExecutorOneSegmentResult> { - pub segment: ExecutionSegment, - pub next_state: Option>, +pub struct PreflightExecutionOutput { + pub system_records: SystemRecords, + pub record_arenas: Vec, + pub to_state: VmState, } impl VmExecutor where - F: PrimeField32, - VC: VmConfig, + VC: VmExecutionConfig, { /// Create a new VM executor with a given config. /// /// The VM will start with a single segment, which is created from the initial state. - pub fn new(config: VC) -> Self { - Self::new_with_overridden_trace_heights(config, None) - } - - pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) { - self.overridden_heights = Some(overridden_heights); - } - - pub fn new_with_overridden_trace_heights( - config: VC, - overridden_heights: Option, - ) -> Self { - Self { + pub fn new(config: VC) -> Result { + let inventory = config.create_executors()?; + Ok(Self { config, - overridden_heights, - trace_height_constraints: vec![], - _marker: Default::default(), - } - } - - pub fn continuation_enabled(&self) -> bool { - self.config.system().continuation_enabled - } - - /// Executes the program in segments. - /// After each segment is executed, call the provided closure on the execution result. - /// Returns the results from each closure, one per segment. - /// - /// The closure takes `f(segment_idx, segment) -> R`. - pub fn execute_and_then( - &self, - exe: impl Into>, - input: impl Into>, - mut f: impl FnMut(usize, ExecutionSegment) -> Result, - map_err: impl Fn(ExecutionError) -> E, - ) -> Result, E> { - let mem_config = self.config.system().memory_config; - let exe = exe.into(); - let mut segment_results = vec![]; - let memory = AddressMap::from_iter( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - exe.init_memory.clone(), - ); - let pc = exe.pc_start; - let mut state = VmExecutorNextSegmentState::new(memory, input, pc); - - #[cfg(feature = "bench-metrics")] - { - state.metrics.fn_bounds = exe.fn_bounds.clone(); - } - - let mut segment_idx = 0; - - loop { - let _span = info_span!("execute_segment", segment = segment_idx).entered(); - let one_segment_result = self - .execute_until_segment(exe.clone(), state) - .map_err(&map_err)?; - segment_results.push(f(segment_idx, one_segment_result.segment)?); - if one_segment_result.next_state.is_none() { - break; - } - state = one_segment_result.next_state.unwrap(); - segment_idx += 1; - } - tracing::debug!("Number of continuation segments: {}", segment_results.len()); - #[cfg(feature = "bench-metrics")] - metrics::counter!("num_segments").absolute(segment_results.len() as u64); - - Ok(segment_results) - } - - pub fn execute_segments( - &self, - exe: impl Into>, - input: impl Into>, - ) -> Result>, ExecutionError> { - self.execute_and_then(exe, input, |_, seg| Ok(seg), |err| err) - } - - /// Executes a program until a segmentation happens. - /// Returns the last segment and the vm state for next segment. - /// This is so that the tracegen and proving of this segment can be immediately started (on a - /// separate machine). - pub fn execute_until_segment( - &self, - exe: impl Into>, - from_state: VmExecutorNextSegmentState, - ) -> Result, ExecutionError> { - let exe = exe.into(); - let mut segment = ExecutionSegment::new( - &self.config, - exe.program.clone(), - from_state.input, - Some(from_state.memory), - self.trace_height_constraints.clone(), - exe.fn_bounds.clone(), - ); - #[cfg(feature = "bench-metrics")] - { - segment.metrics = from_state.metrics; - } - if let Some(overridden_heights) = self.overridden_heights.as_ref() { - segment.set_override_trace_heights(overridden_heights.clone()); - } - let state = metrics_span("execute_time_ms", || segment.execute_from_pc(from_state.pc))?; - - if state.is_terminated { - return Ok(VmExecutorOneSegmentResult { - segment, - next_state: None, - }); - } - - assert!( - self.continuation_enabled(), - "multiple segments require to enable continuations" - ); - assert_eq!( - state.pc, - segment.chip_complex.connector_chip().boundary_states[1] - .unwrap() - .pc - ); - let final_memory = mem::take(&mut segment.final_memory) - .expect("final memory should be set in continuations segment"); - let streams = segment.chip_complex.take_streams(); - #[cfg(feature = "bench-metrics")] - let metrics = segment.metrics.partial_take(); - Ok(VmExecutorOneSegmentResult { - segment, - next_state: Some(VmExecutorNextSegmentState { - memory: final_memory, - input: streams, - pc: state.pc, - #[cfg(feature = "bench-metrics")] - metrics, - }), + inventory, + phantom: PhantomData, }) } +} - pub fn execute( - &self, - exe: impl Into>, - input: impl Into>, - ) -> Result>, ExecutionError> { - let mut last = None; - self.execute_and_then( - exe, - input, - |_, seg| { - last = Some(seg); - Ok(()) - }, - |err| err, - )?; - let last = last.expect("at least one segment must be executed"); - let final_memory = last.final_memory; - let end_state = - last.chip_complex.connector_chip().boundary_states[1].expect("end state must be set"); - if end_state.is_terminate != 1 { - return Err(ExecutionError::DidNotTerminate); - } - if end_state.exit_code != ExitCode::Success as u32 { - return Err(ExecutionError::FailedWithExitCode(end_state.exit_code)); - } - Ok(final_memory) - } - - pub fn execute_and_generate( - &self, - exe: impl Into>, - input: impl Into>, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - VC::Executor: Chip, - VC::Periphery: Chip, - { - self.execute_and_generate_impl(exe.into(), None, input) - } - - pub fn execute_and_generate_with_cached_program( +impl VmExecutor +where + VC: VmExecutionConfig + AsRef, +{ + pub fn build_metered_ctx( &self, - committed_exe: Arc>, - input: impl Into>, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - VC::Executor: Chip, - VC::Periphery: Chip, - { - self.execute_and_generate_impl( - committed_exe.exe.clone(), - Some(committed_exe.committed_program.clone()), - input, + constant_trace_heights: &[Option], + air_names: &[String], + widths: &[usize], + interactions: &[usize], + ) -> MeteredCtx { + let system_config = self.config.as_ref(); + let as_byte_alignment_bits = system_config + .memory_config + .addr_spaces + .iter() + .map(|addr_sp| log2_strict_usize(addr_sp.min_block_size) as u8) + .collect(); + + MeteredCtx::new( + constant_trace_heights.to_vec(), + system_config.has_public_values_chip(), + system_config.continuation_enabled, + as_byte_alignment_bits, + system_config.memory_config.memory_dimensions(), + air_names.to_vec(), + widths.to_vec(), + interactions.to_vec(), + system_config.segmentation_limits, ) } - - fn execute_and_generate_impl( - &self, - exe: VmExe, - committed_program: Option>, - input: impl Into>, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - VC::Executor: Chip, - VC::Periphery: Chip, - { - let mut final_memory = None; - let per_segment = self.execute_and_then( - exe, - input, - |seg_idx, mut seg| { - // Note: this will only be Some on the last segment; otherwise it is - // already moved into next segment state - final_memory = mem::take(&mut seg.final_memory); - tracing::info_span!("trace_gen", segment = seg_idx) - .in_scope(|| seg.generate_proof_input(committed_program.clone())) - }, - GenerationError::Execution, - )?; - - Ok(VmExecutorResult { - per_segment, - final_memory, - }) - } - - pub fn set_trace_height_constraints(&mut self, constraints: Vec) { - self.trace_height_constraints = constraints; - } -} - -/// A single segment VM. -pub struct SingleSegmentVmExecutor { - pub config: VC, - pub overridden_heights: Option, - pub trace_height_constraints: Vec, - _marker: PhantomData, -} - -/// Execution result of a single segment VM execution. -pub struct SingleSegmentVmExecutionResult { - /// All user public values - pub public_values: Vec>, - /// Heights of each AIR, ordered by AIR ID. - pub air_heights: Vec, - /// Heights of (SystemBase, Inventory), in an internal ordering. - pub vm_heights: VmComplexTraceHeights, } -impl SingleSegmentVmExecutor +impl VmExecutor where F: PrimeField32, - VC: VmConfig, + VC: VmExecutionConfig, + VC::Executor: Executor, { - pub fn new(config: VC) -> Self { - Self::new_with_overridden_trace_heights(config, None) - } - - pub fn new_with_overridden_trace_heights( - config: VC, - overridden_heights: Option, - ) -> Self { - assert!( - !config.system().continuation_enabled, - "Single segment VM doesn't support continuation mode" - ); - Self { - config, - overridden_heights, - trace_height_constraints: vec![], - _marker: Default::default(), - } - } - - pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) { - self.overridden_heights = Some(overridden_heights); - } - - pub fn set_trace_height_constraints(&mut self, constraints: Vec) { - self.trace_height_constraints = constraints; - } - - /// Executes a program, compute the trace heights, and returns the public values. - pub fn execute_and_compute_heights( - &self, - exe: impl Into>, - input: impl Into>, - ) -> Result, ExecutionError> { - let segment = { - let mut segment = self.execute_impl(exe.into(), input.into())?; - segment.chip_complex.finalize_memory(); - segment - }; - let air_heights = segment.chip_complex.current_trace_heights(); - let vm_heights = segment.chip_complex.get_internal_trace_heights(); - let public_values = if let Some(pv_chip) = segment.chip_complex.public_values_chip() { - pv_chip.core.get_custom_public_values() - } else { - vec![] - }; - Ok(SingleSegmentVmExecutionResult { - public_values, - air_heights, - vm_heights, - }) - } - - /// Executes a program and returns its proof input. - pub fn execute_and_generate( + /// Creates an instance of the interpreter specialized for pure execution, without metering, of + /// the given `exe`. + /// + /// For metered execution, use the [`metered_instance`](Self::metered_instance) constructor. + pub fn instance( &self, - committed_exe: Arc>, - input: impl Into>, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - VC::Executor: Chip, - VC::Periphery: Chip, - { - let segment = self.execute_impl(committed_exe.exe.clone(), input)?; - let proof_input = tracing::info_span!("trace_gen").in_scope(|| { - segment.generate_proof_input(Some(committed_exe.committed_program.clone())) - })?; - Ok(proof_input) + exe: &VmExe, + ) -> Result, StaticProgramError> { + InterpretedInstance::new(&self.inventory, exe) } +} - fn execute_impl( +impl VmExecutor +where + F: PrimeField32, + VC: VmExecutionConfig, + VC::Executor: MeteredExecutor, +{ + /// Creates an instance of the interpreter specialized for pure execution, without metering, of + /// the given `exe`. + pub fn metered_instance( &self, - exe: VmExe, - input: impl Into>, - ) -> Result, ExecutionError> { - let pc_start = exe.pc_start; - let mut segment = ExecutionSegment::new( - &self.config, - exe.program.clone(), - input.into(), - None, - self.trace_height_constraints.clone(), - exe.fn_bounds.clone(), - ); - if let Some(overridden_heights) = self.overridden_heights.as_ref() { - segment.set_override_trace_heights(overridden_heights.clone()); - } - metrics_span("execute_time_ms", || segment.execute_from_pc(pc_start))?; - Ok(segment) + exe: &VmExe, + executor_idx_to_air_idx: &[usize], + ) -> Result, StaticProgramError> { + InterpretedInstance::new_metered(&self.inventory, exe, executor_idx_to_air_idx) } } @@ -544,122 +294,349 @@ pub enum VmVerificationError { UserPublicValuesError(#[from] UserPublicValuesProofError), } -pub struct VirtualMachine { +#[derive(Error, Debug)] +pub enum VirtualMachineError { + #[error("executor inventory error: {0}")] + ExecutorInventory(#[from] ExecutorInventoryError), + #[error("air inventory error: {0}")] + AirInventory(#[from] AirInventoryError), + #[error("chip inventory error: {0}")] + ChipInventory(#[from] ChipInventoryError), + #[error("static program error: {0}")] + StaticProgram(#[from] StaticProgramError), + #[error("execution error: {0}")] + Execution(#[from] ExecutionError), + #[error("trace generation error: {0}")] + Generation(#[from] GenerationError), + #[error("program committed trade data not loaded")] + ProgramIsNotCommitted, + #[error("verification error: {0}")] + Verification(#[from] VmVerificationError), +} + +/// The [VirtualMachine] struct contains the API to generate proofs for _arbitrary_ programs for a +/// fixed set of OpenVM instructions and a fixed VM circuit corresponding to those instructions. The +/// API is specific to a particular [StarkEngine], which specifies a fixed [StarkGenericConfig] and +/// [ProverBackend] via associated types. The [VmProverBuilder] also fixes the choice of +/// `RecordArena` associated to the prover backend via an associated type. +/// +/// In other words, this struct _is_ the zkVM. +#[derive(Getters, MutGetters, Setters, WithSetters)] +pub struct VirtualMachine +where + E: StarkEngine, + VB: VmBuilder, +{ /// Proving engine pub engine: E, /// Runtime executor - pub executor: VmExecutor, VC>, - _marker: PhantomData, + #[getset(get = "pub")] + executor: VmExecutor, VB::VmConfig>, + #[getset(get = "pub", get_mut = "pub")] + pk: DeviceMultiStarkProvingKey, + chip_complex: VmChipComplex, + #[cfg(feature = "stark-debug")] + pub h_pk: Option>, } -impl VirtualMachine +impl VirtualMachine where - F: PrimeField32, - SC: StarkGenericConfig, - E: StarkEngine, - Domain: PolynomialSpace, - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + E: StarkEngine, + VB: VmBuilder, { - pub fn new(engine: E, config: VC) -> Self { - let executor = VmExecutor::new(config); - Self { + pub fn new( + engine: E, + builder: VB, + config: VB::VmConfig, + d_pk: DeviceMultiStarkProvingKey, + ) -> Result { + let circuit = config.create_airs()?; + let chip_complex = builder.create_chip_complex(&config, circuit)?; + let executor = VmExecutor::, _>::new(config)?; + Ok(Self { engine, executor, - _marker: PhantomData, - } + pk: d_pk, + chip_complex, + #[cfg(feature = "stark-debug")] + h_pk: None, + }) } - pub fn new_with_overridden_trace_heights( + pub fn new_with_keygen( engine: E, - config: VC, - overridden_heights: Option, - ) -> Self { - let executor = VmExecutor::new_with_overridden_trace_heights(config, overridden_heights); - Self { - engine, - executor, - _marker: PhantomData, - } + builder: VB, + config: VB::VmConfig, + ) -> Result<(Self, MultiStarkProvingKey), VirtualMachineError> { + let circuit = config.create_airs()?; + let pk = circuit.keygen(&engine); + let d_pk = engine.device().transport_pk_to_device(&pk); + let vm = Self::new(engine, builder, config, d_pk)?; + Ok((vm, pk)) } - pub fn config(&self) -> &VC { + pub fn config(&self) -> &VB::VmConfig { &self.executor.config } - pub fn keygen(&self) -> MultiStarkProvingKey { - let mut keygen_builder = self.engine.keygen_builder(); - let chip_complex = self.config().create_chip_complex().unwrap(); - for air in chip_complex.airs() { - keygen_builder.add_air(air); - } - keygen_builder.generate_pk() - } - - pub fn set_trace_height_constraints( - &mut self, - trace_height_constraints: Vec, - ) { - self.executor - .set_trace_height_constraints(trace_height_constraints); - } - - pub fn commit_exe(&self, exe: impl Into>) -> Arc> { - let exe = exe.into(); - Arc::new(VmCommittedExe::commit(exe, self.engine.config().pcs())) - } - - pub fn execute( + // TODO[jpw]: I'd like to make a VmInstance struct that has a loaded program + // + /// Preflight execution for a single segment. Executes for exactly `num_insns` instructions + /// using an interpreter. Preflight execution must be provided with `trace_heights` + /// instrumentation data that was collected from a previous run of metered execution so that the + /// preflight execution knows how much memory to allocate for record arenas. + /// + /// This function should rarely be called on its own. Users are advised to call + /// [`prove`](Self::prove) directly. + #[instrument(name = "execute_preflight", skip_all)] + pub fn execute_preflight( &self, - exe: impl Into>, - input: impl Into>, - ) -> Result>, ExecutionError> { - self.executor.execute(exe, input) + exe: &VmExe>, + state: VmState, GuestMemory>, + num_insns: Option, + trace_heights: &[u32], + ) -> Result, VB::RecordArena>, ExecutionError> + where + Val: PrimeField32, + >>::Executor: + PreflightExecutor, VB::RecordArena>, + { + let handler = ProgramHandler::new(&exe.program, &self.executor.inventory)?; + let executor_idx_to_air_idx = self.executor_idx_to_air_idx(); + debug_assert!(executor_idx_to_air_idx + .iter() + .all(|&air_idx| air_idx < trace_heights.len())); + let mut instance = PreflightInterpretedInstance::new(handler, executor_idx_to_air_idx); + + let instret_end = num_insns.map(|ni| state.instret.saturating_add(ni)); + // TODO[jpw]: figure out how to compute RA specific main_widths + let main_widths = self + .pk + .per_air + .iter() + .map(|pk| pk.vk.params.width.main_width()) + .collect_vec(); + let capacities = zip_eq(trace_heights, main_widths) + .map(|(&h, w)| (h as usize, w)) + .collect::>(); + let ctx = TracegenCtx::new_with_capacity(&capacities, instret_end); + + let system_config: &SystemConfig = self.config().as_ref(); + let adapter_offset = system_config.access_adapter_air_id_offset(); + // ATTENTION: this must agree with `num_memory_airs` + let num_adapters = log2_strict_usize(system_config.memory_config.max_access_adapter_n); + assert_eq!(adapter_offset + num_adapters, system_config.num_airs()); + let access_adapter_arena_size_bound = records::arena_size_bound( + &trace_heights[adapter_offset..adapter_offset + num_adapters], + ); + let memory = TracingMemory::from_image( + state.memory, + system_config.initial_block_size(), + access_adapter_arena_size_bound, + ); + let from_state = ExecutionState::new(state.pc, memory.timestamp()); + let vm_state = VmState { + instret: state.instret, + pc: state.pc, + memory, + streams: state.streams, + rng: state.rng, + #[cfg(feature = "metrics")] + metrics: state.metrics, + }; + let mut exec_state = VmExecState::new(vm_state, ctx); + execute_spanned!("execute_preflight", instance, &mut exec_state)?; + let filtered_exec_frequencies = instance.handler.filtered_execution_frequencies(); + let touched_memory = exec_state + .vm_state + .memory + .finalize::>(system_config.continuation_enabled); + #[cfg(feature = "perf-metrics")] + crate::metrics::end_segment_metrics(&mut exec_state); + + let memory = exec_state.vm_state.memory; + let to_state = ExecutionState::new(exec_state.vm_state.pc, memory.timestamp()); + let public_values = system_config + .has_public_values_chip() + .then(|| { + instance.handler.executors[PV_EXECUTOR_IDX] + .as_any_kind() + .downcast_ref::>>() + .unwrap() + .generate_public_values() + }) + .unwrap_or_default(); + let exit_code = exec_state.exit_code?; + let system_records = SystemRecords { + from_state, + to_state, + exit_code, + filtered_exec_frequencies, + access_adapter_records: memory.access_adapter_records, + touched_memory, + public_values, + }; + let record_arenas = exec_state.ctx.arenas; + let to_state = VmState { + instret: exec_state.vm_state.instret, + pc: exec_state.vm_state.pc, + memory: memory.data, + streams: exec_state.vm_state.streams, + rng: exec_state.vm_state.rng, + #[cfg(feature = "metrics")] + metrics: exec_state.vm_state.metrics, + }; + Ok(PreflightExecutionOutput { + system_records, + record_arenas, + to_state, + }) } - pub fn execute_and_generate( + /// Calls [`VmState::initial`] but sets more information for + /// performance metrics when feature "perf-metrics" is enabled. + pub fn create_initial_state( &self, - exe: impl Into>, - input: impl Into>, - ) -> Result, GenerationError> { - self.executor.execute_and_generate(exe, input) + exe: &VmExe>, + inputs: impl Into>>, + ) -> VmState, GuestMemory> { + let memory_config = &self.config().as_ref().memory_config; + #[allow(unused_mut)] + let mut state = + VmState::initial(memory_config, exe.init_memory.clone(), exe.pc_start, inputs); + // Add backtrace information for either: + // - debugging + // - performance metrics + #[cfg(all(feature = "metrics", any(feature = "perf-metrics", debug_assertions)))] + { + state.metrics.fn_bounds = exe.fn_bounds.clone(); + state.metrics.debug_infos = exe.program.debug_infos(); + } + #[cfg(feature = "perf-metrics")] + { + state.metrics.set_pk_info(&self.pk); + state.metrics.num_sys_airs = self.config().as_ref().num_airs(); + state.metrics.access_adapter_offset = + self.config().as_ref().access_adapter_air_id_offset(); + } + state } - pub fn execute_and_generate_with_cached_program( - &self, - committed_exe: Arc>, - input: impl Into>, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - { - self.executor - .execute_and_generate_with_cached_program(committed_exe, input) - } + /// This function mutates `self` but should only depend on internal state in the sense that: + /// - program must already be loaded as cached trace via [`load_program`](Self::load_program). + /// - initial memory image was already sent to device via + /// [`transport_init_memory_to_device`](Self::transport_init_memory_to_device). + /// - all other state should be given by `system_records` and `record_arenas` + #[instrument(name = "trace_gen", skip_all)] + pub fn generate_proving_ctx( + &mut self, + system_records: SystemRecords>, + record_arenas: Vec, + ) -> Result, GenerationError> { + #[cfg(feature = "metrics")] + let mut current_trace_heights = + self.get_trace_heights_from_arenas(&system_records, &record_arenas); + // main tracegen call: + let ctx = self + .chip_complex + .generate_proving_ctx(system_records, record_arenas)?; + + // ==== Defensive checks that the trace heights satisfy the linear constraints: ==== + let idx_trace_heights = ctx + .per_air + .iter() + .map(|(air_idx, ctx)| (*air_idx, ctx.main_trace_height())) + .collect_vec(); + // 1. check max trace height isn't exceeded + let max_trace_height = if TypeId::of::>() == TypeId::of::() { + let min_log_blowup = log2_ceil_usize(self.config().as_ref().max_constraint_degree - 1); + 1 << (BabyBear::TWO_ADICITY - min_log_blowup) + } else { + tracing::warn!( + "constructing VirtualMachine for unrecognized field; using max_trace_height=2^30" + ); + 1 << 30 + }; + if let Some(&(air_idx, height)) = idx_trace_heights + .iter() + .find(|(_, height)| *height > max_trace_height) + { + return Err(GenerationError::TraceHeightsLimitExceeded { + air_idx, + height, + max_height: max_trace_height, + }); + } + // 2. check linear constraints on trace heights are satisfied + let trace_height_constraints = &self.pk.trace_height_constraints; + if trace_height_constraints.is_empty() { + tracing::warn!("generating proving context without trace height constraints"); + } + for (i, constraint) in trace_height_constraints.iter().enumerate() { + let value = idx_trace_heights + .iter() + .map(|&(air_idx, h)| constraint.coefficients[air_idx] as u64 * h as u64) + .sum::(); + + if value >= constraint.threshold as u64 { + tracing::info!( + "trace heights {:?} violate linear constraint {} ({} >= {})", + idx_trace_heights, + i, + value, + constraint.threshold + ); + return Err(GenerationError::LinearTraceHeightConstraintExceeded { + constraint_idx: i, + value, + threshold: constraint.threshold, + }); + } + } + #[cfg(feature = "metrics")] + self.finalize_metrics(&mut current_trace_heights); + #[cfg(feature = "stark-debug")] + self.debug_proving_ctx(&ctx); - pub fn prove_single( - &self, - pk: &MultiStarkProvingKey, - proof_input: ProofInput, - ) -> Proof { - self.engine.prove(pk, proof_input) + Ok(ctx) } + /// Generates proof for zkVM execution for exactly `num_insns` instructions for a given program + /// and a given starting state. + /// + /// **Note**: The cached program trace must be loaded via [`load_program`](Self::load_program) + /// before calling this function. + /// + /// Returns: + /// - proof for the execution segment + /// - final memory state only if execution ends in successful termination (exit code 0). This + /// final memory state may be used to extract user public values afterwards. pub fn prove( - &self, - pk: &MultiStarkProvingKey, - results: VmExecutorResult, - ) -> Vec> { - results - .per_segment - .into_iter() - .enumerate() - .map(|(seg_idx, proof_input)| { - tracing::info_span!("prove_segment", segment = seg_idx) - .in_scope(|| self.engine.prove(pk, proof_input)) - }) - .collect() + &mut self, + exe: &VmExe>, + state: VmState, GuestMemory>, + num_insns: Option, + trace_heights: &[u32], + ) -> Result<(Proof, Option), VirtualMachineError> + where + Val: PrimeField32, + >>::Executor: + PreflightExecutor, VB::RecordArena>, + { + self.transport_init_memory_to_device(&state.memory); + + let PreflightExecutionOutput { + system_records, + record_arenas, + to_state, + } = self.execute_preflight(exe, state, num_insns, trace_heights)?; + // drop final memory unless this is a terminal segment and the exit code is success + let final_memory = + (system_records.exit_code == Some(ExitCode::Success as u32)).then_some(to_state.memory); + let ctx = self.generate_proving_ctx(system_records, record_arenas)?; + let proof = self.engine.prove(&self.pk, ctx); + + Ok((proof, final_memory)) } /// Verify segment proofs, checking continuation boundary conditions between segments if VM @@ -668,21 +645,293 @@ where /// or [`verify_single`] directly instead. pub fn verify( &self, - vk: &MultiStarkVerifyingKey, - proofs: Vec>, + vk: &MultiStarkVerifyingKey, + proofs: &[Proof], ) -> Result<(), VmVerificationError> where - Val: PrimeField32, - Com: AsRef<[Val; CHUNK]> + From<[Val; CHUNK]>, + Com: AsRef<[Val; CHUNK]> + From<[Val; CHUNK]>, + Val: PrimeField32, { - if self.config().system().continuation_enabled { - verify_segments(&self.engine, vk, &proofs).map(|_| ()) + if self.config().as_ref().continuation_enabled { + verify_segments(&self.engine, vk, proofs).map(|_| ()) } else { assert_eq!(proofs.len(), 1); - verify_single(&self.engine, vk, &proofs.into_iter().next().unwrap()) - .map_err(VmVerificationError::StarkError) + verify_single(&self.engine, vk, &proofs[0]).map_err(VmVerificationError::StarkError) } } + + /// Generates and then commits to program trace entirely on host. + pub fn commit_exe(&self, exe: impl Into>>) -> VmCommittedExe { + let exe = exe.into(); + VmCommittedExe::commit(exe, self.engine.config().pcs()) + } + + /// Convenience method to transport a host committed Exe to device. If the Exe has already been + /// committed directly on device (either via a different caching mechanism or directly using + /// device committer), then you can directly call [`load_program`](Self::load_program) and skip + /// this function. + pub fn transport_committed_exe_to_device( + &self, + committed_exe: &VmCommittedExe, + ) -> CommittedTraceData { + let commitment = committed_exe.commitment.clone(); + let trace = &committed_exe.trace; + let prover_data = &committed_exe.prover_data; + self.engine + .device() + .transport_committed_trace_to_device(commitment, trace, prover_data) + } + + pub fn load_program(&mut self, cached_program_trace: CommittedTraceData) { + self.chip_complex.system.load_program(cached_program_trace); + } + + pub fn transport_init_memory_to_device(&mut self, memory: &GuestMemory) { + self.chip_complex + .system + .transport_init_memory_to_device(memory); + } + + pub fn executor_idx_to_air_idx(&self) -> Vec { + let ret = self.chip_complex.inventory.executor_idx_to_air_idx(); + tracing::debug!("executor_idx_to_air_idx: {:?}", ret); + assert_eq!(self.executor().inventory.executors().len(), ret.len()); + ret + } + + /// Convenience method to construct a [MeteredCtx] using data from the stored proving key. + pub fn build_metered_ctx(&self) -> MeteredCtx { + let (constant_trace_heights, air_names, widths, interactions): ( + Vec<_>, + Vec<_>, + Vec<_>, + Vec<_>, + ) = self + .pk + .per_air + .iter() + .map(|pk| { + let constant_trace_height = + pk.preprocessed_data.as_ref().map(|pd| pd.trace.height()); + let air_names = pk.air_name.clone(); + let width = pk + .vk + .params + .width + .total_width(<::Challenge>::D); + let num_interactions = pk.vk.symbolic_constraints.interactions.len(); + (constant_trace_height, air_names, width, num_interactions) + }) + .multiunzip(); + + self.executor().build_metered_ctx( + &constant_trace_heights, + &air_names, + &widths, + &interactions, + ) + } + + pub fn num_airs(&self) -> usize { + let num_airs = self.pk.per_air.len(); + debug_assert_eq!(num_airs, self.chip_complex.inventory.airs().num_airs()); + num_airs + } + + pub fn air_names(&self) -> impl Iterator { + self.pk.per_air.iter().map(|pk| pk.air_name.as_str()) + } + + /// See [`debug_proving_ctx`]. + #[cfg(feature = "stark-debug")] + pub fn debug_proving_ctx(&mut self, ctx: &ProvingContext) { + if self.h_pk.is_none() { + let air_inv = self.config().create_airs().unwrap(); + self.h_pk = Some(air_inv.keygen(&self.engine)); + } + let pk = self.h_pk.as_ref().unwrap(); + debug_proving_ctx(self, pk, ctx); + } +} + +#[derive(Serialize, Deserialize)] +#[serde(bound( + serialize = "Com: Serialize", + deserialize = "Com: Deserialize<'de>" +))] +pub struct ContinuationVmProof { + pub per_segment: Vec>, + pub user_public_values: UserPublicValuesProof<{ CHUNK }, Val>, +} + +/// Prover for a specific exe in a specific continuation VM using a specific Stark config. +pub trait ContinuationVmProver { + fn prove( + &mut self, + input: impl Into>>, + ) -> Result, VirtualMachineError>; +} + +/// Prover for a specific exe in a specific single-segment VM using a specific Stark config. +/// +/// Does not run metered execution and directly runs preflight execution. The `prove` function must +/// be provided with the expected maximum `trace_heights` to use to allocate record arena +/// capacities. +pub trait SingleSegmentVmProver { + fn prove( + &mut self, + input: impl Into>>, + trace_heights: &[u32], + ) -> Result, VirtualMachineError>; +} + +/// Virtual machine prover instance for a fixed VM config and a fixed program. For use in proving a +/// program directly on bare metal. +#[derive(Getters)] +pub struct VmLocalProver +where + E: StarkEngine, + VB: VmBuilder, +{ + pub vm: VirtualMachine, + #[getset(get = "pub")] + exe_commitment: Com, + // TODO: store immutable parts of program handler here + #[getset(get = "pub")] + exe: VmExe>, +} + +impl VmLocalProver +where + E: StarkEngine, + VB: VmBuilder, +{ + pub fn new( + mut vm: VirtualMachine, + exe: VmExe>, + cached_program_trace: CommittedTraceData, + ) -> Self { + let exe_commitment = cached_program_trace.commitment.clone(); + vm.load_program(cached_program_trace); + Self { + vm, + exe, + exe_commitment, + } + } +} + +impl ContinuationVmProver for VmLocalProver +where + E: StarkEngine, + Val: PrimeField32, + VB: VmBuilder, + >>::Executor: Executor> + + MeteredExecutor> + + PreflightExecutor, VB::RecordArena>, +{ + /// First performs metered execution (E2) to determine segments. Then sequentially proves each + /// segment. The proof for each segment uses the specified [ProverBackend], but the proof for + /// the next segment does not start before the current proof finishes. + fn prove( + &mut self, + input: impl Into>>, + ) -> Result, VirtualMachineError> { + self.prove_continuations(input, |_, _| {}) + } +} + +impl VmLocalProver +where + E: StarkEngine, + Val: PrimeField32, + VB: VmBuilder, + >>::Executor: Executor> + + MeteredExecutor> + + PreflightExecutor, VB::RecordArena>, +{ + /// For internal use to resize trace matrices before proving. + /// + /// The closure `modify_ctx(seg_idx, &mut ctx)` is called sequentially for each segment. + pub fn prove_continuations( + &mut self, + input: impl Into>>, + mut modify_ctx: impl FnMut(usize, &mut ProvingContext), + ) -> Result, VirtualMachineError> { + let input = input.into(); + let vm = &mut self.vm; + let exe = &self.exe; + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + let e2_ctx = vm.build_metered_ctx(); + let interpreter = vm + .executor() + .metered_instance(&self.exe, &executor_idx_to_air_idx)?; + let (segments, _) = interpreter.execute_metered(input.clone(), e2_ctx)?; + let mut proofs = Vec::with_capacity(segments.len()); + let mut state = Some(vm.create_initial_state(exe, input)); + for (seg_idx, segment) in segments.into_iter().enumerate() { + let _segment_span = info_span!("prove_segment", segment = seg_idx).entered(); + // We need a separate span so the metric label includes "segment" from _segment_span + let _prove_span = info_span!("total_proof").entered(); + let Segment { + instret_start, + num_insns, + trace_heights, + } = segment; + assert_eq!(state.as_ref().unwrap().instret, instret_start); + let from_state = Option::take(&mut state).unwrap(); + vm.transport_init_memory_to_device(&from_state.memory); + let PreflightExecutionOutput { + system_records, + record_arenas, + to_state, + } = vm.execute_preflight(exe, from_state, Some(num_insns), &trace_heights)?; + state = Some(to_state); + + let mut ctx = vm.generate_proving_ctx(system_records, record_arenas)?; + modify_ctx(seg_idx, &mut ctx); + let proof = vm.engine.prove(vm.pk(), ctx); + proofs.push(proof); + } + let to_state = state.unwrap(); + let final_memory = to_state.memory.memory; + let user_public_values = UserPublicValuesProof::compute( + vm.config().as_ref().memory_config.memory_dimensions(), + vm.config().as_ref().num_public_values, + &vm_poseidon2_hasher(), + &final_memory, + ); + Ok(ContinuationVmProof { + per_segment: proofs, + user_public_values, + }) + } +} + +impl SingleSegmentVmProver for VmLocalProver +where + E: StarkEngine, + Val: PrimeField32, + VB: VmBuilder, + >>::Executor: + PreflightExecutor, VB::RecordArena>, +{ + #[instrument(name = "total_proof", skip_all)] + fn prove( + &mut self, + input: impl Into>>, + trace_heights: &[u32], + ) -> Result, VirtualMachineError> { + let input = input.into(); + let vm = &mut self.vm; + let exe = &self.exe; + assert!(!vm.config().as_ref().continuation_enabled); + let mut trace_heights = trace_heights.to_vec(); + trace_heights[PUBLIC_VALUES_AIR_ID] = vm.config().as_ref().num_public_values as u32; + let state = vm.create_initial_state(exe, input); + let (proof, _) = vm.prove(exe, state, None, &trace_heights)?; + Ok(proof) + } } /// Verifies a single proof. This should be used for proof of VM without continuations. @@ -690,14 +939,13 @@ where /// ## Note /// This function does not check any public values or extract the starting pc or commitment /// to the [VmCommittedExe]. -pub fn verify_single( +pub fn verify_single( engine: &E, - vk: &MultiStarkVerifyingKey, - proof: &Proof, + vk: &MultiStarkVerifyingKey, + proof: &Proof, ) -> Result<(), VerificationError> where - SC: StarkGenericConfig, - E: StarkEngine, + E: StarkEngine, { engine.verify(vk, proof) } @@ -732,16 +980,15 @@ pub struct VerifiedExecutionPayload { /// This verification requires an additional Merkle proof with respect to the Merkle root of /// the final memory state. // @dev: This function doesn't need to be generic in `VC`. -pub fn verify_segments( +pub fn verify_segments( engine: &E, - vk: &MultiStarkVerifyingKey, - proofs: &[Proof], -) -> Result>, VmVerificationError> + vk: &MultiStarkVerifyingKey, + proofs: &[Proof], +) -> Result>, VmVerificationError> where - SC: StarkGenericConfig, - E: StarkEngine, - Val: PrimeField32, - Com: AsRef<[Val; CHUNK]>, + E: StarkEngine, + Val: PrimeField32, + Com: AsRef<[Val; CHUNK]>, { if proofs.is_empty() { return Err(VmVerificationError::ProofNotFound); @@ -865,16 +1112,6 @@ where }) } -#[derive(Serialize, Deserialize)] -#[serde(bound( - serialize = "Com: Serialize", - deserialize = "Com: Deserialize<'de>" -))] -pub struct ContinuationVmProof { - pub per_segment: Vec>, - pub user_public_values: UserPublicValuesProof<{ CHUNK }, Val>, -} - impl Clone for ContinuationVmProof where Com: Clone, @@ -886,3 +1123,155 @@ where } } } + +pub(super) fn create_memory_image( + memory_config: &MemoryConfig, + init_memory: SparseMemoryImage, +) -> GuestMemory { + GuestMemory::new(AddressMap::from_sparse( + memory_config.addr_spaces.clone(), + init_memory, + )) +} + +impl VirtualMachine +where + E: StarkEngine, + VC: VmBuilder, + VC::SystemChipInventory: SystemWithFixedTraceHeights, +{ + /// Sets fixed trace heights for the system AIRs' trace matrices. + pub fn override_system_trace_heights(&mut self, heights: &[u32]) { + let num_sys_airs = self.config().as_ref().num_airs(); + assert!(heights.len() >= num_sys_airs); + self.chip_complex + .system + .override_trace_heights(&heights[..num_sys_airs]); + } +} + +/// Runs the STARK backend debugger to check the constraints against the trace matrices +/// logically, instead of cryptographically. This will panic if any constraint is violated, and +/// using `RUST_BACKTRACE=1` can be used to read the stack backtrace of where the constraint +/// failed in the code (this requires the code to be compiled with debug=true). Using lower +/// optimization levels like -O0 will prevent the compiler from inlining and give better +/// debugging information. +// @dev The debugger needs the host proving key. +// This function is used both by VirtualMachine::debug_proving_ctx and by +// stark_utils::air_test_impl +#[cfg(any(debug_assertions, feature = "test-utils", feature = "stark-debug"))] +#[tracing::instrument(level = "debug", skip_all)] +pub fn debug_proving_ctx( + vm: &VirtualMachine, + pk: &MultiStarkProvingKey, + ctx: &ProvingContext, +) where + E: StarkEngine, + VB: VmBuilder, +{ + use itertools::multiunzip; + use openvm_stark_backend::prover::types::AirProofRawInput; + + let device = vm.engine.device(); + let air_inv = vm.config().create_airs().unwrap(); + let global_airs = air_inv.into_airs().collect_vec(); + let (airs, pks, proof_inputs): (Vec<_>, Vec<_>, Vec<_>) = + multiunzip(ctx.per_air.iter().map(|(air_id, air_ctx)| { + // Transfer from device **back** to host so the debugger can read the data. + let cached_mains = air_ctx + .cached_mains + .iter() + .map(|pre| device.transport_matrix_from_device_to_host(&pre.trace)) + .collect_vec(); + let common_main = air_ctx + .common_main + .as_ref() + .map(|m| device.transport_matrix_from_device_to_host(m)); + let public_values = air_ctx.public_values.clone(); + let raw = AirProofRawInput { + cached_mains, + common_main, + public_values, + }; + ( + global_airs[*air_id].clone(), + pk.per_air[*air_id].clone(), + raw, + ) + })); + vm.engine.debug(&airs, &pks, &proof_inputs); +} + +#[cfg(feature = "metrics")] +mod vm_metrics { + use std::iter::zip; + + use metrics::counter; + + use super::*; + use crate::arch::Arena; + + impl VirtualMachine + where + E: StarkEngine, + VB: VmBuilder, + { + /// Assumed that `record_arenas` has length equal to number of AIRs. + /// + /// Best effort calculation of the used trace heights per chip without padding to powers of + /// two. This is best effort because some periphery chips may not have record arenas to + /// instrument. This function includes the constant trace heights, and the used height of + /// the program trace. It does not include the memory access adapter trace heights, + /// which is included in `SystemChipComplex::finalize_trace_heights`. + pub(crate) fn get_trace_heights_from_arenas( + &self, + system_records: &SystemRecords>, + record_arenas: &[VB::RecordArena], + ) -> Vec { + let num_airs = self.num_airs(); + assert_eq!(num_airs, record_arenas.len()); + let mut heights: Vec = record_arenas + .iter() + .map(|arena| arena.current_trace_height()) + .collect(); + // If there are any constant trace heights, set them + for (pk, height) in zip(&self.pk.per_air, &mut heights) { + if let Some(constant_height) = + pk.preprocessed_data.as_ref().map(|pd| pd.trace.height()) + { + *height = constant_height; + } + } + // Program chip used height + heights[PROGRAM_AIR_ID] = system_records.filtered_exec_frequencies.len(); + + heights + } + + /// Update used trace heights after tracegen is done (primarily updating memory-related + /// metrics) and then emit the final metrics. + pub(crate) fn finalize_metrics(&self, heights: &mut [usize]) { + self.chip_complex.system.finalize_trace_heights(heights); + let mut main_cells_used = 0usize; + let mut total_cells_used = 0usize; + for (pk, height) in zip(&self.pk.per_air, heights.iter()) { + let width = &pk.vk.params.width; + main_cells_used += width.main_width() * *height; + total_cells_used += + width.total_width(::Challenge::D) * *height; + } + tracing::debug!(?heights); + tracing::info!(main_cells_used, total_cells_used); + counter!("main_cells_used").absolute(main_cells_used as u64); + counter!("total_cells_used").absolute(total_cells_used as u64); + + #[cfg(feature = "perf-metrics")] + { + for (name, value) in zip(self.air_names(), heights) { + let labels = [("air_name", name.to_string())]; + counter!("rows_used", &labels).absolute(*value as u64); + } + } + } + } +} diff --git a/crates/vm/src/lib.rs b/crates/vm/src/lib.rs index 2e3ba461c5..271ea04b82 100644 --- a/crates/vm/src/lib.rs +++ b/crates/vm/src/lib.rs @@ -8,7 +8,7 @@ pub use openvm_stark_sdk; /// Traits and constructs for the OpenVM architecture. pub mod arch; /// Instrumentation metrics for performance analysis and debugging -#[cfg(feature = "bench-metrics")] +#[cfg(feature = "metrics")] pub mod metrics; /// System chips that are always required by the architecture. /// (The [PhantomChip](system::phantom::PhantomChip) is not technically required for a functioning diff --git a/crates/vm/src/metrics/cycle_tracker/mod.rs b/crates/vm/src/metrics/cycle_tracker/mod.rs index b1ef065451..3d989bc44b 100644 --- a/crates/vm/src/metrics/cycle_tracker/mod.rs +++ b/crates/vm/src/metrics/cycle_tracker/mod.rs @@ -46,7 +46,7 @@ impl CycleTracker { } } -#[cfg(feature = "bench-metrics")] +#[cfg(feature = "metrics")] mod emit { use metrics::counter; diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index 916e8251ac..e8d44b9fc3 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -1,70 +1,165 @@ use std::{collections::BTreeMap, mem}; +use backtrace::Backtrace; use cycle_tracker::CycleTracker; +use itertools::Itertools; use metrics::counter; use openvm_instructions::{ exe::{FnBound, FnBounds}, - VmOpcode, + program::ProgramDebugInfo, }; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::prover::{hal::ProverBackend, types::DeviceMultiStarkProvingKey}; -use crate::arch::{ExecutionSegment, InstructionExecutor, VmConfig}; +use crate::{ + arch::{execution_mode::tracegen::TracegenCtx, Arena, PreflightExecutor, VmExecState}, + system::{memory::online::TracingMemory, program::PcEntry}, +}; pub mod cycle_tracker; #[derive(Clone, Debug, Default)] pub struct VmMetrics { - pub cycle_count: usize, - pub chip_heights: Vec<(String, usize)>, + // Static info + pub air_names: Vec, + pub debug_infos: ProgramDebugInfo, + #[cfg(feature = "perf-metrics")] + pub(crate) num_sys_airs: usize, + #[cfg(feature = "perf-metrics")] + pub(crate) access_adapter_offset: usize, + pub(crate) main_widths: Vec, + pub(crate) total_widths: Vec, + + // Dynamic stats /// Maps (dsl_ir, opcode) to number of times opcode was executed pub counts: BTreeMap<(Option, String), usize>, /// Maps (dsl_ir, opcode, air_name) to number of trace cells generated by opcode pub trace_cells: BTreeMap<(Option, String, String), usize>, - /// Metric collection tools. Only collected when `config.profiling` is true. + /// Metric collection tools. Only collected when "perf-metrics" feature is enabled. pub cycle_tracker: CycleTracker, + + pub(crate) current_trace_cells: Vec, + + /// Backtrace for guest debug panic display + pub prev_backtrace: Option, #[allow(dead_code)] pub(crate) fn_bounds: FnBounds, /// Cycle span by function if function start/end addresses are available #[allow(dead_code)] pub(crate) current_fn: FnBound, - pub(crate) current_trace_cells: Vec, } -impl ExecutionSegment +/// We assume this will be called after execute_instruction, so less error-handling is needed. +#[allow(unused_variables)] +#[inline(always)] +pub fn update_instruction_metrics( + state: &mut VmExecState>, + executor: &mut Executor, + prev_pc: u32, // the pc of the instruction executed, state.pc is next pc + pc_entry: &PcEntry, +) where + F: Clone + Send + Sync, + RA: Arena, + Executor: PreflightExecutor, +{ + #[cfg(any(debug_assertions, feature = "perf-metrics"))] + { + let pc = state.pc; + state.metrics.update_backtrace(pc); + } + + #[cfg(feature = "perf-metrics")] + { + use std::iter::zip; + + let pc = state.pc; + let opcode = pc_entry.insn.opcode; + let opcode_name = executor.get_opcode_name(opcode.as_usize()); + + let debug_info = state.metrics.debug_infos.get(prev_pc); + let dsl_instr = debug_info.as_ref().map(|info| info.dsl_instruction.clone()); + + let now_trace_heights: Vec = state + .ctx + .arenas + .iter() + .map(|arena| arena.current_trace_height()) + .collect(); + let now_trace_cells = zip(&state.metrics.main_widths, &now_trace_heights) + .map(|(main_width, h)| main_width * h) + .collect_vec(); + state + .metrics + .update_trace_cells(now_trace_cells, opcode_name, dsl_instr); + + state.metrics.update_current_fn(pc); + } +} + +// Memory access adapter height calculation is slow, so only do it if this is the end of +// execution. +// We also clear the current trace cell counts so there aren't negative diffs at the start of the +// next segment. +#[cfg(feature = "perf-metrics")] +pub fn end_segment_metrics(state: &mut VmExecState>) where - F: PrimeField32, - VC: VmConfig, + F: Clone + Send + Sync, + RA: Arena, { - /// Update metrics that increment per instruction - #[allow(unused_variables)] - pub fn update_instruction_metrics( - &mut self, - pc: u32, - opcode: VmOpcode, - dsl_instr: Option, + use std::iter::zip; + + use crate::system::memory::adapter::AccessAdapterInventory; + + let access_adapter_offset = state.metrics.access_adapter_offset; + let num_sys_airs = state.metrics.num_sys_airs; + let mut now_heights = vec![0; num_sys_airs - access_adapter_offset]; + AccessAdapterInventory::::compute_heights_from_arena( + &state.memory.access_adapter_records, + &mut now_heights, + ); + let now_trace_cells = zip( + &state.metrics.main_widths[access_adapter_offset..], + &now_heights, + ) + .map(|(main_width, h)| main_width * h) + .collect_vec(); + for (air_name, &now_value) in itertools::izip!( + &state.metrics.air_names[access_adapter_offset..], + &now_trace_cells, ) { - self.metrics.cycle_count += 1; - - if self.system_config().profiling { - let executor = self.chip_complex.inventory.get_executor(opcode).unwrap(); - let opcode_name = executor.get_opcode_name(opcode.as_usize()); - self.metrics.update_trace_cells( - &self.air_names, - self.current_trace_cells(), - opcode_name, - dsl_instr, - ); - - #[cfg(feature = "function-span")] - self.metrics.update_current_fn(pc); + if now_value != 0 { + let labels = [ + ("air_name", air_name.clone()), + ("opcode", String::default()), + ("dsl_ir", String::default()), + ("cycle_tracker_span", "memory_access_adapters".to_owned()), + ]; + counter!("cells_used", &labels).increment(now_value as u64); } } + state.metrics.current_trace_cells.fill(0); } impl VmMetrics { - fn update_trace_cells( + pub fn set_pk_info(&mut self, pk: &DeviceMultiStarkProvingKey) { + let (air_names, main_widths, total_widths): (Vec<_>, Vec<_>, Vec<_>) = pk + .per_air + .iter() + .map(|pk| { + let air_names = pk.air_name.clone(); + let width = &pk.vk.params.width; + let main_width = width.main_width(); + let total_width = width.total_width(PB::CHALLENGE_EXT_DEGREE as usize); + (air_names, main_width, total_width) + }) + .multiunzip(); + self.air_names = air_names; + self.main_widths = main_widths; + self.total_widths = total_widths; + self.current_trace_cells = vec![0; self.air_names.len()]; + } + + pub fn update_trace_cells( &mut self, - air_names: &[String], now_trace_cells: Vec, opcode_name: String, dsl_instr: Option, @@ -74,7 +169,7 @@ impl VmMetrics { *self.counts.entry(key.clone()).or_insert(0) += 1; for (air_name, now_value, prev_value) in - itertools::izip!(air_names, &now_trace_cells, &self.current_trace_cells) + itertools::izip!(&self.air_names, &now_trace_cells, &self.current_trace_cells) { if prev_value != now_value { let key = (key.0.clone(), key.1.clone(), air_name.to_owned()); @@ -104,8 +199,17 @@ impl VmMetrics { *self = self.partial_take(); } - #[cfg(feature = "function-span")] - fn update_current_fn(&mut self, pc: u32) { + #[cfg(any(debug_assertions, feature = "perf-metrics"))] + pub fn update_backtrace(&mut self, pc: u32) { + if let Some(info) = self.debug_infos.get(pc) { + if let Some(trace) = &info.trace { + self.prev_backtrace = Some(trace.clone()); + } + } + } + + #[cfg(feature = "perf-metrics")] + pub(super) fn update_current_fn(&mut self, pc: u32) { if self.fn_bounds.is_empty() { return; } @@ -130,11 +234,6 @@ impl VmMetrics { } pub fn emit(&self) { - for (name, value) in self.chip_heights.iter() { - let labels = [("chip_name", name.clone())]; - counter!("rows_used", &labels).absolute(*value as u64); - } - for ((dsl_ir, opcode), value) in self.counts.iter() { let labels = [ ("dsl_ir", dsl_ir.clone().unwrap_or_else(String::new)), diff --git a/crates/vm/src/system/connector/mod.rs b/crates/vm/src/system/connector/mod.rs index 88a03c484b..6785a027a5 100644 --- a/crates/vm/src/system/connector/mod.rs +++ b/crates/vm/src/system/connector/mod.rs @@ -15,9 +15,9 @@ use openvm_stark_backend::{ p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + Chip, ChipUsageGetter, }; use serde::{Deserialize, Serialize}; @@ -88,6 +88,26 @@ impl BaseAir for VmConnectorAir { } impl VmConnectorAir { + pub fn new( + execution_bus: ExecutionBus, + program_bus: ProgramBus, + range_bus: VariableRangeCheckerBus, + timestamp_max_bits: usize, + ) -> Self { + assert!( + range_bus.range_max_bits * 2 >= timestamp_max_bits, + "Range checker not large enough: range_max_bits={}, timestamp_max_bits={}", + range_bus.range_max_bits, + timestamp_max_bits + ); + Self { + execution_bus, + program_bus, + range_bus, + timestamp_max_bits, + } + } + /// Returns (low_bits, high_bits) to range check. fn timestamp_limb_bits(&self) -> (usize, usize) { let range_max_bits = self.range_bus.range_max_bits; @@ -194,34 +214,25 @@ impl Air } pub struct VmConnectorChip { - pub air: VmConnectorAir, pub range_checker: SharedVariableRangeCheckerChip, pub boundary_states: [Option>; 2], + timestamp_max_bits: usize, _marker: PhantomData, } -impl VmConnectorChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - range_checker: SharedVariableRangeCheckerChip, - timestamp_max_bits: usize, - ) -> Self { +impl VmConnectorChip { + pub fn new(range_checker: SharedVariableRangeCheckerChip, timestamp_max_bits: usize) -> Self { + let range_bus = range_checker.bus(); assert!( - range_checker.bus().range_max_bits * 2 >= timestamp_max_bits, + range_bus.range_max_bits * 2 >= timestamp_max_bits, "Range checker not large enough: range_max_bits={}, timestamp_max_bits={}", - range_checker.bus().range_max_bits, + range_bus.range_max_bits, timestamp_max_bits ); Self { - air: VmConnectorAir { - execution_bus, - program_bus, - range_bus: range_checker.bus(), - timestamp_max_bits, - }, range_checker, boundary_states: [None, None], + timestamp_max_bits, _marker: PhantomData, } } @@ -245,25 +256,30 @@ impl VmConnectorChip { timestamp_low_limb: 0, // will be computed during tracegen }); } + + fn timestamp_limb_bits(&self) -> (usize, usize) { + let range_max_bits = self.range_checker.bus().range_max_bits; + if self.timestamp_max_bits <= range_max_bits { + (self.timestamp_max_bits, 0) + } else { + (range_max_bits, self.timestamp_max_bits - range_max_bits) + } + } } -impl Chip for VmConnectorChip> +impl Chip> for VmConnectorChip> where SC: StarkGenericConfig, Val: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air) - } - - fn generate_air_proof_input(self) -> AirProofInput { + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { let [initial_state, final_state] = self.boundary_states.map(|state| { let mut state = state.unwrap(); // Decompose and range check timestamp let range_max_bits = self.range_checker.range_max_bits(); let timestamp_low_limb = state.timestamp & ((1u32 << range_max_bits) - 1); state.timestamp_low_limb = timestamp_low_limb; - let (low_bits, high_bits) = self.air.timestamp_limb_bits(); + let (low_bits, high_bits) = self.timestamp_limb_bits(); self.range_checker.add_count(timestamp_low_limb, low_bits); self.range_checker .add_count(state.timestamp >> range_max_bits, high_bits); @@ -271,10 +287,10 @@ where state.map(Val::::from_canonical_u32) }); - let trace = RowMajorMatrix::new( + let trace = Arc::new(RowMajorMatrix::new( [initial_state.flatten(), final_state.flatten()].concat(), self.trace_width(), - ); + )); let mut public_values = Val::::zero_vec(VmConnectorPvs::>::width()); *public_values.as_mut_slice().borrow_mut() = VmConnectorPvs { @@ -283,7 +299,7 @@ where exit_code: final_state.exit_code, is_terminate: final_state.is_terminate, }; - AirProofInput::simple(trace, public_values) + AirProvingContext::simple(trace, public_values) } } diff --git a/crates/vm/src/system/connector/tests.rs b/crates/vm/src/system/connector/tests.rs index f3ded1812c..672516bdc6 100644 --- a/crates/vm/src/system/connector/tests.rs +++ b/crates/vm/src/system/connector/tests.rs @@ -7,8 +7,10 @@ use openvm_instructions::{ instruction::Instruction, program::Program, LocalOpcode, SystemOpcode::TERMINATE, }; use openvm_stark_backend::{ - config::StarkGenericConfig, engine::StarkEngine, p3_field::FieldAlgebra, - prover::types::AirProofInput, utils::disable_debug_builder, + config::StarkGenericConfig, + engine::StarkEngine, + p3_field::FieldAlgebra, + prover::{cpu::CpuBackend, types::AirProvingContext}, }; use openvm_stark_sdk::{ config::{ @@ -21,16 +23,25 @@ use openvm_stark_sdk::{ use super::VmConnectorPvs; use crate::{ - arch::{SingleSegmentVmExecutor, SystemConfig, VirtualMachine, CONNECTOR_AIR_ID}, - system::program::trace::VmCommittedExe, + arch::{ + PreflightExecutionOutput, Streams, SystemConfig, VirtualMachine, VmState, CONNECTOR_AIR_ID, + }, + system::{ + memory::{online::GuestMemory, AddressMap}, + program::trace::VmCommittedExe, + SystemCpuBuilder, + }, }; type F = BabyBear; +type SC = BabyBearPoseidon2Config; +type PB = CpuBackend; + #[test] fn test_vm_connector_happy_path() { let exit_code = 1789; - test_impl(true, exit_code, |air_proof_input| { - let pvs: &VmConnectorPvs = air_proof_input.raw.public_values.as_slice().borrow(); + test_impl(true, exit_code, |air_ctx| { + let pvs: &VmConnectorPvs = air_ctx.public_values.as_slice().borrow(); assert_eq!(pvs.is_terminate, F::ONE); assert_eq!(pvs.exit_code, F::from_canonical_u32(exit_code)); }); @@ -39,12 +50,8 @@ fn test_vm_connector_happy_path() { #[test] fn test_vm_connector_wrong_exit_code() { let exit_code = 1789; - test_impl(false, exit_code, |air_proof_input| { - let pvs: &mut VmConnectorPvs = air_proof_input - .raw - .public_values - .as_mut_slice() - .borrow_mut(); + test_impl(false, exit_code, |air_ctx| { + let pvs: &mut VmConnectorPvs = air_ctx.public_values.as_mut_slice().borrow_mut(); pvs.exit_code = F::from_canonical_u32(exit_code + 1); }); } @@ -52,57 +59,59 @@ fn test_vm_connector_wrong_exit_code() { #[test] fn test_vm_connector_wrong_is_terminate() { let exit_code = 1789; - test_impl(false, exit_code, |air_proof_input| { - let pvs: &mut VmConnectorPvs = air_proof_input - .raw - .public_values - .as_mut_slice() - .borrow_mut(); + test_impl(false, exit_code, |air_ctx| { + let pvs: &mut VmConnectorPvs = air_ctx.public_values.as_mut_slice().borrow_mut(); pvs.is_terminate = F::ZERO; }); } -fn test_impl( - should_pass: bool, - exit_code: u32, - f: impl FnOnce(&mut AirProofInput), -) { +fn test_impl(should_pass: bool, exit_code: u32, f: impl FnOnce(&mut AirProvingContext)) { let vm_config = SystemConfig::default(); - let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(3)); - let vm = VirtualMachine::new(engine, vm_config.clone()); - let pk = vm.keygen(); + let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)); + let (mut vm, pk) = + VirtualMachine::new_with_keygen(engine, SystemCpuBuilder, vm_config.clone()).unwrap(); + let vk = pk.get_vk(); - { - let instructions = vec![Instruction::from_isize( - TERMINATE.global_opcode(), - 0, - 0, - exit_code as isize, - 0, - 0, - )]; + let instructions = vec![Instruction::::from_isize( + TERMINATE.global_opcode(), + 0, + 0, + exit_code as isize, + 0, + 0, + )]; - let program = Program::from_instructions(&instructions); - let committed_exe = Arc::new(VmCommittedExe::commit( - program.into(), - vm.engine.config.pcs(), - )); - let single_vm = SingleSegmentVmExecutor::new(vm_config); - let mut proof_input = single_vm - .execute_and_generate(committed_exe, vec![]) - .unwrap(); - let connector_air_input = proof_input - .per_air - .iter_mut() - .find(|(air_id, _)| *air_id == CONNECTOR_AIR_ID); - f(&mut connector_air_input.unwrap().1); - if should_pass { - vm.engine - .prove_then_verify(&pk, proof_input) - .expect("Verification failed"); - } else { - disable_debug_builder(); - assert!(vm.engine.prove_then_verify(&pk, proof_input).is_err()); - } + let program = Program::from_instructions(&instructions); + let committed_exe = Arc::new(VmCommittedExe::::commit( + program.into(), + vm.engine.config().pcs(), + )); + let max_trace_heights = vec![0; vk.total_widths().len()]; + let memory = GuestMemory::new(AddressMap::from_mem_config(&vm_config.memory_config)); + vm.transport_init_memory_to_device(&memory); + vm.load_program(committed_exe.get_committed_trace()); + let from_state = VmState::new(0, 0, memory, Streams::default(), 0); + let PreflightExecutionOutput { + system_records, + record_arenas, + .. + } = vm + .execute_preflight(&committed_exe.exe, from_state, None, &max_trace_heights) + .unwrap(); + let mut ctx = vm + .generate_proving_ctx(system_records, record_arenas) + .unwrap(); + let connector_air_ctx = &mut ctx + .per_air + .iter_mut() + .find(|(air_id, _)| *air_id == CONNECTOR_AIR_ID) + .unwrap() + .1; + f(connector_air_ctx); + let proof = vm.engine.prove(vm.pk(), ctx); + if should_pass { + vm.engine.verify(&vk, &proof).expect("Verification failed"); + } else { + assert!(vm.engine.verify(&vk, &proof).is_err()); } } diff --git a/crates/vm/src/system/memory/adapter/mod.rs b/crates/vm/src/system/memory/adapter/mod.rs index 64e79a920b..46df5d968e 100644 --- a/crates/vm/src/system/memory/adapter/mod.rs +++ b/crates/vm/src/system/memory/adapter/mod.rs @@ -1,145 +1,273 @@ -use std::{borrow::BorrowMut, cmp::max, sync::Arc}; +use std::{ + borrow::{Borrow, BorrowMut}, + marker::PhantomData, + ptr::copy_nonoverlapping, + sync::Arc, +}; pub use air::*; pub use columns::*; use enum_dispatch::enum_dispatch; +use getset::Setters; use openvm_circuit_primitives::{ is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero, var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_stark_backend::{ - config::{Domain, StarkGenericConfig, Val}, + config::{Domain, StarkGenericConfig}, p3_air::BaseAir, p3_commit::PolynomialSpace, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, p3_util::log2_strict_usize, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, + prover::{cpu::CpuBackend, types::AirProvingContext}, }; -use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress}; +use crate::{ + arch::{ + AddressSpaceHostConfig, AddressSpaceHostLayout, CustomBorrow, DenseRecordArena, + MemoryCellType, MemoryConfig, SizedRecord, + }, + system::memory::{ + adapter::records::{ + arena_size_bound, AccessLayout, AccessRecordHeader, AccessRecordMut, + MERGE_AND_NOT_SPLIT_FLAG, + }, + offline_checker::MemoryBus, + MemoryAddress, + }, +}; mod air; mod columns; -#[cfg(test)] -mod tests; +pub mod records; +#[derive(Setters)] pub struct AccessAdapterInventory { + pub(super) memory_config: MemoryConfig, chips: Vec>, - air_names: Vec, + #[getset(set = "pub")] + arena: DenseRecordArena, + #[cfg(feature = "metrics")] + pub(crate) trace_heights: Vec, } -impl AccessAdapterInventory { +impl AccessAdapterInventory { pub fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, - clk_max_bits: usize, - max_access_adapter_n: usize, + memory_config: MemoryConfig, ) -> Self { let rc = range_checker; let mb = memory_bus; - let cmb = clk_max_bits; - let maan = max_access_adapter_n; + let tmb = memory_config.timestamp_max_bits; + let maan = memory_config.max_access_adapter_n; assert!(matches!(maan, 2 | 4 | 8 | 16 | 32)); let chips: Vec<_> = [ - Self::create_access_adapter_chip::<2>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<4>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<8>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<16>(rc.clone(), mb, cmb, maan), - Self::create_access_adapter_chip::<32>(rc.clone(), mb, cmb, maan), + Self::create_access_adapter_chip::<2>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<4>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<8>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<16>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<32>(rc.clone(), mb, tmb, maan), ] .into_iter() .flatten() .collect(); - let air_names = (0..chips.len()).map(|i| air_name(1 << (i + 1))).collect(); - Self { chips, air_names } + Self { + memory_config, + chips, + arena: DenseRecordArena::with_byte_capacity(0), + #[cfg(feature = "metrics")] + trace_heights: Vec::new(), + } } + pub fn num_access_adapters(&self) -> usize { self.chips.len() } - pub fn set_override_trace_heights(&mut self, overridden_heights: Vec) { - assert_eq!(overridden_heights.len(), self.chips.len()); - for (chip, oh) in self.chips.iter_mut().zip(overridden_heights) { - chip.set_override_trace_heights(oh); - } - } - pub fn add_record(&mut self, record: AccessAdapterRecord) { - let n = record.data.len(); - let idx = log2_strict_usize(n) - 1; - let chip = &mut self.chips[idx]; - debug_assert!(chip.n() == n); - chip.add_record(record); - } - pub fn extend_records(&mut self, records: Vec>) { - for record in records { - self.add_record(record); + pub(super) fn set_override_trace_heights(&mut self, overridden_heights: Vec) { + self.set_arena_from_trace_heights( + &overridden_heights + .iter() + .map(|&h| h as u32) + .collect::>(), + ); + for (chip, oh) in self.chips.iter_mut().zip(overridden_heights) { + chip.set_override_trace_height(oh); } } - #[cfg(test)] - pub fn records_for_n(&self, n: usize) -> &[AccessAdapterRecord] { - let idx = log2_strict_usize(n) - 1; - let chip = &self.chips[idx]; - chip.records() - } - - #[cfg(test)] - pub fn total_records(&self) -> usize { - self.chips.iter().map(|chip| chip.records().len()).sum() + pub(super) fn set_arena_from_trace_heights(&mut self, trace_heights: &[u32]) { + assert_eq!(trace_heights.len(), self.chips.len()); + let size_bound = arena_size_bound(trace_heights); + tracing::debug!( + "Allocating {} bytes for memory adapters arena from heights {:?}", + size_bound, + trace_heights + ); + self.arena.set_byte_capacity(size_bound); } - pub fn get_heights(&self) -> Vec { - self.chips - .iter() - .map(|chip| chip.current_trace_height()) - .collect() - } - #[allow(dead_code)] pub fn get_widths(&self) -> Vec { - self.chips.iter().map(|chip| chip.trace_width()).collect() - } - pub fn get_cells(&self) -> Vec { self.chips .iter() - .map(|chip| chip.current_trace_cells()) + .map(|chip: &GenericAccessAdapterChip| chip.trace_width()) .collect() } - pub fn airs(&self) -> Vec> - where - F: PrimeField32, - Domain: PolynomialSpace, - { - self.chips.iter().map(|chip| chip.air()).collect() + + /// `heights` should have length equal to the number of access adapter chips. + pub(crate) fn compute_heights_from_arena(arena: &DenseRecordArena, heights: &mut [usize]) { + let bytes = arena.allocated(); + tracing::debug!( + "Computing heights from memory adapters arena: used {} bytes", + bytes.len() + ); + let mut ptr = 0; + while ptr < bytes.len() { + let header: &AccessRecordHeader = bytes[ptr..].borrow(); + let layout: AccessLayout = unsafe { bytes[ptr..].extract_layout() }; + ptr += as SizedRecord>::size(&layout); + + let log_max_block_size = log2_strict_usize(header.block_size as usize); + for (i, h) in heights + .iter_mut() + .enumerate() + .take(log_max_block_size) + .skip(log2_strict_usize(header.lowest_block_size as usize)) + { + *h += 1 << (log_max_block_size - i - 1); + } + } + tracing::debug!("Computed heights from memory adapters arena: {:?}", heights); } - pub fn air_names(&self) -> Vec { - self.air_names.clone() + + fn apply_overridden_heights(&mut self, heights: &mut [usize]) { + for (i, h) in heights.iter_mut().enumerate() { + if let Some(oh) = self.chips[i].overridden_trace_height() { + assert!( + oh >= *h, + "Overridden height {oh} is less than the required height {}", + *h + ); + *h = oh; + } + *h = next_power_of_two_or_zero(*h); + } } - pub fn generate_air_proof_inputs(self) -> Vec> + + pub fn generate_proving_ctx( + &mut self, + ) -> Vec>> where F: PrimeField32, Domain: PolynomialSpace, { - self.chips + let num_adapters = self.chips.len(); + + let mut heights = vec![0; num_adapters]; + Self::compute_heights_from_arena(&self.arena, &mut heights); + self.apply_overridden_heights(&mut heights); + + let widths = self + .chips + .iter() + .map(|chip| chip.trace_width()) + .collect::>(); + let mut traces = widths + .iter() + .zip(heights.iter()) + .map(|(&width, &height)| RowMajorMatrix::new(vec![F::ZERO; width * height], width)) + .collect::>(); + #[cfg(feature = "metrics")] + { + self.trace_heights = heights; + } + + let mut trace_ptrs = vec![0; num_adapters]; + + let bytes = self.arena.allocated_mut(); + let mut ptr = 0; + while ptr < bytes.len() { + let layout: AccessLayout = unsafe { bytes[ptr..].extract_layout() }; + let record: AccessRecordMut<'_> = bytes[ptr..].custom_borrow(layout.clone()); + ptr += as SizedRecord>::size(&layout); + + let log_min_block_size = log2_strict_usize(record.header.lowest_block_size as usize); + let log_max_block_size = log2_strict_usize(record.header.block_size as usize); + + if record.header.timestamp_and_mask & MERGE_AND_NOT_SPLIT_FLAG != 0 { + for i in log_min_block_size..log_max_block_size { + let data_len = layout.type_size << i; + let ts_len = 1 << (i - log_min_block_size); + for j in 0..record.data.len() / (2 * data_len) { + let row_slice = + &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]]; + trace_ptrs[i] += widths[i]; + self.chips[i].fill_trace_row( + &self.memory_config.addr_spaces, + row_slice, + false, + MemoryAddress::new( + record.header.address_space, + record.header.pointer + (j << (i + 1)) as u32, + ), + &record.data[j * 2 * data_len..(j + 1) * 2 * data_len], + *record.timestamps[2 * j * ts_len..(2 * j + 1) * ts_len] + .iter() + .max() + .unwrap(), + *record.timestamps[(2 * j + 1) * ts_len..(2 * j + 2) * ts_len] + .iter() + .max() + .unwrap(), + ); + } + } + } else { + let timestamp = record.header.timestamp_and_mask; + for i in log_min_block_size..log_max_block_size { + let data_len = layout.type_size << i; + for j in 0..record.data.len() / (2 * data_len) { + let row_slice = + &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]]; + trace_ptrs[i] += widths[i]; + self.chips[i].fill_trace_row( + &self.memory_config.addr_spaces, + row_slice, + true, + MemoryAddress::new( + record.header.address_space, + record.header.pointer + (j << (i + 1)) as u32, + ), + &record.data[j * 2 * data_len..(j + 1) * 2 * data_len], + timestamp, + timestamp, + ); + } + } + } + } + traces .into_iter() - .map(|chip| chip.generate_air_proof_input()) + .map(|trace| AirProvingContext::simple_no_pis(Arc::new(trace))) .collect() } fn create_access_adapter_chip( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, - clk_max_bits: usize, + timestamp_max_bits: usize, max_access_adapter_n: usize, - ) -> Option> { + ) -> Option> + where + F: Clone + Send + Sync, + { if N <= max_access_adapter_n { Some(GenericAccessAdapterChip::new::( range_checker, memory_bus, - clk_max_bits, + timestamp_max_bits, )) } else { None @@ -147,37 +275,27 @@ impl AccessAdapterInventory { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum AccessAdapterRecordKind { - Split, - Merge { +#[enum_dispatch] +pub(crate) trait GenericAccessAdapterChipTrait { + fn trace_width(&self) -> usize; + fn set_override_trace_height(&mut self, overridden_height: usize); + fn overridden_trace_height(&self) -> Option; + + #[allow(clippy::too_many_arguments)] + fn fill_trace_row( + &self, + addr_spaces: &[AddressSpaceHostConfig], + row: &mut [F], + is_split: bool, + address: MemoryAddress, + values: &[u8], left_timestamp: u32, right_timestamp: u32, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct AccessAdapterRecord { - pub timestamp: u32, - pub address_space: T, - pub start_index: T, - pub data: Vec, - pub kind: AccessAdapterRecordKind, -} - -#[enum_dispatch] -pub trait GenericAccessAdapterChipTrait { - fn set_override_trace_heights(&mut self, overridden_height: usize); - fn add_record(&mut self, record: AccessAdapterRecord); - fn n(&self) -> usize; - fn generate_trace(self) -> RowMajorMatrix - where + ) where F: PrimeField32; } -#[derive(Chip, ChipUsageGetter)] #[enum_dispatch(GenericAccessAdapterChipTrait)] -#[chip(where = "F: PrimeField32")] enum GenericAccessAdapterChip { N2(AccessAdapterChip), N4(AccessAdapterChip), @@ -186,15 +304,15 @@ enum GenericAccessAdapterChip { N32(AccessAdapterChip), } -impl GenericAccessAdapterChip { +impl GenericAccessAdapterChip { fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, - clk_max_bits: usize, + timestamp_max_bits: usize, ) -> Self { let rc = range_checker; let mb = memory_bus; - let cmb = clk_max_bits; + let cmb = timestamp_max_bits; match N { 2 => GenericAccessAdapterChip::N2(AccessAdapterChip::new(rc, mb, cmb)), 4 => GenericAccessAdapterChip::N4(AccessAdapterChip::new(rc, mb, cmb)), @@ -204,127 +322,89 @@ impl GenericAccessAdapterChip { _ => panic!("Only supports N in (2, 4, 8, 16, 32)"), } } - - #[cfg(test)] - fn records(&self) -> &[AccessAdapterRecord] { - match &self { - GenericAccessAdapterChip::N2(chip) => &chip.records, - GenericAccessAdapterChip::N4(chip) => &chip.records, - GenericAccessAdapterChip::N8(chip) => &chip.records, - GenericAccessAdapterChip::N16(chip) => &chip.records, - GenericAccessAdapterChip::N32(chip) => &chip.records, - } - } } -pub struct AccessAdapterChip { + +pub(crate) struct AccessAdapterChip { air: AccessAdapterAir, range_checker: SharedVariableRangeCheckerChip, - pub records: Vec>, overridden_height: Option, + _marker: PhantomData, } -impl AccessAdapterChip { + +impl AccessAdapterChip { pub fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, - clk_max_bits: usize, + timestamp_max_bits: usize, ) -> Self { - let lt_air = IsLtSubAir::new(range_checker.bus(), clk_max_bits); + let lt_air = IsLtSubAir::new(range_checker.bus(), timestamp_max_bits); Self { air: AccessAdapterAir:: { memory_bus, lt_air }, range_checker, - records: vec![], overridden_height: None, + _marker: PhantomData, } } } impl GenericAccessAdapterChipTrait for AccessAdapterChip { - fn set_override_trace_heights(&mut self, overridden_height: usize) { - self.overridden_height = Some(overridden_height); - } - fn add_record(&mut self, record: AccessAdapterRecord) { - self.records.push(record); - } - fn n(&self) -> usize { - N - } - fn generate_trace(self) -> RowMajorMatrix - where - F: PrimeField32, - { - let width = BaseAir::::width(&self.air); - let height = if let Some(oh) = self.overridden_height { - assert!( - oh >= self.records.len(), - "Overridden height is less than the required height" - ); - oh - } else { - self.records.len() - }; - let height = next_power_of_two_or_zero(height); - let mut values = F::zero_vec(height * width); - - values - .par_chunks_mut(width) - .zip(self.records.into_par_iter()) - .for_each(|(row, record)| { - let row: &mut AccessAdapterCols = row.borrow_mut(); - - row.is_valid = F::ONE; - row.values = record.data.try_into().unwrap(); - row.address = MemoryAddress::new(record.address_space, record.start_index); - - let (left_timestamp, right_timestamp) = match record.kind { - AccessAdapterRecordKind::Split => (record.timestamp, record.timestamp), - AccessAdapterRecordKind::Merge { - left_timestamp, - right_timestamp, - } => (left_timestamp, right_timestamp), - }; - debug_assert_eq!(max(left_timestamp, right_timestamp), record.timestamp); - - row.left_timestamp = F::from_canonical_u32(left_timestamp); - row.right_timestamp = F::from_canonical_u32(right_timestamp); - row.is_split = F::from_bool(record.kind == AccessAdapterRecordKind::Split); - - self.air.lt_air.generate_subrow( - (self.range_checker.as_ref(), left_timestamp, right_timestamp), - (&mut row.lt_aux, &mut row.is_right_larger), - ); - }); - RowMajorMatrix::new(values, width) - } -} - -impl Chip for AccessAdapterChip, N> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let trace = self.generate_trace(); - AirProofInput::simple_no_pis(trace) + fn trace_width(&self) -> usize { + BaseAir::::width(&self.air) } -} -impl ChipUsageGetter for AccessAdapterChip { - fn air_name(&self) -> String { - air_name(N) + fn set_override_trace_height(&mut self, overridden_height: usize) { + self.overridden_height = Some(overridden_height); } - fn current_trace_height(&self) -> usize { - self.records.len() + fn overridden_trace_height(&self) -> Option { + self.overridden_height } - fn trace_width(&self) -> usize { - BaseAir::::width(&self.air) + fn fill_trace_row( + &self, + addr_spaces: &[AddressSpaceHostConfig], + row: &mut [F], + is_split: bool, + address: MemoryAddress, + values: &[u8], + left_timestamp: u32, + right_timestamp: u32, + ) where + F: PrimeField32, + { + let row: &mut AccessAdapterCols = row.borrow_mut(); + row.is_valid = F::ONE; + row.is_split = F::from_bool(is_split); + row.address = MemoryAddress::new( + F::from_canonical_u32(address.address_space), + F::from_canonical_u32(address.pointer), + ); + let addr_space_layout = addr_spaces[address.address_space as usize].layout; + // SAFETY: values will be a slice of the cell type + unsafe { + match addr_space_layout { + MemoryCellType::Native { .. } => { + copy_nonoverlapping( + values.as_ptr(), + row.values.as_mut_ptr() as *mut u8, + N * size_of::(), + ); + } + _ => { + for (dst, src) in row + .values + .iter_mut() + .zip(values.chunks_exact(addr_space_layout.size())) + { + *dst = addr_space_layout.to_field(src); + } + } + } + } + row.left_timestamp = F::from_canonical_u32(left_timestamp); + row.right_timestamp = F::from_canonical_u32(right_timestamp); + self.air.lt_air.generate_subrow( + (self.range_checker.as_ref(), left_timestamp, right_timestamp), + (&mut row.lt_aux, &mut row.is_right_larger), + ); } } - -#[inline] -fn air_name(n: usize) -> String { - format!("AccessAdapter<{}>", n) -} diff --git a/crates/vm/src/system/memory/adapter/records.rs b/crates/vm/src/system/memory/adapter/records.rs new file mode 100644 index 0000000000..2a82bc51a4 --- /dev/null +++ b/crates/vm/src/system/memory/adapter/records.rs @@ -0,0 +1,144 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::{align_of, size_of}, +}; + +use openvm_circuit_primitives::AlignedBytesBorrow; + +use crate::arch::{CustomBorrow, DenseRecordArena, RecordArena, SizedRecord}; + +#[repr(C)] +#[derive(Debug, Clone, Copy, AlignedBytesBorrow, PartialEq, Eq, PartialOrd, Ord)] +pub struct AccessRecordHeader { + /// Iff we need to merge before, this has the `MERGE_AND_NOT_SPLIT_FLAG` bit set + pub timestamp_and_mask: u32, + pub address_space: u32, + pub pointer: u32, + // PERF: these three are easily mergeable into a single u32 + pub block_size: u32, + pub lowest_block_size: u32, + pub type_size: u32, +} + +#[repr(C)] +#[derive(Debug)] +pub struct AccessRecordMut<'a> { + pub header: &'a mut AccessRecordHeader, + // PERF(AG): optimize with some `Option` serialization stuff + pub timestamps: &'a mut [u32], // len is block_size / lowest_block_size + pub data: &'a mut [u8], // len is block_size * type_size +} + +#[derive(Debug, Clone)] +pub struct AccessLayout { + /// The size of the block in elements. + pub block_size: usize, + /// The size of the minimal block we may split into/merge from (usually 1 or 4) + pub lowest_block_size: usize, + /// The size of the type in bytes (1 for u8, 4 for F). + pub type_size: usize, +} + +impl AccessLayout { + pub(crate) fn from_record_header(header: &AccessRecordHeader) -> Self { + Self { + block_size: header.block_size as usize, + lowest_block_size: header.lowest_block_size as usize, + type_size: header.type_size as usize, + } + } +} + +pub(crate) const MERGE_AND_NOT_SPLIT_FLAG: u32 = 1 << 31; + +pub(crate) fn size_by_layout(layout: &AccessLayout) -> usize { + size_of::() // header struct + + (layout.block_size / layout.lowest_block_size) * size_of::() // timestamps + + (layout.block_size * layout.type_size).next_multiple_of(4) // data +} + +impl SizedRecord for AccessRecordMut<'_> { + fn size(layout: &AccessLayout) -> usize { + size_by_layout(layout) + } + + fn alignment(_: &AccessLayout) -> usize { + align_of::() + } +} + +impl<'a> CustomBorrow<'a, AccessRecordMut<'a>, AccessLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: AccessLayout) -> AccessRecordMut<'a> { + // header: AccessRecordHeader (using trivial borrowing) + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let header = header_buf.borrow_mut(); + + let mut offset = 0; + + // timestamps: [u32] (block_size / cell_size * 4 bytes) + let timestamps = unsafe { + std::slice::from_raw_parts_mut( + rest.as_mut_ptr().add(offset) as *mut u32, + layout.block_size / layout.lowest_block_size, + ) + }; + offset += layout.block_size / layout.lowest_block_size * size_of::(); + + // data: [u8] (block_size * type_size bytes) + let data = unsafe { + std::slice::from_raw_parts_mut( + rest.as_mut_ptr().add(offset), + layout.block_size * layout.type_size, + ) + }; + + AccessRecordMut { + header, + data, + timestamps, + } + } + + unsafe fn extract_layout(&self) -> AccessLayout { + let header: &AccessRecordHeader = self.borrow(); + AccessLayout { + block_size: header.block_size as usize, + lowest_block_size: header.lowest_block_size as usize, + type_size: header.type_size as usize, + } + } +} + +impl<'a> RecordArena<'a, AccessLayout, AccessRecordMut<'a>> for DenseRecordArena { + fn alloc(&'a mut self, layout: AccessLayout) -> AccessRecordMut<'a> { + let bytes = self.alloc_bytes( as SizedRecord>::size( + &layout, + )); + <[u8] as CustomBorrow, AccessLayout>>::custom_borrow(bytes, layout) + } +} + +/// `trace_heights[i]` is assumed to correspond to `Adapter< 2^(i+1) >`. +pub fn arena_size_bound(trace_heights: &[u32]) -> usize { + // At the very worst, each row in `Adapter` + // corresponds to a unique record of `block_size` being `2 * N`, + // and its `lowest_block_size` is at least 1 and `type_size` is at most 4. + let size_bound = trace_heights + .iter() + .enumerate() + .map(|(i, &h)| { + size_by_layout(&AccessLayout { + block_size: 1 << (i + 1), + lowest_block_size: 1, + type_size: 4, + }) * h as usize + }) + .sum::(); + tracing::debug!( + "Allocating {} bytes for memory adapters arena from heights {:?}", + size_bound, + trace_heights + ); + size_bound +} diff --git a/crates/vm/src/system/memory/adapter/tests.rs b/crates/vm/src/system/memory/adapter/tests.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/crates/vm/src/system/memory/adapter/tests.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/vm/src/system/memory/controller/dimensions.rs b/crates/vm/src/system/memory/controller/dimensions.rs index 1082d3adf0..77345c2e82 100644 --- a/crates/vm/src/system/memory/controller/dimensions.rs +++ b/crates/vm/src/system/memory/controller/dimensions.rs @@ -2,23 +2,24 @@ use derive_new::new; use openvm_stark_backend::p3_util::log2_strict_usize; use serde::{Deserialize, Serialize}; -use crate::{arch::MemoryConfig, system::memory::CHUNK}; +use crate::{ + arch::{MemoryConfig, ADDR_SPACE_OFFSET}, + system::memory::CHUNK, +}; -// indicates that there are 2^`as_height` address spaces numbered starting from `as_offset`, +// indicates that there are 2^`addr_space_height` address spaces numbered starting from 1, // and that each address space has 2^`address_height` addresses numbered starting from 0 #[derive(Clone, Copy, Debug, Serialize, Deserialize, new)] pub struct MemoryDimensions { /// Address space height - pub as_height: usize, + pub addr_space_height: usize, /// Pointer height pub address_height: usize, - /// Address space offset - pub as_offset: u32, } impl MemoryDimensions { pub fn overall_height(&self) -> usize { - self.as_height + self.address_height + self.addr_space_height + self.address_height } /// Convert an address label (address space, block id) to its index in the memory merkle tree. /// @@ -27,17 +28,29 @@ impl MemoryDimensions { /// This function is primarily for internal use for accessing the memory merkle tree. /// Users should use a higher-level API when possible. pub fn label_to_index(&self, (addr_space, block_id): (u32, u32)) -> u64 { - debug_assert!(block_id < (1 << self.address_height)); - (((addr_space - self.as_offset) as u64) << self.address_height) + block_id as u64 + debug_assert!( + block_id < (1 << self.address_height), + "block_id={block_id} exceeds address_height={}", + self.address_height + ); + (((addr_space - ADDR_SPACE_OFFSET) as u64) << self.address_height) + block_id as u64 + } + + /// Convert an index in the memory merkle tree to an address label (address space, block id). + /// + /// This function performs the inverse operation of `label_to_index`. + pub fn index_to_label(&self, index: u64) -> (u32, u32) { + let block_id = (index & ((1 << self.address_height) - 1)) as u32; + let addr_space = (index >> self.address_height) as u32 + ADDR_SPACE_OFFSET; + (addr_space, block_id) } } impl MemoryConfig { pub fn memory_dimensions(&self) -> MemoryDimensions { MemoryDimensions { - as_height: self.as_height, + addr_space_height: self.addr_space_height, address_height: self.pointer_max_bits - log2_strict_usize(CHUNK), - as_offset: self.as_offset, } } } diff --git a/crates/vm/src/system/memory/controller/interface.rs b/crates/vm/src/system/memory/controller/interface.rs index b51e960a32..ff0a0b64a9 100644 --- a/crates/vm/src/system/memory/controller/interface.rs +++ b/crates/vm/src/system/memory/controller/interface.rs @@ -1,10 +1,23 @@ use openvm_stark_backend::{interaction::PermutationCheckBus, p3_field::PrimeField32}; use crate::system::memory::{ - merkle::MemoryMerkleChip, persistent::PersistentBoundaryChip, volatile::VolatileBoundaryChip, + merkle::{MemoryMerkleAir, MemoryMerkleChip}, + persistent::{PersistentBoundaryAir, PersistentBoundaryChip}, + volatile::{VolatileBoundaryAir, VolatileBoundaryChip}, MemoryImage, CHUNK, }; +#[derive(Clone)] +pub enum MemoryInterfaceAirs { + Volatile { + boundary: VolatileBoundaryAir, + }, + Persistent { + boundary: PersistentBoundaryAir, + merkle: MemoryMerkleAir, + }, +} + #[allow(clippy::large_enum_variant)] pub enum MemoryInterface { Volatile { @@ -13,25 +26,11 @@ pub enum MemoryInterface { Persistent { boundary_chip: PersistentBoundaryChip, merkle_chip: MemoryMerkleChip, - initial_memory: MemoryImage, + initial_memory: MemoryImage, }, } impl MemoryInterface { - pub fn touch_range(&mut self, addr_space: u32, pointer: u32, len: u32) { - match self { - MemoryInterface::Volatile { .. } => {} - MemoryInterface::Persistent { - boundary_chip, - merkle_chip, - .. - } => { - boundary_chip.touch_range(addr_space, pointer, len); - merkle_chip.touch_range(addr_space, pointer, len); - } - } - } - pub fn compression_bus(&self) -> Option { match self { MemoryInterface::Volatile { .. } => None, diff --git a/crates/vm/src/system/memory/controller/mod.rs b/crates/vm/src/system/memory/controller/mod.rs index 680a03ab8e..aabe4df08d 100644 --- a/crates/vm/src/system/memory/controller/mod.rs +++ b/crates/vm/src/system/memory/controller/mod.rs @@ -1,51 +1,40 @@ -use std::{ - array, - collections::BTreeMap, - iter, - marker::PhantomData, - mem, - sync::{Arc, Mutex}, -}; +//! [MemoryController] can be considered as the Memory Chip Complex for the CPU Backend. +use std::{collections::BTreeMap, fmt::Debug, marker::PhantomData, sync::Arc}; use getset::{Getters, MutGetters}; use openvm_circuit_primitives::{ assert_less_than::{AssertLtSubAir, LessThanAuxCols}, - is_zero::IsZeroSubAir, - utils::next_power_of_two_or_zero, - var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip, + }, TraceSubRowGenerator, }; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, interaction::PermutationCheckBus, p3_commit::PolynomialSpace, - p3_field::PrimeField32, + p3_field::{Field, PrimeField32}, p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}, p3_util::{log2_ceil_usize, log2_strict_usize}, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, }; use serde::{Deserialize, Serialize}; use self::interface::MemoryInterface; -use super::{ - paged_vec::{AddressMap, PAGE_SIZE}, - volatile::VolatileBoundaryChip, -}; +use super::{volatile::VolatileBoundaryChip, AddressMap}; use crate::{ - arch::{hasher::HasherChip, MemoryConfig}, - system::memory::{ - adapter::AccessAdapterInventory, - dimensions::MemoryDimensions, - merkle::{MemoryMerkleChip, SerialReceiver}, - offline::{MemoryRecord, OfflineMemory, INITIAL_TIMESTAMP}, - offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryBus, MemoryReadAuxCols, - MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, AUX_LEN, + arch::{DenseRecordArena, MemoryConfig, ADDR_SPACE_OFFSET}, + system::{ + memory::{ + adapter::AccessAdapterInventory, + dimensions::MemoryDimensions, + merkle::MemoryMerkleChip, + offline_checker::{MemoryBaseAuxCols, MemoryBridge, MemoryBus, AUX_LEN}, + persistent::PersistentBoundaryChip, }, - online::{Memory, MemoryLogEntry}, - persistent::PersistentBoundaryChip, - tree::MemoryNode, + poseidon2::Poseidon2PeripheryChip, + TouchedMemory, }, }; @@ -53,16 +42,13 @@ pub mod dimensions; pub mod interface; pub const CHUNK: usize = 8; + /// The offset of the Merkle AIR in AIRs of MemoryController. pub const MERKLE_AIR_OFFSET: usize = 1; /// The offset of the boundary AIR in AIRs of MemoryController. pub const BOUNDARY_AIR_OFFSET: usize = 0; -#[repr(C)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -pub struct RecordId(pub usize); - -pub type MemoryImage = AddressMap; +pub type MemoryImage = AddressMap; #[repr(C)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -71,14 +57,11 @@ pub struct TimestampedValues { pub values: [T; N], } -/// An equipartition of memory, with timestamps and values. +/// A sorted equipartition of memory, with timestamps and values. /// -/// The key is a pair `(address_space, label)`, where `label` is the index of the block in the +/// The "key" is a pair `(address_space, label)`, where `label` is the index of the block in the /// partition. I.e., the starting address of the block is `(address_space, label * N)`. -/// -/// If a key is not present in the map, then the block is uninitialized (and therefore zero). -pub type TimestampedEquipartition = - BTreeMap<(u32, u32), TimestampedValues>; +pub type TimestampedEquipartition = Vec<((u32, u32), TimestampedValues)>; /// An equipartition of memory values. /// @@ -89,69 +72,14 @@ pub type TimestampedEquipartition = pub type Equipartition = BTreeMap<(u32, u32), [F; N]>; #[derive(Getters, MutGetters)] -pub struct MemoryController { +pub struct MemoryController { pub memory_bus: MemoryBus, pub interface_chip: MemoryInterface, - #[getset(get = "pub")] - pub(crate) mem_config: MemoryConfig, pub range_checker: SharedVariableRangeCheckerChip, // Store separately to avoid smart pointer reference each time range_checker_bus: VariableRangeCheckerBus, - // addr_space -> Memory data structure - memory: Memory, - /// A reference to the `OfflineMemory`. Will be populated after `finalize()`. - offline_memory: Arc>>, - pub access_adapters: AccessAdapterInventory, - // Filled during finalization. - final_state: Option>, -} - -#[allow(clippy::large_enum_variant)] -#[derive(Debug)] -enum FinalState { - Volatile(VolatileFinalState), - #[allow(dead_code)] - Persistent(PersistentFinalState), -} -#[derive(Debug, Default)] -struct VolatileFinalState { - _marker: PhantomData, -} -#[allow(dead_code)] -#[derive(Debug)] -struct PersistentFinalState { - final_memory: Equipartition, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum MemoryTraceHeights { - Volatile(VolatileMemoryTraceHeights), - Persistent(PersistentMemoryTraceHeights), -} - -impl MemoryTraceHeights { - fn flatten(&self) -> Vec { - match self { - MemoryTraceHeights::Volatile(oh) => oh.flatten(), - MemoryTraceHeights::Persistent(oh) => oh.flatten(), - } - } - - /// Round all trace heights to the next power of two. This will round trace heights of 0 to 1. - pub fn round_to_next_power_of_two(&mut self) { - match self { - MemoryTraceHeights::Volatile(oh) => oh.round_to_next_power_of_two(), - MemoryTraceHeights::Persistent(oh) => oh.round_to_next_power_of_two(), - } - } - - /// Round all trace heights to the next power of two, except 0 stays 0. - pub fn round_to_next_power_of_two_or_zero(&mut self) { - match self { - MemoryTraceHeights::Volatile(oh) => oh.round_to_next_power_of_two_or_zero(), - MemoryTraceHeights::Persistent(oh) => oh.round_to_next_power_of_two_or_zero(), - } - } + pub(crate) access_adapter_inventory: AccessAdapterInventory, + pub(crate) hasher_chip: Option>>, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -161,24 +89,14 @@ pub struct VolatileMemoryTraceHeights { } impl VolatileMemoryTraceHeights { - pub fn flatten(&self) -> Vec { - iter::once(self.boundary) - .chain(self.access_adapters.iter().copied()) - .collect() - } - - fn round_to_next_power_of_two(&mut self) { - self.boundary = self.boundary.next_power_of_two(); - self.access_adapters - .iter_mut() - .for_each(|v| *v = v.next_power_of_two()); - } - - fn round_to_next_power_of_two_or_zero(&mut self) { - self.boundary = next_power_of_two_or_zero(self.boundary); - self.access_adapters - .iter_mut() - .for_each(|v| *v = next_power_of_two_or_zero(*v)); + /// `heights` must consist of only memory trace heights, in order of AIR IDs. + pub fn from_slice(heights: &[u32]) -> Self { + let boundary = heights[0] as usize; + let access_adapters = heights[1..].iter().map(|&h| h as usize).collect(); + Self { + boundary, + access_adapters, + } } } @@ -189,32 +107,21 @@ pub struct PersistentMemoryTraceHeights { access_adapters: Vec, } impl PersistentMemoryTraceHeights { - pub fn flatten(&self) -> Vec { - vec![self.boundary, self.merkle] - .into_iter() - .chain(self.access_adapters.iter().copied()) - .collect() - } - - fn round_to_next_power_of_two(&mut self) { - self.boundary = self.boundary.next_power_of_two(); - self.merkle = self.merkle.next_power_of_two(); - self.access_adapters - .iter_mut() - .for_each(|v| *v = v.next_power_of_two()); - } - - fn round_to_next_power_of_two_or_zero(&mut self) { - self.boundary = next_power_of_two_or_zero(self.boundary); - self.merkle = next_power_of_two_or_zero(self.merkle); - self.access_adapters - .iter_mut() - .for_each(|v| *v = next_power_of_two_or_zero(*v)); + /// `heights` must consist of only memory trace heights, in order of AIR IDs. + pub fn from_slice(heights: &[u32]) -> Self { + let boundary = heights[0] as usize; + let merkle = heights[1] as usize; + let access_adapters = heights[2..].iter().map(|&h| h as usize).collect(); + Self { + boundary, + merkle, + access_adapters, + } } } impl MemoryController { - pub fn continuation_enabled(&self) -> bool { + pub(crate) fn continuation_enabled(&self) -> bool { match &self.interface_chip { MemoryInterface::Volatile { .. } => false, MemoryInterface::Persistent { .. } => true, @@ -226,15 +133,17 @@ impl MemoryController { range_checker: SharedVariableRangeCheckerChip, ) -> Self { let range_checker_bus = range_checker.bus(); - let initial_memory = AddressMap::from_mem_config(&mem_config); assert!(mem_config.pointer_max_bits <= F::bits() - 2); - assert!(mem_config.as_height < F::bits() - 2); + assert!(mem_config + .addr_spaces + .iter() + .all(|&space| space.num_cells <= (1 << mem_config.pointer_max_bits))); + assert!(mem_config.addr_space_height < F::bits() - 2); let addr_space_max_bits = log2_ceil_usize( - (mem_config.as_offset + 2u32.pow(mem_config.as_height as u32)) as usize, + (ADDR_SPACE_OFFSET + 2u32.pow(mem_config.addr_space_height as u32)) as usize, ); Self { memory_bus, - mem_config, interface_chip: MemoryInterface::Volatile { boundary_chip: VolatileBoundaryChip::new( memory_bus, @@ -243,23 +152,14 @@ impl MemoryController { range_checker.clone(), ), }, - memory: Memory::new(&mem_config), - offline_memory: Arc::new(Mutex::new(OfflineMemory::new( - initial_memory, - 1, - memory_bus, - range_checker.clone(), - mem_config, - ))), - access_adapters: AccessAdapterInventory::new( + access_adapter_inventory: AccessAdapterInventory::new( range_checker.clone(), memory_bus, - mem_config.clk_max_bits, - mem_config.max_access_adapter_n, + mem_config, ), range_checker, range_checker_bus, - final_state: None, + hasher_chip: None, } } @@ -272,12 +172,11 @@ impl MemoryController { range_checker: SharedVariableRangeCheckerChip, merkle_bus: PermutationCheckBus, compression_bus: PermutationCheckBus, + hasher_chip: Arc>, ) -> Self { - assert_eq!(mem_config.as_offset, 1); let memory_dims = MemoryDimensions { - as_height: mem_config.as_height, + addr_space_height: mem_config.addr_space_height, address_height: mem_config.pointer_max_bits - log2_strict_usize(CHUNK), - as_offset: 1, }; let range_checker_bus = range_checker.bus(); let interface_chip = MemoryInterface::Persistent { @@ -292,73 +191,50 @@ impl MemoryController { }; Self { memory_bus, - mem_config, interface_chip, - memory: Memory::new(&mem_config), // it is expected that the memory will be set later - offline_memory: Arc::new(Mutex::new(OfflineMemory::new( - AddressMap::from_mem_config(&mem_config), - CHUNK, - memory_bus, - range_checker.clone(), - mem_config, - ))), - access_adapters: AccessAdapterInventory::new( + access_adapter_inventory: AccessAdapterInventory::new( range_checker.clone(), memory_bus, - mem_config.clk_max_bits, - mem_config.max_access_adapter_n, + mem_config, ), range_checker, range_checker_bus, - final_state: None, + hasher_chip: Some(hasher_chip), } } - pub fn memory_image(&self) -> &MemoryImage { - &self.memory.data + pub fn memory_config(&self) -> &MemoryConfig { + &self.access_adapter_inventory.memory_config } - pub fn set_override_trace_heights(&mut self, overridden_heights: MemoryTraceHeights) { + pub(crate) fn set_override_trace_heights(&mut self, overridden_heights: &[u32]) { match &mut self.interface_chip { - MemoryInterface::Volatile { boundary_chip } => match overridden_heights { - MemoryTraceHeights::Volatile(oh) => { - boundary_chip.set_overridden_height(oh.boundary); - self.access_adapters - .set_override_trace_heights(oh.access_adapters); - } - _ => panic!("Expect overridden_heights to be MemoryTraceHeights::Volatile"), - }, + MemoryInterface::Volatile { boundary_chip } => { + let oh = VolatileMemoryTraceHeights::from_slice(overridden_heights); + boundary_chip.set_overridden_height(oh.boundary); + self.access_adapter_inventory + .set_override_trace_heights(oh.access_adapters); + } MemoryInterface::Persistent { boundary_chip, merkle_chip, .. - } => match overridden_heights { - MemoryTraceHeights::Persistent(oh) => { - boundary_chip.set_overridden_height(oh.boundary); - merkle_chip.set_overridden_height(oh.merkle); - self.access_adapters - .set_override_trace_heights(oh.access_adapters); - } - _ => panic!("Expect overridden_heights to be MemoryTraceHeights::Persistent"), - }, + } => { + let oh = PersistentMemoryTraceHeights::from_slice(overridden_heights); + boundary_chip.set_overridden_height(oh.boundary); + merkle_chip.set_overridden_height(oh.merkle); + self.access_adapter_inventory + .set_override_trace_heights(oh.access_adapters); + } } } - pub fn set_initial_memory(&mut self, memory: MemoryImage) { - if self.timestamp() > INITIAL_TIMESTAMP + 1 { - panic!("Cannot set initial memory after first timestamp"); - } - let mut offline_memory = self.offline_memory.lock().unwrap(); - offline_memory.set_initial_memory(memory.clone(), self.mem_config); - - self.memory = Memory::from_image(memory.clone(), self.mem_config.access_capacity); - + /// This only sets the initial memory image for the persistent boundary and merkle tree chips. + /// Tracing memory should be set separately. + pub(crate) fn set_initial_memory(&mut self, memory: AddressMap) { match &mut self.interface_chip { MemoryInterface::Volatile { .. } => { - assert!( - memory.is_empty(), - "Cannot set initial memory for volatile memory" - ); + // Skip initialization for volatile memory } MemoryInterface::Persistent { initial_memory, .. } => { *initial_memory = memory; @@ -369,207 +245,68 @@ impl MemoryController { pub fn memory_bridge(&self) -> MemoryBridge { MemoryBridge::new( self.memory_bus, - self.mem_config.clk_max_bits, + self.memory_config().timestamp_max_bits, self.range_checker_bus, ) } - pub fn read_cell(&mut self, address_space: F, pointer: F) -> (RecordId, F) { - let (record_id, [data]) = self.read(address_space, pointer); - (record_id, data) - } - - pub fn read(&mut self, address_space: F, pointer: F) -> (RecordId, [F; N]) { - let address_space_u32 = address_space.as_canonical_u32(); - let ptr_u32 = pointer.as_canonical_u32(); - assert!( - address_space == F::ZERO || ptr_u32 < (1 << self.mem_config.pointer_max_bits), - "memory out of bounds: {ptr_u32:?}", - ); - - let (record_id, values) = self.memory.read::(address_space_u32, ptr_u32); - - (record_id, values) - } - - /// Reads a word directly from memory without updating internal state. - /// - /// Any value returned is unconstrained. - pub fn unsafe_read_cell(&self, addr_space: F, ptr: F) -> F { - self.unsafe_read::<1>(addr_space, ptr)[0] - } - - /// Reads a word directly from memory without updating internal state. - /// - /// Any value returned is unconstrained. - pub fn unsafe_read(&self, addr_space: F, ptr: F) -> [F; N] { - let addr_space = addr_space.as_canonical_u32(); - let ptr = ptr.as_canonical_u32(); - array::from_fn(|i| self.memory.get(addr_space, ptr + i as u32)) - } - - /// Writes `data` to the given cell. - /// - /// Returns the `RecordId` and previous data. - pub fn write_cell(&mut self, address_space: F, pointer: F, data: F) -> (RecordId, F) { - let (record_id, [data]) = self.write(address_space, pointer, [data]); - (record_id, data) - } - - pub fn write( - &mut self, - address_space: F, - pointer: F, - data: [F; N], - ) -> (RecordId, [F; N]) { - assert_ne!(address_space, F::ZERO); - let address_space_u32 = address_space.as_canonical_u32(); - let ptr_u32 = pointer.as_canonical_u32(); - assert!( - ptr_u32 < (1 << self.mem_config.pointer_max_bits), - "memory out of bounds: {ptr_u32:?}", - ); - - self.memory.write(address_space_u32, ptr_u32, data) - } - - pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory { + pub fn helper(&self) -> SharedMemoryHelper { let range_bus = self.range_checker.bus(); - MemoryAuxColsFactory { + SharedMemoryHelper { range_checker: self.range_checker.clone(), - timestamp_lt_air: AssertLtSubAir::new(range_bus, self.mem_config.clk_max_bits), + timestamp_lt_air: AssertLtSubAir::new( + range_bus, + self.memory_config().timestamp_max_bits, + ), _marker: Default::default(), } } - pub fn increment_timestamp(&mut self) { - self.memory.increment_timestamp_by(1); - } - - pub fn increment_timestamp_by(&mut self, change: u32) { - self.memory.increment_timestamp_by(change); - } - - pub fn timestamp(&self) -> u32 { - self.memory.timestamp() - } - - fn replay_access_log(&mut self) { - let log = mem::take(&mut self.memory.log); - if log.is_empty() { - // Online memory logs may be empty, but offline memory may be replayed from external - // sources. In these cases, we skip the calls to replay access logs because - // `set_log_capacity` would panic. - tracing::debug!("skipping replay_access_log"); - return; - } - - let mut offline_memory = self.offline_memory.lock().unwrap(); - offline_memory.set_log_capacity(log.len()); - - for entry in log { - Self::replay_access( - entry, - &mut offline_memory, - &mut self.interface_chip, - &mut self.access_adapters, - ); - } - } - - /// Low-level API to replay a single memory access log entry and populate the [OfflineMemory], - /// [MemoryInterface], and `AccessAdapterInventory`. - pub fn replay_access( - entry: MemoryLogEntry, - offline_memory: &mut OfflineMemory, - interface_chip: &mut MemoryInterface, - adapter_records: &mut AccessAdapterInventory, - ) { - match entry { - MemoryLogEntry::Read { - address_space, - pointer, - len, - } => { - if address_space != 0 { - interface_chip.touch_range(address_space, pointer, len as u32); - } - offline_memory.read(address_space, pointer, len, adapter_records); - } - MemoryLogEntry::Write { - address_space, - pointer, - data, - } => { - if address_space != 0 { - interface_chip.touch_range(address_space, pointer, data.len() as u32); - } - offline_memory.write(address_space, pointer, data, adapter_records); - } - MemoryLogEntry::IncrementTimestampBy(amount) => { - offline_memory.increment_timestamp_by(amount); - } - }; - } - - /// Returns the final memory state if persistent. - pub fn finalize(&mut self, hasher: Option<&mut H>) + // @dev: Memory is complicated and allowed to break all the rules (e.g., 1 arena per chip) and + // there's no need for any memory chip to implement the Chip trait. We do it when convenient, + // but all that matters is that you can tracegen all the trace matrices for the memory AIRs + // _somehow_. + pub fn generate_proving_ctx( + &mut self, + access_adapter_records: DenseRecordArena, + touched_memory: TouchedMemory, + ) -> Vec>> where - H: HasherChip + Sync + for<'a> SerialReceiver<&'a [F]>, + Domain: PolynomialSpace, { - if self.final_state.is_some() { - return; - } - - self.replay_access_log(); - let mut offline_memory = self.offline_memory.lock().unwrap(); - - match &mut self.interface_chip { - MemoryInterface::Volatile { boundary_chip } => { - let final_memory = offline_memory.finalize::<1>(&mut self.access_adapters); + match (&mut self.interface_chip, touched_memory) { + ( + MemoryInterface::Volatile { boundary_chip }, + TouchedMemory::Volatile(final_memory), + ) => { boundary_chip.finalize(final_memory); - self.final_state = Some(FinalState::Volatile(VolatileFinalState::default())); } - MemoryInterface::Persistent { - merkle_chip, - boundary_chip, - initial_memory, - } => { - let hasher = hasher.unwrap(); - let final_partition = offline_memory.finalize::(&mut self.access_adapters); - - boundary_chip.finalize(initial_memory, &final_partition, hasher); - let final_memory_values = final_partition + ( + MemoryInterface::Persistent { + boundary_chip, + merkle_chip, + initial_memory, + }, + TouchedMemory::Persistent(final_memory), + ) => { + let hasher = self.hasher_chip.as_ref().unwrap(); + boundary_chip.finalize(initial_memory, &final_memory, hasher.as_ref()); + let final_memory_values = final_memory .into_par_iter() .map(|(key, value)| (key, value.values)) .collect(); - let initial_node = MemoryNode::tree_from_memory( - merkle_chip.air.memory_dimensions, - initial_memory, - hasher, - ); - merkle_chip.finalize(&initial_node, &final_memory_values, hasher); - self.final_state = Some(FinalState::Persistent(PersistentFinalState { - final_memory: final_memory_values.clone(), - })); + merkle_chip.finalize(initial_memory, &final_memory_values, hasher.as_ref()); } - }; - } + _ => panic!("TouchedMemory incorrect type"), + } - pub fn generate_air_proof_inputs(self) -> Vec> - where - Domain: PolynomialSpace, - { let mut ret = Vec::new(); - let Self { - interface_chip, - access_adapters, - .. - } = self; - match interface_chip { + let access_adapters = &mut self.access_adapter_inventory; + access_adapters.set_arena(access_adapter_records); + match &mut self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { - ret.push(boundary_chip.generate_air_proof_input()); + ret.push(boundary_chip.generate_proving_ctx(())); } MemoryInterface::Persistent { merkle_chip, @@ -577,191 +314,66 @@ impl MemoryController { .. } => { debug_assert_eq!(ret.len(), BOUNDARY_AIR_OFFSET); - ret.push(boundary_chip.generate_air_proof_input()); + ret.push(boundary_chip.generate_proving_ctx(())); debug_assert_eq!(ret.len(), MERKLE_AIR_OFFSET); - ret.push(merkle_chip.generate_air_proof_input()); + ret.push(merkle_chip.generate_proving_ctx()); } } - ret.extend(access_adapters.generate_air_proof_inputs()); + ret.extend(access_adapters.generate_proving_ctx()); ret } - pub fn airs(&self) -> Vec> - where - Domain: PolynomialSpace, - { - let mut airs = Vec::>::new(); - - match &self.interface_chip { - MemoryInterface::Volatile { boundary_chip } => { - debug_assert_eq!(airs.len(), BOUNDARY_AIR_OFFSET); - airs.push(boundary_chip.air()) - } - MemoryInterface::Persistent { - boundary_chip, - merkle_chip, - .. - } => { - debug_assert_eq!(airs.len(), BOUNDARY_AIR_OFFSET); - airs.push(boundary_chip.air()); - debug_assert_eq!(airs.len(), MERKLE_AIR_OFFSET); - airs.push(merkle_chip.air()); - } - } - airs.extend(self.access_adapters.airs()); - - airs - } - /// Return the number of AIRs in the memory controller. pub fn num_airs(&self) -> usize { let mut num_airs = 1; if self.continuation_enabled() { num_airs += 1; } - num_airs += self.access_adapters.num_access_adapters(); + num_airs += self.access_adapter_inventory.num_access_adapters(); num_airs } - - pub fn air_names(&self) -> Vec { - let mut air_names = vec!["Boundary".to_string()]; - if self.continuation_enabled() { - air_names.push("Merkle".to_string()); - } - air_names.extend(self.access_adapters.air_names()); - air_names - } - - pub fn current_trace_heights(&self) -> Vec { - self.get_memory_trace_heights().flatten() - } - - pub fn get_memory_trace_heights(&self) -> MemoryTraceHeights { - let access_adapters = self.access_adapters.get_heights(); - match &self.interface_chip { - MemoryInterface::Volatile { boundary_chip } => { - MemoryTraceHeights::Volatile(VolatileMemoryTraceHeights { - boundary: boundary_chip.current_trace_height(), - access_adapters, - }) - } - MemoryInterface::Persistent { - boundary_chip, - merkle_chip, - .. - } => MemoryTraceHeights::Persistent(PersistentMemoryTraceHeights { - boundary: boundary_chip.current_trace_height(), - merkle: merkle_chip.current_trace_height(), - access_adapters, - }), - } - } - - pub fn get_dummy_memory_trace_heights(&self) -> MemoryTraceHeights { - let access_adapters = vec![1; self.access_adapters.num_access_adapters()]; - match &self.interface_chip { - MemoryInterface::Volatile { .. } => { - MemoryTraceHeights::Volatile(VolatileMemoryTraceHeights { - boundary: 1, - access_adapters, - }) - } - MemoryInterface::Persistent { .. } => { - MemoryTraceHeights::Persistent(PersistentMemoryTraceHeights { - boundary: 1, - merkle: 1, - access_adapters, - }) - } - } - } - - pub fn current_trace_cells(&self) -> Vec { - let mut ret = Vec::new(); - match &self.interface_chip { - MemoryInterface::Volatile { boundary_chip } => { - ret.push(boundary_chip.current_trace_cells()) - } - MemoryInterface::Persistent { - boundary_chip, - merkle_chip, - .. - } => { - ret.push(boundary_chip.current_trace_cells()); - ret.push(merkle_chip.current_trace_cells()); - } - } - ret.extend(self.access_adapters.get_cells()); - ret - } - - /// Returns a reference to the offline memory. - /// - /// Until `finalize` is called, the `OfflineMemory` does not contain useful state, and should - /// therefore not be used by any chip during execution. However, to obtain a reference to the - /// offline memory that will be useful in trace generation, a chip can call `offline_memory()` - /// and store the returned reference for later use. - pub fn offline_memory(&self) -> Arc>> { - self.offline_memory.clone() - } - pub fn get_memory_logs(&self) -> &Vec> { - &self.memory.log - } - pub fn set_memory_logs(&mut self, logs: Vec>) { - self.memory.log = logs; - } - pub fn take_memory_logs(&mut self) -> Vec> { - std::mem::take(&mut self.memory.log) - } } -pub struct MemoryAuxColsFactory { +/// Owned version of [MemoryAuxColsFactory]. +#[derive(Clone)] +pub struct SharedMemoryHelper { pub(crate) range_checker: SharedVariableRangeCheckerChip, pub(crate) timestamp_lt_air: AssertLtSubAir, - pub(crate) _marker: PhantomData, + pub(crate) _marker: PhantomData, } -// NOTE[jpw]: The `make_*_aux_cols` functions should be thread-safe so they can be used in -// parallelized trace generation. -impl MemoryAuxColsFactory { - pub fn generate_read_aux(&self, read: &MemoryRecord, buffer: &mut MemoryReadAuxCols) { - assert!( - !read.address_space.is_zero(), - "cannot make `MemoryReadAuxCols` for address space 0" - ); - self.generate_base_aux(read, &mut buffer.base); +impl SharedMemoryHelper { + pub fn new(range_checker: SharedVariableRangeCheckerChip, timestamp_max_bits: usize) -> Self { + let timestamp_lt_air = AssertLtSubAir::new(range_checker.bus(), timestamp_max_bits); + Self { + range_checker, + timestamp_lt_air, + _marker: PhantomData, + } } +} - pub fn generate_read_or_immediate_aux( - &self, - read: &MemoryRecord, - buffer: &mut MemoryReadOrImmediateAuxCols, - ) { - IsZeroSubAir.generate_subrow( - read.address_space, - (&mut buffer.is_zero_aux, &mut buffer.is_immediate), - ); - self.generate_base_aux(read, &mut buffer.base); - } +/// A helper for generating trace values in auxiliary memory columns related to the offline memory +/// argument. +pub struct MemoryAuxColsFactory<'a, F> { + pub(crate) range_checker: &'a VariableRangeCheckerChip, + pub(crate) timestamp_lt_air: AssertLtSubAir, + pub(crate) _marker: PhantomData, +} - pub fn generate_write_aux( - &self, - write: &MemoryRecord, - buffer: &mut MemoryWriteAuxCols, - ) { - buffer - .prev_data - .copy_from_slice(write.prev_data_slice().unwrap()); - self.generate_base_aux(write, &mut buffer.base); +impl MemoryAuxColsFactory<'_, F> { + /// Fill the trace assuming `prev_timestamp` is already provided in `buffer`. + pub fn fill(&self, prev_timestamp: u32, timestamp: u32, buffer: &mut MemoryBaseAuxCols) { + self.generate_timestamp_lt(prev_timestamp, timestamp, &mut buffer.timestamp_lt_aux); + // Safety: even if prev_timestamp were obtained by transmute_ref from + // `buffer.prev_timestamp`, this should still work because it is a direct assignment + buffer.prev_timestamp = F::from_canonical_u32(prev_timestamp); } - pub fn generate_base_aux(&self, record: &MemoryRecord, buffer: &mut MemoryBaseAuxCols) { - buffer.prev_timestamp = F::from_canonical_u32(record.prev_timestamp); - self.generate_timestamp_lt( - record.prev_timestamp, - record.timestamp, - &mut buffer.timestamp_lt_aux, - ); + /// # Safety + /// We assume that `F::ZERO` has underlying memory equivalent to `mem::zeroed()`. + pub fn fill_zero(&self, buffer: &mut MemoryBaseAuxCols) { + *buffer = unsafe { std::mem::zeroed() }; } fn generate_timestamp_lt( @@ -770,102 +382,23 @@ impl MemoryAuxColsFactory { timestamp: u32, buffer: &mut LessThanAuxCols, ) { - debug_assert!(prev_timestamp < timestamp); - self.timestamp_lt_air.generate_subrow( - (self.range_checker.as_ref(), prev_timestamp, timestamp), - &mut buffer.lower_decomp, + debug_assert!( + prev_timestamp < timestamp, + "prev_timestamp {prev_timestamp} >= timestamp {timestamp}" ); - } - - /// In general, prefer `generate_read_aux` which writes in-place rather than this function. - pub fn make_read_aux_cols(&self, read: &MemoryRecord) -> MemoryReadAuxCols { - assert!( - !read.address_space.is_zero(), - "cannot make `MemoryReadAuxCols` for address space 0" - ); - MemoryReadAuxCols::new( - read.prev_timestamp, - self.generate_timestamp_lt_cols(read.prev_timestamp, read.timestamp), - ) - } - - /// In general, prefer `generate_write_aux` which writes in-place rather than this function. - pub fn make_write_aux_cols( - &self, - write: &MemoryRecord, - ) -> MemoryWriteAuxCols { - let prev_data = write.prev_data_slice().unwrap(); - MemoryWriteAuxCols::new( - prev_data.try_into().unwrap(), - F::from_canonical_u32(write.prev_timestamp), - self.generate_timestamp_lt_cols(write.prev_timestamp, write.timestamp), - ) - } - - fn generate_timestamp_lt_cols( - &self, - prev_timestamp: u32, - timestamp: u32, - ) -> LessThanAuxCols { - debug_assert!(prev_timestamp < timestamp); - let mut decomp = [F::ZERO; AUX_LEN]; self.timestamp_lt_air.generate_subrow( - (self.range_checker.as_ref(), prev_timestamp, timestamp), - &mut decomp, + (self.range_checker, prev_timestamp, timestamp), + &mut buffer.lower_decomp, ); - LessThanAuxCols::new(decomp) } } -#[cfg(test)] -mod tests { - use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, - }; - use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{prelude::SliceRandom, thread_rng, Rng}; - - use super::MemoryController; - use crate::{ - arch::{testing::MEMORY_BUS, MemoryConfig}, - system::memory::offline_checker::MemoryBus, - }; - - const RANGE_CHECKER_BUS: BusIndex = 3; - - #[test] - fn test_no_adapter_records_for_singleton_accesses() { - type F = BabyBear; - - let memory_bus = MemoryBus::new(MEMORY_BUS); - let memory_config = MemoryConfig::default(); - let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp); - let range_checker = SharedVariableRangeCheckerChip::new(range_bus); - - let mut memory_controller = MemoryController::with_volatile_memory( - memory_bus, - memory_config, - range_checker.clone(), - ); - - let mut rng = thread_rng(); - for _ in 0..1000 { - let address_space = F::from_canonical_u32(*[1, 2].choose(&mut rng).unwrap()); - let pointer = - F::from_canonical_u32(rng.gen_range(0..1 << memory_config.pointer_max_bits)); - - if rng.gen_bool(0.5) { - let data = F::from_canonical_u32(rng.gen_range(0..1 << 30)); - memory_controller.write(address_space, pointer, [data]); - } else { - memory_controller.read::<1>(address_space, pointer); - } +impl SharedMemoryHelper { + pub fn as_borrowed(&self) -> MemoryAuxColsFactory<'_, F> { + MemoryAuxColsFactory { + range_checker: self.range_checker.as_ref(), + timestamp_lt_air: self.timestamp_lt_air, + _marker: PhantomData, } - assert!(memory_controller - .access_adapters - .get_heights() - .iter() - .all(|&h| h == 0)); } } diff --git a/crates/vm/src/system/memory/merkle/mod.rs b/crates/vm/src/system/memory/merkle/mod.rs index 74f8951bc4..6d974ddde0 100644 --- a/crates/vm/src/system/memory/merkle/mod.rs +++ b/crates/vm/src/system/memory/merkle/mod.rs @@ -1,27 +1,39 @@ -use openvm_stark_backend::{interaction::PermutationCheckBus, p3_field::PrimeField32}; -use rustc_hash::FxHashSet; +use std::array; + +use openvm_stark_backend::{ + interaction::PermutationCheckBus, p3_field::PrimeField32, p3_maybe_rayon::prelude::*, +}; + +use super::{controller::dimensions::MemoryDimensions, online::LinearMemory}; +use crate::{ + arch::AddressSpaceHostLayout, + system::memory::{online::PAGE_SIZE, AddressMap}, +}; -use super::controller::dimensions::MemoryDimensions; mod air; mod columns; +pub mod public_values; mod trace; +mod tree; pub use air::*; pub use columns::*; pub(super) use trace::SerialReceiver; +pub use tree::*; #[cfg(test)] mod tests; pub struct MemoryMerkleChip { pub air: MemoryMerkleAir, - touched_nodes: FxHashSet<(usize, u32, u32)>, - num_touched_nonleaves: usize, final_state: Option>, overridden_height: Option, + /// Used for metric collection purposes only + #[cfg(feature = "metrics")] + pub(crate) current_height: usize, } #[derive(Debug)] -struct FinalState { +pub struct FinalState { rows: Vec>, init_root: [F; CHUNK], final_root: [F; CHUNK], @@ -35,46 +47,76 @@ impl MemoryMerkleChip { merkle_bus: PermutationCheckBus, compression_bus: PermutationCheckBus, ) -> Self { - assert!(memory_dimensions.as_height > 0); + assert!(memory_dimensions.addr_space_height > 0); assert!(memory_dimensions.address_height > 0); - let mut touched_nodes = FxHashSet::default(); - touched_nodes.insert((memory_dimensions.overall_height(), 0, 0)); Self { air: MemoryMerkleAir { memory_dimensions, merkle_bus, compression_bus, }, - touched_nodes, - num_touched_nonleaves: 1, final_state: None, overridden_height: None, + #[cfg(feature = "metrics")] + current_height: 0, } } pub fn set_overridden_height(&mut self, override_height: usize) { self.overridden_height = Some(override_height); } +} - fn touch_node(&mut self, height: usize, as_label: u32, address_label: u32) { - if self.touched_nodes.insert((height, as_label, address_label)) { - assert_ne!(height, self.air.memory_dimensions.overall_height()); - if height != 0 { - self.num_touched_nonleaves += 1; - } - if height >= self.air.memory_dimensions.address_height { - self.touch_node(height + 1, as_label / 2, address_label); - } else { - self.touch_node(height + 1, as_label, address_label / 2); - } - } - } +#[tracing::instrument(level = "info", skip_all)] +fn memory_to_vec_partition( + memory: &AddressMap, + md: &MemoryDimensions, +) -> Vec<(u64, [F; N])> { + (0..memory.mem.len()) + .into_par_iter() + .map(move |as_idx| { + let space_mem = memory.mem[as_idx].as_slice(); + let addr_space_layout = memory.config[as_idx].layout; + let cell_size = addr_space_layout.size(); + debug_assert_eq!(PAGE_SIZE % (cell_size * N), 0); - pub fn touch_range(&mut self, address_space: u32, address: u32, len: u32) { - let as_label = address_space - self.air.memory_dimensions.as_offset; - let first_address_label = address / CHUNK as u32; - let last_address_label = (address + len - 1) / CHUNK as u32; - for address_label in first_address_label..=last_address_label { - self.touch_node(0, as_label, address_label); - } - } + let num_nonzero_pages = space_mem + .par_chunks(PAGE_SIZE) + .enumerate() + .flat_map(|(idx, page)| { + if page.iter().any(|x| *x != 0) { + Some(idx + 1) + } else { + None + } + }) + .max() + .unwrap_or(0); + + let space_mem = &space_mem[..(num_nonzero_pages * PAGE_SIZE).min(space_mem.len())]; + let mut num_elements = space_mem.len() / (cell_size * N); + // virtual memory may be larger than dimensions due to rounding up to page size + num_elements = num_elements.min(1 << md.address_height); + + (0..num_elements) + .into_par_iter() + .map(move |idx| { + ( + md.label_to_index((as_idx as u32, idx as u32)), + array::from_fn(|i| unsafe { + // SAFETY: idx < num_elements = space_mem.len() / (cell_size * N) so ptr + // is within bounds. We are reading one cell at a time, so alignment is + // guaranteed. + let ptr: *const u8 = + space_mem.as_ptr().add(idx * cell_size * N + i * cell_size); + addr_space_layout + .to_field(&*core::ptr::slice_from_raw_parts(ptr, cell_size)) + }), + ) + }) + .collect::>() + }) + .collect::>() + .into_iter() + .flatten() + .collect::>() } diff --git a/crates/vm/src/system/memory/tree/public_values.rs b/crates/vm/src/system/memory/merkle/public_values.rs similarity index 65% rename from crates/vm/src/system/memory/tree/public_values.rs rename to crates/vm/src/system/memory/merkle/public_values.rs index 1c6866b959..e0f079c799 100644 --- a/crates/vm/src/system/memory/tree/public_values.rs +++ b/crates/vm/src/system/memory/merkle/public_values.rs @@ -1,17 +1,17 @@ -use std::{collections::BTreeMap, sync::Arc}; - use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize}; use serde::{Deserialize, Serialize}; use thiserror::Error; +use tracing::instrument; use crate::{ - arch::hasher::Hasher, + arch::{hasher::Hasher, MemoryCellType, ADDR_SPACE_OFFSET}, system::memory::{ - dimensions::MemoryDimensions, paged_vec::Address, tree::MemoryNode, MemoryImage, + dimensions::MemoryDimensions, merkle::tree::MerkleTree, online::LinearMemory, MemoryImage, }, }; -pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = 2; +pub const PUBLIC_VALUES_AS: u32 = 3; +pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET; /// Merkle proof for user public values in the memory state. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -47,11 +47,14 @@ impl UserPublicValuesProof { /// Computes the proof of the public values from the final memory state. /// Assumption: /// - `num_public_values` is a power of two * CHUNK. It cannot be 0. + // TODO[jpw]: this currently reconstructs the merkle tree from final memory; we should avoid + // this. We should make this a function within SystemChipComplex + #[instrument(name = "compute_user_public_values_proof", skip_all)] pub fn compute( memory_dimensions: MemoryDimensions, num_public_values: usize, hasher: &(impl Hasher + Sync), - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Self { let proof = compute_merkle_proof_to_user_public_values_root( memory_dimensions, @@ -59,8 +62,7 @@ impl UserPublicValuesProof { hasher, final_memory, ); - let public_values = - extract_public_values(&memory_dimensions, num_public_values, final_memory); + let public_values = extract_public_values(num_public_values, final_memory); let public_values_commit = hasher.merkle_root(&public_values); UserPublicValuesProof { proof, @@ -81,7 +83,7 @@ impl UserPublicValuesProof { // 2. Compare user public values commitment with Merkle root of user public values. let pv_commit = self.public_values_commit; // 0. - let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; + let pv_as = PUBLIC_VALUES_AS; let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0)); let pvs = &self.public_values; if pvs.len() % CHUNK != 0 || !(pvs.len() / CHUNK).is_power_of_two() { @@ -121,14 +123,14 @@ fn compute_merkle_proof_to_user_public_values_root + Sync), - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Vec<[F; CHUNK]> { assert_eq!( num_public_values % CHUNK, 0, "num_public_values must be a multiple of memory chunk {CHUNK}" ); - let root = MemoryNode::tree_from_memory(memory_dimensions, final_memory, hasher); + let tree = MerkleTree::::from_memory(final_memory, &memory_dimensions, hasher); let num_pv_chunks: usize = num_public_values / CHUNK; // This enforces the number of public values cannot be 0. assert!( @@ -138,63 +140,50 @@ fn compute_merkle_proof_to_user_public_values_root( - memory_dimensions: &MemoryDimensions, num_public_values: usize, - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Vec { - // All (addr, value) pairs in the public value address space. - let f_as_start = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; - let f_as_end = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset + 1; - - // This clones the entire memory. Ideally this should run in time proportional to - // the size of the PV address space, not entire memory. - let final_memory: BTreeMap = final_memory.items().collect(); - - let used_pvs: Vec<_> = final_memory - .range((f_as_start, 0)..(f_as_end, 0)) - .map(|(&(_, pointer), &value)| (pointer as usize, value)) - .collect(); - if let Some(&last_pv) = used_pvs.last() { - assert!( - last_pv.0 < num_public_values || last_pv.1 == F::ZERO, - "Last public value is out of bounds" + let mut public_values: Vec = { + assert_eq!( + final_memory.config[PUBLIC_VALUES_AS as usize].layout, + MemoryCellType::U8 ); - } - let mut public_values = F::zero_vec(num_public_values); - for (i, pv) in used_pvs { - if i < num_public_values { - public_values[i] = pv; - } - } + final_memory.mem[PUBLIC_VALUES_AS as usize] + .as_slice() + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect() + }; + + assert!( + public_values.len() >= num_public_values, + "Public values address space has {} elements, but configuration has num_public_values={}", + public_values.len(), + num_public_values + ); + public_values.truncate(num_public_values); public_values } @@ -203,27 +192,32 @@ mod tests { use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; - use super::{UserPublicValuesProof, PUBLIC_VALUES_ADDRESS_SPACE_OFFSET}; + use super::UserPublicValuesProof; use crate::{ - arch::{hasher::poseidon2::vm_poseidon2_hasher, SystemConfig}, - system::memory::{paged_vec::AddressMap, tree::MemoryNode, CHUNK}, + arch::{hasher::poseidon2::vm_poseidon2_hasher, MemoryConfig, SystemConfig}, + system::memory::{ + merkle::{public_values::PUBLIC_VALUES_AS, tree::MerkleTree}, + online::GuestMemory, + AddressMap, CHUNK, + }, }; type F = BabyBear; #[test] fn test_public_value_happy_path() { let mut vm_config = SystemConfig::default(); - vm_config.memory_config.as_height = 4; + vm_config.memory_config.addr_space_height = 4; vm_config.memory_config.pointer_max_bits = 5; let memory_dimensions = vm_config.memory_config.memory_dimensions(); - let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; let num_public_values = 16; - let memory = AddressMap::from_iter( - memory_dimensions.as_offset, - 1 << memory_dimensions.as_height, - 1 << memory_dimensions.address_height, - [((pv_as, 15), F::ONE)], - ); + let mut addr_spaces_config = MemoryConfig::empty_address_space_configs(4); + addr_spaces_config[PUBLIC_VALUES_AS as usize].num_cells = num_public_values; + let mut memory = GuestMemory { + memory: AddressMap::new(addr_spaces_config), + }; + unsafe { + memory.write::(PUBLIC_VALUES_AS, 12, [0, 0, 0, 1]); + } let mut expected_pvs = F::zero_vec(num_public_values); expected_pvs[15] = F::ONE; @@ -232,12 +226,13 @@ mod tests { memory_dimensions, num_public_values, &hasher, - &memory, + &memory.memory, ); assert_eq!(pv_proof.public_values, expected_pvs); - let final_memory_root = MemoryNode::tree_from_memory(memory_dimensions, &memory, &hasher); + let final_memory_root = + MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher).root(); pv_proof - .verify(&hasher, memory_dimensions, final_memory_root.hash()) + .verify(&hasher, memory_dimensions, final_memory_root) .unwrap(); } } diff --git a/crates/vm/src/system/memory/merkle/tests/mod.rs b/crates/vm/src/system/memory/merkle/tests/mod.rs index 05c966dc23..09d996393e 100644 --- a/crates/vm/src/system/memory/merkle/tests/mod.rs +++ b/crates/vm/src/system/memory/merkle/tests/mod.rs @@ -1,7 +1,7 @@ use std::{ array, borrow::BorrowMut, - collections::{BTreeSet, HashSet}, + collections::{BTreeMap, BTreeSet, HashSet}, sync::Arc, }; @@ -9,8 +9,7 @@ use openvm_stark_backend::{ interaction::{PermutationCheckBus, PermutationInteractionType}, p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, - Chip, ChipUsageGetter, + prover::types::AirProvingContext, }; use openvm_stark_sdk::{ config::baby_bear_poseidon2::BabyBearPoseidon2Engine, @@ -20,84 +19,83 @@ use openvm_stark_sdk::{ use rand::RngCore; use crate::{ - arch::testing::{MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS}, + arch::{ + testing::{MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS}, + AddressSpaceHostConfig, MemoryCellType, MemoryConfig, ADDR_SPACE_OFFSET, + }, system::memory::{ merkle::{ - columns::MemoryMerkleCols, tests::util::HashTestChip, MemoryDimensions, - MemoryMerkleChip, + memory_to_vec_partition, tests::util::HashTestChip, MemoryDimensions, MemoryMerkleChip, + MemoryMerkleCols, MerkleTree, }, - paged_vec::{AddressMap, PAGE_SIZE}, - tree::MemoryNode, - Equipartition, MemoryImage, + online::{GuestMemory, LinearMemory}, + AddressMap, MemoryImage, }, }; mod util; -const DEFAULT_CHUNK: usize = 8; +const CHUNK: usize = 8; const COMPRESSION_BUS: PermutationCheckBus = PermutationCheckBus::new(POSEIDON2_DIRECT_BUS); +type F = BabyBear; -fn test( +fn test( memory_dimensions: MemoryDimensions, - initial_memory: &MemoryImage, + initial_memory: &MemoryImage, touched_labels: BTreeSet<(u32, u32)>, - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) { let MemoryDimensions { - as_height, + addr_space_height, address_height, - as_offset, } = memory_dimensions; + let merkle_bus = PermutationCheckBus::new(MEMORY_MERKLE_BUS); - // checking validity of test data - for ((address_space, pointer), value) in final_memory.items() { - let label = pointer / CHUNK as u32; - assert!(address_space - as_offset < (1 << as_height)); - assert!(pointer < ((CHUNK << address_height).div_ceil(PAGE_SIZE) * PAGE_SIZE) as u32); - if initial_memory.get(&(address_space, pointer)) != Some(&value) { - assert!(touched_labels.contains(&(address_space, label))); - } - } - for key in initial_memory.items().map(|(key, _)| key) { - assert!(final_memory.get(&key).is_some()); - } - for &(address_space, label) in touched_labels.iter() { - let mut contains_some_key = false; - for i in 0..CHUNK { - if final_memory - .get(&(address_space, label * CHUNK as u32 + i as u32)) - .is_some() - { - contains_some_key = true; - break; + for address_space in 0..final_memory.config.len() { + for pointer in 0..final_memory.mem[address_space].size() / 4 { + if unsafe { + initial_memory.get_f::(address_space as u32, pointer as u32) + != final_memory.get_f(address_space as u32, pointer as u32) + } { + let label = (pointer / CHUNK) as u32; + assert!(address_space - (ADDR_SPACE_OFFSET as usize) < (1 << addr_space_height)); + assert!(pointer < (CHUNK << address_height)); + assert!(touched_labels.contains(&(address_space as u32, label))); } } - assert!(contains_some_key); } let mut hash_test_chip = HashTestChip::new(); - let initial_tree = - MemoryNode::tree_from_memory(memory_dimensions, initial_memory, &hash_test_chip); let final_tree_check = - MemoryNode::tree_from_memory(memory_dimensions, final_memory, &hash_test_chip); + MerkleTree::from_memory(final_memory, &memory_dimensions, &hash_test_chip); let mut chip = MemoryMerkleChip::::new(memory_dimensions, merkle_bus, COMPRESSION_BUS); - for &(address_space, label) in touched_labels.iter() { - chip.touch_range(address_space, label * CHUNK as u32, CHUNK as u32); - } + let final_partition: BTreeMap<_, [F; CHUNK]> = + memory_to_vec_partition::(final_memory, &memory_dimensions) + .into_iter() + .map(|(idx, values)| { + let address_space = + (idx >> memory_dimensions.address_height) as u32 + ADDR_SPACE_OFFSET; + let label = (idx & ((1 << memory_dimensions.address_height) - 1)) as u32; + ((address_space, label * (CHUNK as u32)), values) + }) + .collect(); + let final_partition = final_partition + .into_iter() + .filter(|((address_space, pointer), _)| { + touched_labels.contains(&(*address_space, pointer / CHUNK as u32)) + }) + .collect(); + chip.finalize(initial_memory, &final_partition, &hash_test_chip); - let final_partition = memory_to_partition(final_memory); - println!("trace height = {}", chip.current_trace_height()); - chip.finalize(&initial_tree, &final_partition, &mut hash_test_chip); assert_eq!( chip.final_state.as_ref().unwrap().final_root, - final_tree_check.hash() + final_tree_check.root() ); - let chip_air = chip.air(); - let chip_api = chip.generate_air_proof_input(); + let chip_api = chip.generate_proving_ctx(); let dummy_interaction_air = DummyInteractionAir::new(4 + CHUNK, true, merkle_bus.index); let mut dummy_interaction_trace_rows = vec![]; @@ -126,13 +124,12 @@ fn test( }; for (address_space, address_label) in touched_labels { - let initial_values = array::from_fn(|i| { - initial_memory - .get(&(address_space, address_label * CHUNK as u32 + i as u32)) - .copied() - .unwrap_or_default() - }); - let as_label = address_space - as_offset; + let initial_values = unsafe { + array::from_fn(|i| { + initial_memory.get((address_space, address_label * CHUNK as u32 + i as u32)) + }) + }; + let as_label = address_space - ADDR_SPACE_OFFSET; interaction( PermutationInteractionType::Send, false, @@ -142,7 +139,7 @@ fn test( initial_values, ); let final_values = *final_partition - .get(&(address_space, address_label)) + .get(&(address_space, address_label * (CHUNK as u32))) .unwrap(); interaction( PermutationInteractionType::Send, @@ -163,38 +160,24 @@ fn test( dummy_interaction_trace_rows, dummy_interaction_air.field_width() + 1, ); - let dummy_interaction_api = AirProofInput::simple_no_pis(dummy_interaction_trace); + let dummy_interaction_api = AirProvingContext::simple_no_pis(Arc::new(dummy_interaction_trace)); BabyBearPoseidon2Engine::run_test_fast( vec![ - chip_air, + Arc::new(chip.air), Arc::new(dummy_interaction_air), Arc::new(hash_test_chip.air()), ], vec![ chip_api, dummy_interaction_api, - hash_test_chip.generate_air_proof_input(), + hash_test_chip.generate_proving_ctx(), ], ) .expect("Verification failed"); } -fn memory_to_partition( - memory: &MemoryImage, -) -> Equipartition { - let mut memory_partition = Equipartition::new(); - for ((address_space, pointer), value) in memory.items() { - let label = (address_space, pointer / N as u32); - let chunk = memory_partition - .entry(label) - .or_insert_with(|| [F::default(); N]); - chunk[(pointer % N as u32) as usize] = value; - } - memory_partition -} - -fn random_test( +fn random_test( height: usize, max_value: u32, mut num_initial_addresses: usize, @@ -203,8 +186,34 @@ fn random_test( let mut rng = create_seeded_rng(); let mut next_u32 = || rng.next_u64() as u32; - let mut initial_memory = AddressMap::new(1, 2, CHUNK << height); - let mut final_memory = AddressMap::new(1, 2, CHUNK << height); + let mem_config = MemoryConfig::new( + 1, + vec![ + AddressSpaceHostConfig { + num_cells: 0, + min_block_size: 0, + layout: MemoryCellType::Null, + }, + AddressSpaceHostConfig { + num_cells: CHUNK << height, + min_block_size: 1, + layout: MemoryCellType::Native { size: 4 }, + }, + AddressSpaceHostConfig { + num_cells: CHUNK << height, + min_block_size: 1, + layout: MemoryCellType::Native { size: 4 }, + }, + ], + height + 3, + 20, + 17, + 32, + ); + + let mut initial_memory = GuestMemory::new(AddressMap::from_mem_config(&mem_config)); + let mut final_memory = GuestMemory::new(AddressMap::from_mem_config(&mem_config)); + let mut seen = HashSet::new(); let mut touched_labels = BTreeSet::new(); @@ -221,132 +230,155 @@ fn random_test( if is_initial && num_initial_addresses != 0 { num_initial_addresses -= 1; let value = BabyBear::from_canonical_u32(next_u32() % max_value); - initial_memory.insert(&(address_space, pointer), value); - final_memory.insert(&(address_space, pointer), value); + unsafe { + initial_memory.write(address_space, pointer, [value]); + final_memory.write(address_space, pointer, [value]); + } } if is_touched && num_touched_addresses != 0 { num_touched_addresses -= 1; touched_labels.insert((address_space, label)); if value_changes || !is_initial { let value = BabyBear::from_canonical_u32(next_u32() % max_value); - final_memory.insert(&(address_space, pointer), value); + unsafe { + final_memory.write(address_space, pointer, [value]); + } } } } } - test::( + test( MemoryDimensions { - as_height: 1, + addr_space_height: 1, address_height: height, - as_offset: 1, }, - &initial_memory, + &initial_memory.memory, touched_labels, - &final_memory, + &final_memory.memory, ); } #[test] fn expand_test_0() { - random_test::(2, 3000, 2, 3); + random_test(2, 3000, 2, 3); } #[test] fn expand_test_1() { - random_test::(10, 3000, 400, 30); + random_test(10, 3000, 400, 30); } #[test] fn expand_test_2() { - random_test::(3, 3000, 3, 2); + random_test(3, 3000, 3, 2); } #[test] fn expand_test_no_accesses() { - let memory_dimensions = MemoryDimensions { - as_height: 2, - address_height: 1, - as_offset: 7, - }; let mut hash_test_chip = HashTestChip::new(); + let height = 1; - let memory = AddressMap::new( - memory_dimensions.as_offset, - 1 << memory_dimensions.as_height, - 1 << memory_dimensions.address_height, - ); - let tree = MemoryNode::::tree_from_memory( - memory_dimensions, - &memory, - &hash_test_chip, + let mem_config = MemoryConfig::new( + 1, + vec![ + AddressSpaceHostConfig { + num_cells: 0, + min_block_size: 0, + layout: MemoryCellType::Null, + }, + AddressSpaceHostConfig { + num_cells: CHUNK << height, + min_block_size: 1, + layout: MemoryCellType::Native { size: 4 }, + }, + AddressSpaceHostConfig { + num_cells: CHUNK << height, + min_block_size: 1, + layout: MemoryCellType::Native { size: 4 }, + }, + ], + height + 3, + 20, + 17, + 32, ); + let md = mem_config.memory_dimensions(); - let mut chip: MemoryMerkleChip = MemoryMerkleChip::new( - memory_dimensions, + let memory = AddressMap::from_mem_config(&mem_config); + + let mut chip: MemoryMerkleChip = MemoryMerkleChip::new( + md, PermutationCheckBus::new(MEMORY_MERKLE_BUS), COMPRESSION_BUS, ); - let partition = memory_to_partition(&memory); - chip.finalize(&tree, &partition, &mut hash_test_chip); + chip.finalize(&memory, &BTreeMap::new(), &hash_test_chip); + let trace = chip.generate_proving_ctx(); BabyBearPoseidon2Engine::run_test_fast( - vec![chip.air(), Arc::new(hash_test_chip.air())], - vec![ - chip.generate_air_proof_input(), - hash_test_chip.generate_air_proof_input(), - ], + vec![Arc::new(chip.air), Arc::new(hash_test_chip.air())], + vec![trace, hash_test_chip.generate_proving_ctx()], ) - .expect("This should occur"); + .expect("Empty touched memory doesn't work"); } #[test] #[should_panic] fn expand_test_negative() { - let memory_dimensions = MemoryDimensions { - as_height: 2, - address_height: 1, - as_offset: 7, - }; - let mut hash_test_chip = HashTestChip::new(); + let height = 1; - let memory = AddressMap::new( - memory_dimensions.as_offset, - 1 << memory_dimensions.as_height, - 1 << memory_dimensions.address_height, - ); - let tree = MemoryNode::::tree_from_memory( - memory_dimensions, - &memory, - &hash_test_chip, + let mem_config = MemoryConfig::new( + 1, + vec![ + AddressSpaceHostConfig { + num_cells: 0, + min_block_size: 0, + layout: MemoryCellType::Null, + }, + AddressSpaceHostConfig { + num_cells: CHUNK << height, + min_block_size: 1, + layout: MemoryCellType::Native { size: 4 }, + }, + AddressSpaceHostConfig { + num_cells: CHUNK << height, + min_block_size: 1, + layout: MemoryCellType::Native { size: 4 }, + }, + ], + height + 3, + 20, + 17, + 32, ); + let md = mem_config.memory_dimensions(); + + let memory = AddressMap::from_mem_config(&mem_config); - let mut chip = MemoryMerkleChip::::new( - memory_dimensions, + let mut chip: MemoryMerkleChip = MemoryMerkleChip::new( + md, PermutationCheckBus::new(MEMORY_MERKLE_BUS), COMPRESSION_BUS, ); - let partition = memory_to_partition(&memory); - chip.finalize(&tree, &partition, &mut hash_test_chip); - let air = chip.air(); - let mut chip_api = chip.generate_air_proof_input(); + chip.finalize(&memory, &BTreeMap::new(), &hash_test_chip); + let mut chip_ctx = chip.generate_proving_ctx(); { - let trace = chip_api.raw.common_main.as_mut().unwrap(); + let mut trace = (*chip_ctx.clone().common_main.unwrap()).clone(); for row in trace.rows_mut() { - let row: &mut MemoryMerkleCols<_, DEFAULT_CHUNK> = row.borrow_mut(); + let row: &mut MemoryMerkleCols<_, CHUNK> = row.borrow_mut(); if row.expand_direction == BabyBear::NEG_ONE { row.left_direction_different = BabyBear::ZERO; row.right_direction_different = BabyBear::ZERO; } } + chip_ctx.common_main.replace(Arc::new(trace)); } - let hash_air = Arc::new(hash_test_chip.air()); BabyBearPoseidon2Engine::run_test_fast( - vec![air, hash_air], - vec![chip_api, hash_test_chip.generate_air_proof_input()], + vec![Arc::new(chip.air), Arc::new(hash_test_chip.air())], + vec![chip_ctx, hash_test_chip.generate_proving_ctx()], ) - .expect("This should occur"); + .expect("We tinkered with the trace and now it doesn't pass"); } diff --git a/crates/vm/src/system/memory/merkle/tests/util.rs b/crates/vm/src/system/memory/merkle/tests/util.rs index c838fa06db..d976979d6b 100644 --- a/crates/vm/src/system/memory/merkle/tests/util.rs +++ b/crates/vm/src/system/memory/merkle/tests/util.rs @@ -1,4 +1,7 @@ -use std::array::from_fn; +use std::{ + array::from_fn, + sync::{Arc, Mutex}, +}; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, @@ -6,7 +9,7 @@ use openvm_stark_backend::{ p3_commit::PolynomialSpace, p3_field::Field, p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, }; use openvm_stark_sdk::dummy_airs::interaction::dummy_interaction_air::DummyInteractionAir; @@ -23,12 +26,14 @@ pub fn test_hash_sum( } pub struct HashTestChip { - requests: Vec<[[F; CHUNK]; 3]>, + requests: Mutex>, } impl HashTestChip { pub fn new() -> Self { - Self { requests: vec![] } + Self { + requests: Mutex::new(vec![]), + } } pub fn air(&self) -> DummyInteractionAir { @@ -37,7 +42,8 @@ impl HashTestChip { pub fn trace(&self) -> RowMajorMatrix { let mut rows = vec![]; - for request in self.requests.iter() { + let requests = self.requests.lock().expect("mutex poisoned"); + for request in requests.iter() { rows.push(F::ONE); rows.extend(request.iter().flatten()); } @@ -47,11 +53,12 @@ impl HashTestChip { } RowMajorMatrix::new(rows, width) } - pub fn generate_air_proof_input(&self) -> AirProofInput + pub fn generate_proving_ctx(&mut self) -> AirProvingContext> where + SC: StarkGenericConfig, Domain: PolynomialSpace, { - AirProofInput::simple_no_pis(self.trace()) + AirProvingContext::simple_no_pis(Arc::new(self.trace())) } } @@ -60,10 +67,12 @@ impl Hasher for HashTestChip { test_hash_sum(*left, *right) } } + impl HasherChip for HashTestChip { - fn compress_and_record(&mut self, left: &[F; CHUNK], right: &[F; CHUNK]) -> [F; CHUNK] { + fn compress_and_record(&self, left: &[F; CHUNK], right: &[F; CHUNK]) -> [F; CHUNK] { let result = test_hash_sum(*left, *right); - self.requests.push([*left, *right, result]); + let mut requests = self.requests.lock().expect("mutex poisoned"); + requests.push([*left, *right, result]); result } } diff --git a/crates/vm/src/system/memory/merkle/trace.rs b/crates/vm/src/system/memory/merkle/trace.rs index 52609f259a..f6135e014d 100644 --- a/crates/vm/src/system/memory/merkle/trace.rs +++ b/crates/vm/src/system/memory/merkle/trace.rs @@ -1,26 +1,24 @@ use std::{ borrow::BorrowMut, - cmp::Reverse, sync::{atomic::AtomicU32, Arc}, }; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::{FieldAlgebra, PrimeField32}, + config::{Domain, StarkGenericConfig, Val}, + p3_commit::PolynomialSpace, + p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, + prover::{cpu::CpuBackend, types::AirProvingContext}, + ChipUsageGetter, }; -use rustc_hash::FxHashSet; +use tracing::instrument; use crate::{ arch::hasher::HasherChip, system::{ memory::{ - controller::dimensions::MemoryDimensions, - merkle::{FinalState, MemoryMerkleChip, MemoryMerkleCols}, - tree::MemoryNode::{self, NonLeaf}, - Equipartition, + merkle::{tree::MerkleTree, FinalState, MemoryMerkleChip, MemoryMerkleCols}, + Equipartition, MemoryImage, }, poseidon2::{ Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH, @@ -29,51 +27,28 @@ use crate::{ }; impl MemoryMerkleChip { - pub fn finalize( + #[instrument(name = "merkle_finalize", level = "debug", skip_all)] + pub(crate) fn finalize( &mut self, - initial_tree: &MemoryNode, + initial_memory: &MemoryImage, final_memory: &Equipartition, - hasher: &mut impl HasherChip, + hasher: &impl HasherChip, ) { assert!(self.final_state.is_none(), "Merkle chip already finalized"); - // there needs to be a touched node with `height_section` = 0 - // shouldn't be a leaf because - // trace generation will expect an interaction from MemoryInterfaceChip in that case - if self.touched_nodes.len() == 1 { - self.touch_node(1, 0, 0); - } - - let mut rows = vec![]; - let mut tree_helper = TreeHelper { - memory_dimensions: self.air.memory_dimensions, - final_memory, - touched_nodes: &self.touched_nodes, - trace_rows: &mut rows, - }; - let final_tree = tree_helper.recur( - self.air.memory_dimensions.overall_height(), - initial_tree, - 0, - 0, - hasher, - ); - self.final_state = Some(FinalState { - rows, - init_root: initial_tree.hash(), - final_root: final_tree.hash(), - }); + let mut tree = MerkleTree::from_memory(initial_memory, &self.air.memory_dimensions, hasher); + self.final_state = Some(tree.finalize(hasher, final_memory, &self.air.memory_dimensions)); } } -impl Chip for MemoryMerkleChip> +impl MemoryMerkleChip where - Val: PrimeField32, + F: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { + pub fn generate_proving_ctx(&mut self) -> AirProvingContext> + where + SC: StarkGenericConfig, + Domain: PolynomialSpace, + { assert!( self.final_state.is_some(), "Merkle chip must finalize before trace generation" @@ -82,11 +57,16 @@ where mut rows, init_root, final_root, - } = self.final_state.unwrap(); + } = self.final_state.take().unwrap(); // important that this sort be stable, // because we need the initial root to be first and the final root to be second - rows.sort_by_key(|row| Reverse(row.parent_height)); + rows.reverse(); + rows.swap(0, 1); + #[cfg(feature = "metrics")] + { + self.current_height = rows.len(); + } let width = MemoryMerkleCols::, CHUNK>::width(); let mut height = rows.len().next_power_of_two(); if let Some(mut oh) = self.overridden_height { @@ -103,9 +83,9 @@ where *trace_row.borrow_mut() = row; } - let trace = RowMajorMatrix::new(trace, width); + let trace = Arc::new(RowMajorMatrix::new(trace, width)); let pvs = init_root.into_iter().chain(final_root).collect(); - AirProofInput::simple(trace, pvs) + AirProvingContext::simple(trace, pvs) } } impl ChipUsageGetter for MemoryMerkleChip { @@ -114,7 +94,7 @@ impl ChipUsageGetter for MemoryMerkleChip usize { - 2 * self.num_touched_nonleaves + self.final_state.as_ref().map(|s| s.rows.len()).unwrap_or(0) } fn trace_width(&self) -> usize { @@ -122,138 +102,8 @@ impl ChipUsageGetter for MemoryMerkleChip { - memory_dimensions: MemoryDimensions, - final_memory: &'a Equipartition, - touched_nodes: &'a FxHashSet<(usize, u32, u32)>, - trace_rows: &'a mut Vec>, -} - -impl TreeHelper<'_, CHUNK, F> { - fn recur( - &mut self, - height: usize, - initial_node: &MemoryNode, - as_label: u32, - address_label: u32, - hasher: &mut impl HasherChip, - ) -> MemoryNode { - if height == 0 { - let address_space = as_label + self.memory_dimensions.as_offset; - let leaf_values = *self - .final_memory - .get(&(address_space, address_label)) - .unwrap_or(&[F::ZERO; CHUNK]); - MemoryNode::new_leaf(hasher.hash(&leaf_values)) - } else if let NonLeaf { - left: initial_left_node, - right: initial_right_node, - .. - } = initial_node.clone() - { - // Tell the hasher about this hash. - hasher.compress_and_record(&initial_left_node.hash(), &initial_right_node.hash()); - - let is_as_section = height > self.memory_dimensions.address_height; - - let (left_as_label, right_as_label) = if is_as_section { - (2 * as_label, 2 * as_label + 1) - } else { - (as_label, as_label) - }; - let (left_address_label, right_address_label) = if is_as_section { - (address_label, address_label) - } else { - (2 * address_label, 2 * address_label + 1) - }; - - let left_is_final = - !self - .touched_nodes - .contains(&(height - 1, left_as_label, left_address_label)); - - let final_left_node = if left_is_final { - initial_left_node - } else { - Arc::new(self.recur( - height - 1, - &initial_left_node, - left_as_label, - left_address_label, - hasher, - )) - }; - - let right_is_final = - !self - .touched_nodes - .contains(&(height - 1, right_as_label, right_address_label)); - - let final_right_node = if right_is_final { - initial_right_node - } else { - Arc::new(self.recur( - height - 1, - &initial_right_node, - right_as_label, - right_address_label, - hasher, - )) - }; - - let final_node = MemoryNode::new_nonleaf(final_left_node, final_right_node, hasher); - self.add_trace_row(height, as_label, address_label, initial_node, None); - self.add_trace_row( - height, - as_label, - address_label, - &final_node, - Some([left_is_final, right_is_final]), - ); - final_node - } else { - panic!("Leaf {:?} found at nonzero height {}", initial_node, height); - } - } - - /// Expects `node` to be NonLeaf - fn add_trace_row( - &mut self, - parent_height: usize, - as_label: u32, - address_label: u32, - node: &MemoryNode, - direction_changes: Option<[bool; 2]>, - ) { - let [left_direction_change, right_direction_change] = - direction_changes.unwrap_or([false; 2]); - let cols = if let NonLeaf { hash, left, right } = node { - MemoryMerkleCols { - expand_direction: if direction_changes.is_none() { - F::ONE - } else { - F::NEG_ONE - }, - height_section: F::from_bool(parent_height > self.memory_dimensions.address_height), - parent_height: F::from_canonical_usize(parent_height), - is_root: F::from_bool(parent_height == self.memory_dimensions.overall_height()), - parent_as_label: F::from_canonical_u32(as_label), - parent_address_label: F::from_canonical_u32(address_label), - parent_hash: *hash, - left_child_hash: left.hash(), - right_child_hash: right.hash(), - left_direction_different: F::from_bool(left_direction_change), - right_direction_different: F::from_bool(right_direction_change), - } - } else { - panic!("trace_rows expects node = {:?} to be NonLeaf", node); - }; - self.trace_rows.push(cols); - } -} - pub trait SerialReceiver { - fn receive(&mut self, msg: T); + fn receive(&self, msg: T); } impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]> @@ -261,7 +111,7 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]> { /// Receives a permutation preimage, pads with zeros to the permutation width, and records. /// The permutation preimage must have length at most the permutation width (panics otherwise). - fn receive(&mut self, perm_preimage: &'a [F]) { + fn receive(&self, perm_preimage: &'a [F]) { assert!(perm_preimage.len() <= PERIPHERY_POSEIDON2_WIDTH); let mut state = [F::ZERO; PERIPHERY_POSEIDON2_WIDTH]; state[..perm_preimage.len()].copy_from_slice(perm_preimage); @@ -271,7 +121,7 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> SerialReceiver<&'a [F]> } impl<'a, F: PrimeField32> SerialReceiver<&'a [F]> for Poseidon2PeripheryChip { - fn receive(&mut self, perm_preimage: &'a [F]) { + fn receive(&self, perm_preimage: &'a [F]) { match self { Poseidon2PeripheryChip::Register0(chip) => chip.receive(perm_preimage), Poseidon2PeripheryChip::Register1(chip) => chip.receive(perm_preimage), diff --git a/crates/vm/src/system/memory/merkle/tree.rs b/crates/vm/src/system/memory/merkle/tree.rs new file mode 100644 index 0000000000..956908abf2 --- /dev/null +++ b/crates/vm/src/system/memory/merkle/tree.rs @@ -0,0 +1,267 @@ +use openvm_stark_backend::{ + p3_field::PrimeField32, + p3_maybe_rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, +}; +use rustc_hash::FxHashMap; + +use super::{FinalState, MemoryMerkleCols}; +use crate::{ + arch::hasher::{Hasher, HasherChip}, + system::memory::{ + dimensions::MemoryDimensions, merkle::memory_to_vec_partition, AddressMap, Equipartition, + }, +}; + +#[derive(Debug)] +pub struct MerkleTree { + /// Height of the tree -- the root is the only node at height `height`, + /// and the leaves are at height `0`. + height: usize, + /// Nodes corresponding to all zeroes. + zero_nodes: Vec<[F; CHUNK]>, + /// Nodes in the tree that have ever been touched. + nodes: FxHashMap, +} + +impl MerkleTree { + pub fn new(height: usize, hasher: &impl Hasher) -> Self { + Self { + height, + zero_nodes: (0..height + 1) + .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| { + let result = Some(*acc); + *acc = hasher.compress(acc, acc); + result + }) + .collect(), + nodes: FxHashMap::default(), + } + } + + pub fn root(&self) -> [F; CHUNK] { + self.get_node(1) + } + + pub fn get_node(&self, index: u64) -> [F; CHUNK] { + self.nodes + .get(&index) + .cloned() + .unwrap_or(self.zero_nodes[self.height - index.ilog2() as usize]) + } + + #[allow(clippy::type_complexity)] + /// Shared logic for both from_memory and finalize. + fn process_layers( + &mut self, + layer: Vec<(u64, [F; CHUNK])>, + md: &MemoryDimensions, + mut rows: Option<&mut Vec>>, + compress: CompressFn, + ) where + CompressFn: Fn(&[F; CHUNK], &[F; CHUNK]) -> [F; CHUNK] + Send + Sync, + { + let mut new_entries = layer; + let mut layer = new_entries + .par_iter() + .map(|(index, values)| { + let old_values = self.nodes.get(index).unwrap_or(&self.zero_nodes[0]); + (*index, *values, *old_values) + }) + .collect::>(); + for height in 1..=self.height { + let new_layer = layer + .iter() + .enumerate() + .filter_map(|(i, (index, values, old_values))| { + if i > 0 && layer[i - 1].0 ^ 1 == *index { + return None; + } + + let par_index = index >> 1; + + if i + 1 < layer.len() && layer[i + 1].0 == index ^ 1 { + let (_, sibling_values, sibling_old_values) = &layer[i + 1]; + Some(( + par_index, + Some((values, old_values)), + Some((sibling_values, sibling_old_values)), + )) + } else if index & 1 == 0 { + Some((par_index, Some((values, old_values)), None)) + } else { + Some((par_index, None, Some((values, old_values)))) + } + }) + .collect::>(); + + match rows { + None => { + layer = new_layer + .into_par_iter() + .map(|(par_index, left, right)| { + let left = if let Some(left) = left { + left.0 + } else { + &self.get_node(2 * par_index) + }; + let right = if let Some(right) = right { + right.0 + } else { + &self.get_node(2 * par_index + 1) + }; + let combined = compress(left, right); + let par_old_values = self.get_node(par_index); + (par_index, combined, par_old_values) + }) + .collect(); + } + Some(ref mut rows) => { + let label_section_height = md.address_height.saturating_sub(height); + let (tmp, new_rows): (Vec<(u64, [F; CHUNK], [F; CHUNK])>, Vec<[_; 2]>) = + new_layer + .into_par_iter() + .map(|(par_index, left, right)| { + let parent_address_label = + (par_index & ((1 << label_section_height) - 1)) as u32; + let parent_as_label = ((par_index & !(1 << (self.height - height))) + >> label_section_height) + as u32; + let left_node; + let (left, old_left, changed_left) = match left { + Some((left, old_left)) => (left, old_left, true), + None => { + left_node = self.get_node(2 * par_index); + (&left_node, &left_node, false) + } + }; + let right_node; + let (right, old_right, changed_right) = match right { + Some((right, old_right)) => (right, old_right, true), + None => { + right_node = self.get_node(2 * par_index + 1); + (&right_node, &right_node, false) + } + }; + let combined = compress(left, right); + // This is a hacky way to say: + // "and we also want to record the old values" + compress(old_left, old_right); + let par_old_values = self.get_node(par_index); + ( + (par_index, combined, par_old_values), + [ + MemoryMerkleCols { + expand_direction: F::ONE, + height_section: F::from_bool( + height > md.address_height, + ), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32( + parent_address_label, + ), + parent_hash: par_old_values, + left_child_hash: *old_left, + right_child_hash: *old_right, + left_direction_different: F::ZERO, + right_direction_different: F::ZERO, + }, + MemoryMerkleCols { + expand_direction: F::NEG_ONE, + height_section: F::from_bool( + height > md.address_height, + ), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32( + parent_address_label, + ), + parent_hash: combined, + left_child_hash: *left, + right_child_hash: *right, + left_direction_different: F::from_bool(!changed_left), + right_direction_different: F::from_bool(!changed_right), + }, + ], + ) + }) + .unzip(); + rows.extend(new_rows.into_iter().flatten()); + layer = tmp; + } + } + new_entries.extend(layer.iter().map(|(idx, values, _)| (*idx, *values))); + } + + if self.nodes.is_empty() { + // This, for example, should happen in every `from_memory` call + self.nodes = FxHashMap::from_iter(new_entries); + } else { + self.nodes.extend(new_entries); + } + } + + pub fn from_memory( + memory: &AddressMap, + md: &MemoryDimensions, + hasher: &(impl Hasher + Sync), + ) -> Self { + let mut tree = Self::new(md.overall_height(), hasher); + let layer: Vec<_> = memory_to_vec_partition(memory, md) + .par_iter() + .map(|(idx, v)| ((1 << tree.height) + idx, hasher.hash(v))) + .collect(); + tree.process_layers(layer, md, None, |left, right| hasher.compress(left, right)); + tree + } + + pub fn finalize( + &mut self, + hasher: &impl HasherChip, + touched: &Equipartition, + md: &MemoryDimensions, + ) -> FinalState { + let init_root = self.get_node(1); + let layer: Vec<_> = if !touched.is_empty() { + touched + .iter() + .map(|((addr_sp, ptr), v)| { + ( + (1 << self.height) + md.label_to_index((*addr_sp, *ptr / CHUNK as u32)), + hasher.hash(v), + ) + }) + .collect() + } else { + let index = 1 << self.height; + vec![(index, self.get_node(index))] + }; + let mut rows = Vec::with_capacity(if layer.is_empty() { + 0 + } else { + layer + .iter() + .zip(layer.iter().skip(1)) + .fold(md.overall_height(), |acc, ((lhs, _), (rhs, _))| { + acc + (lhs ^ rhs).ilog2() as usize + }) + }); + self.process_layers(layer, md, Some(&mut rows), |left, right| { + hasher.compress_and_record(left, right) + }); + if touched.is_empty() { + // If we made an artificial touch, we need to change the direction changes for the + // leaves + rows[1].left_direction_different = F::ONE; + rows[1].right_direction_different = F::ONE; + } + let final_root = self.get_node(1); + FinalState { + rows, + init_root, + final_root, + } + } +} diff --git a/crates/vm/src/system/memory/mod.rs b/crates/vm/src/system/memory/mod.rs index ac6a7d85cf..411e7a5473 100644 --- a/crates/vm/src/system/memory/mod.rs +++ b/crates/vm/src/system/memory/mod.rs @@ -1,21 +1,40 @@ +use std::sync::Arc; + +use openvm_circuit_primitives::{is_less_than::IsLtSubAir, var_range::VariableRangeCheckerBus}; use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::PermutationCheckBus, + p3_field::Field, + p3_util::{log2_ceil_usize, log2_strict_usize}, + AirRef, +}; -mod adapter; +pub mod adapter; mod controller; pub mod merkle; -mod offline; pub mod offline_checker; pub mod online; -pub mod paged_vec; -mod persistent; +pub mod persistent; #[cfg(test)] mod tests; -pub mod tree; -mod volatile; +pub mod volatile; pub use controller::*; -pub use offline::*; -pub use paged_vec::*; +pub use online::{Address, AddressMap, INITIAL_TIMESTAMP}; + +use crate::{ + arch::{MemoryConfig, ADDR_SPACE_OFFSET}, + system::memory::{ + adapter::AccessAdapterAir, dimensions::MemoryDimensions, interface::MemoryInterfaceAirs, + merkle::MemoryMerkleAir, offline_checker::MemoryBridge, persistent::PersistentBoundaryAir, + volatile::VolatileBoundaryAir, + }, +}; + +// @dev Currently this is only used for debug assertions, but we may switch to making it constant +// and removing from MemoryConfig +pub const POINTER_MAX_BITS: usize = 29; #[derive(PartialEq, Copy, Clone, Debug, Eq)] pub enum OpType { @@ -52,9 +71,95 @@ impl MemoryAddress { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, AlignedBorrow)] -#[repr(C)] -pub struct HeapAddress { - pub address: MemoryAddress, - pub data: MemoryAddress, +#[derive(Clone)] +pub struct MemoryAirInventory { + pub bridge: MemoryBridge, + pub interface: MemoryInterfaceAirs, + pub access_adapters: Vec>, +} + +impl MemoryAirInventory { + pub fn new( + bridge: MemoryBridge, + mem_config: &MemoryConfig, + range_bus: VariableRangeCheckerBus, + merkle_compression_buses: Option<(PermutationCheckBus, PermutationCheckBus)>, + ) -> Self { + let memory_bus = bridge.memory_bus(); + let interface = if let Some((merkle_bus, compression_bus)) = merkle_compression_buses { + // Persistent memory + let memory_dims = MemoryDimensions { + addr_space_height: mem_config.addr_space_height, + address_height: mem_config.pointer_max_bits - log2_strict_usize(CHUNK), + }; + let boundary = PersistentBoundaryAir:: { + memory_dims, + memory_bus, + merkle_bus, + compression_bus, + }; + let merkle = MemoryMerkleAir:: { + memory_dimensions: memory_dims, + merkle_bus, + compression_bus, + }; + MemoryInterfaceAirs::Persistent { boundary, merkle } + } else { + // Volatile memory + let addr_space_height = mem_config.addr_space_height; + assert!(addr_space_height < Val::::bits() - 2); + let addr_space_max_bits = + log2_ceil_usize((ADDR_SPACE_OFFSET + 2u32.pow(addr_space_height as u32)) as usize); + let boundary = VolatileBoundaryAir::new( + memory_bus, + addr_space_max_bits, + mem_config.pointer_max_bits, + range_bus, + ); + MemoryInterfaceAirs::Volatile { boundary } + }; + // Memory access adapters + let lt_air = IsLtSubAir::new(range_bus, mem_config.timestamp_max_bits); + let maan = mem_config.max_access_adapter_n; + assert!(matches!(maan, 2 | 4 | 8 | 16 | 32)); + let access_adapters: Vec> = [ + Arc::new(AccessAdapterAir::<2> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<4> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<8> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<16> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<32> { memory_bus, lt_air }) as AirRef, + ] + .into_iter() + .take(log2_strict_usize(maan)) + .collect(); + + Self { + bridge, + interface, + access_adapters, + } + } + + /// The order of memory AIRs is boundary, merkle (if exists), access adapters + pub fn into_airs(self) -> Vec> { + let mut airs: Vec> = Vec::new(); + match self.interface { + MemoryInterfaceAirs::Volatile { boundary } => { + airs.push(Arc::new(boundary)); + } + MemoryInterfaceAirs::Persistent { boundary, merkle } => { + airs.push(Arc::new(boundary)); + airs.push(Arc::new(merkle)); + } + } + airs.extend(self.access_adapters); + airs + } +} + +/// This is O(1) and returns the length of +/// [`MemoryAirInventory::into_airs`]. +pub fn num_memory_airs(is_persistent: bool, max_access_adapter_n: usize) -> usize { + // boundary + { merkle if is_persistent } + access_adapters + 1 + usize::from(is_persistent) + log2_strict_usize(max_access_adapter_n) } diff --git a/crates/vm/src/system/memory/offline.rs b/crates/vm/src/system/memory/offline.rs deleted file mode 100644 index 74bb238811..0000000000 --- a/crates/vm/src/system/memory/offline.rs +++ /dev/null @@ -1,1070 +0,0 @@ -use std::{array, cmp::max}; - -use openvm_circuit_primitives::{ - assert_less_than::AssertLtSubAir, var_range::SharedVariableRangeCheckerChip, -}; -use openvm_stark_backend::p3_field::PrimeField32; -use rustc_hash::FxHashSet; - -use super::{AddressMap, PagedVec, PAGE_SIZE}; -use crate::{ - arch::MemoryConfig, - system::memory::{ - adapter::{AccessAdapterInventory, AccessAdapterRecord, AccessAdapterRecordKind}, - offline_checker::{MemoryBridge, MemoryBus}, - MemoryAuxColsFactory, MemoryImage, RecordId, TimestampedEquipartition, TimestampedValues, - }, -}; - -pub const INITIAL_TIMESTAMP: u32 = 0; - -#[repr(C)] -#[derive(Clone, Default, PartialEq, Eq, Debug)] -struct BlockData { - pointer: u32, - timestamp: u32, - size: usize, -} - -struct BlockMap { - /// Block ids. 0 is a special value standing for the default block. - id: AddressMap, - /// The place where non-default blocks are stored. - storage: Vec, - initial_block_size: usize, -} - -impl BlockMap { - pub fn from_mem_config(mem_config: &MemoryConfig, initial_block_size: usize) -> Self { - assert!(initial_block_size.is_power_of_two()); - Self { - id: AddressMap::from_mem_config(mem_config), - storage: vec![], - initial_block_size, - } - } - - fn initial_block_data(pointer: u32, initial_block_size: usize) -> BlockData { - let aligned_pointer = (pointer / initial_block_size as u32) * initial_block_size as u32; - BlockData { - pointer: aligned_pointer, - size: initial_block_size, - timestamp: INITIAL_TIMESTAMP, - } - } - - pub fn get_without_adding(&self, address: &(u32, u32)) -> BlockData { - let idx = self.id.get(address).unwrap_or(&0); - if idx == &0 { - Self::initial_block_data(address.1, self.initial_block_size) - } else { - self.storage[idx - 1].clone() - } - } - - pub fn get(&mut self, address: &(u32, u32)) -> &BlockData { - let (address_space, pointer) = *address; - let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0); - if idx == &0 { - // `initial_block_size` is a power of two, as asserted in `from_mem_config`. - let pointer = pointer & !(self.initial_block_size as u32 - 1); - self.set_range( - &(address_space, pointer), - self.initial_block_size, - Self::initial_block_data(pointer, self.initial_block_size), - ); - self.storage.last().unwrap() - } else { - &self.storage[idx - 1] - } - } - - pub fn get_mut(&mut self, address: &(u32, u32)) -> &mut BlockData { - let (address_space, pointer) = *address; - let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0); - if idx == &0 { - let pointer = pointer - pointer % self.initial_block_size as u32; - self.set_range( - &(address_space, pointer), - self.initial_block_size, - Self::initial_block_data(pointer, self.initial_block_size), - ); - self.storage.last_mut().unwrap() - } else { - &mut self.storage[idx - 1] - } - } - - pub fn set_range(&mut self, address: &(u32, u32), len: usize, block: BlockData) { - let (address_space, pointer) = address; - self.storage.push(block); - for i in 0..len { - self.id - .insert(&(*address_space, pointer + i as u32), self.storage.len()); - } - } - - pub fn items(&self) -> impl Iterator + '_ { - self.id - .items() - .filter(|(_, idx)| *idx > 0) - .map(|(address, idx)| (address, &self.storage[idx - 1])) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct MemoryRecord { - pub address_space: T, - pub pointer: T, - pub timestamp: u32, - pub prev_timestamp: u32, - data: Vec, - /// None if a read. - prev_data: Option>, -} - -impl MemoryRecord { - pub fn data_slice(&self) -> &[T] { - self.data.as_slice() - } - - pub fn prev_data_slice(&self) -> Option<&[T]> { - self.prev_data.as_deref() - } -} - -impl MemoryRecord { - pub fn data_at(&self, index: usize) -> T { - self.data[index] - } -} - -pub struct OfflineMemory { - block_data: BlockMap, - data: Vec>, - as_offset: u32, - timestamp: u32, - timestamp_max_bits: usize, - - memory_bus: MemoryBus, - range_checker: SharedVariableRangeCheckerChip, - - log: Vec>>, -} - -impl OfflineMemory { - /// Creates a new partition with the given initial block size. - /// - /// Panics if the initial block size is not a power of two. - pub fn new( - initial_memory: MemoryImage, - initial_block_size: usize, - memory_bus: MemoryBus, - range_checker: SharedVariableRangeCheckerChip, - config: MemoryConfig, - ) -> Self { - assert_eq!(initial_memory.as_offset, config.as_offset); - Self { - block_data: BlockMap::from_mem_config(&config, initial_block_size), - data: initial_memory.paged_vecs, - as_offset: config.as_offset, - timestamp: INITIAL_TIMESTAMP + 1, - timestamp_max_bits: config.clk_max_bits, - memory_bus, - range_checker, - log: vec![], - } - } - - pub fn set_initial_memory(&mut self, initial_memory: MemoryImage, config: MemoryConfig) { - assert_eq!(self.timestamp, INITIAL_TIMESTAMP + 1); - assert_eq!(initial_memory.as_offset, config.as_offset); - self.as_offset = config.as_offset; - self.data = initial_memory.paged_vecs; - } - - pub(super) fn set_log_capacity(&mut self, access_capacity: usize) { - assert!(self.log.is_empty()); - self.log = Vec::with_capacity(access_capacity); - } - - pub fn memory_bridge(&self) -> MemoryBridge { - MemoryBridge::new( - self.memory_bus, - self.timestamp_max_bits, - self.range_checker.bus(), - ) - } - - pub fn timestamp(&self) -> u32 { - self.timestamp - } - - /// Increments the current timestamp by one and returns the new value. - pub fn increment_timestamp(&mut self) { - self.increment_timestamp_by(1) - } - - /// Increments the current timestamp by a specified delta and returns the new value. - pub fn increment_timestamp_by(&mut self, delta: u32) { - self.log.push(None); - self.timestamp += delta; - } - - /// Writes an array of values to the memory at the specified address space and start index. - pub fn write( - &mut self, - address_space: u32, - pointer: u32, - values: Vec, - records: &mut AccessAdapterInventory, - ) { - let len = values.len(); - assert!(len.is_power_of_two()); - assert_ne!(address_space, 0); - - let prev_timestamp = self.access_updating_timestamp(address_space, pointer, len, records); - - debug_assert!(prev_timestamp < self.timestamp); - - let pointer = pointer as usize; - let prev_data = self.data[(address_space - self.as_offset) as usize] - .set_range(pointer..pointer + len, &values); - - let record = MemoryRecord { - address_space: F::from_canonical_u32(address_space), - pointer: F::from_canonical_usize(pointer), - timestamp: self.timestamp, - prev_timestamp, - data: values, - prev_data: Some(prev_data), - }; - self.log.push(Some(record)); - self.timestamp += 1; - } - - /// Reads an array of values from the memory at the specified address space and start index. - pub fn read( - &mut self, - address_space: u32, - pointer: u32, - len: usize, - adapter_records: &mut AccessAdapterInventory, - ) { - assert!(len.is_power_of_two()); - if address_space == 0 { - let pointer = F::from_canonical_u32(pointer); - self.log.push(Some(MemoryRecord { - address_space: F::ZERO, - pointer, - timestamp: self.timestamp, - prev_timestamp: 0, - data: vec![pointer], - prev_data: None, - })); - self.timestamp += 1; - return; - } - - let prev_timestamp = - self.access_updating_timestamp(address_space, pointer, len, adapter_records); - - debug_assert!(prev_timestamp < self.timestamp); - - let values = self.range_vec(address_space, pointer, len); - - self.log.push(Some(MemoryRecord { - address_space: F::from_canonical_u32(address_space), - pointer: F::from_canonical_u32(pointer), - timestamp: self.timestamp, - prev_timestamp, - data: values, - prev_data: None, - })); - self.timestamp += 1; - } - - pub fn record_by_id(&self, id: RecordId) -> &MemoryRecord { - self.log[id.0].as_ref().unwrap() - } - - pub fn finalize( - &mut self, - adapter_records: &mut AccessAdapterInventory, - ) -> TimestampedEquipartition { - // First make sure the partition we maintain in self.block_data is an equipartition. - // Grab all aligned pointers that need to be re-accessed. - let to_access: FxHashSet<_> = self - .block_data - .items() - .map(|((address_space, pointer), _)| (address_space, (pointer / N as u32) * N as u32)) - .collect(); - - for &(address_space, pointer) in to_access.iter() { - let block = self.block_data.get(&(address_space, pointer)); - if block.pointer != pointer || block.size != N { - self.access(address_space, pointer, N, adapter_records); - } - } - - let mut equipartition = TimestampedEquipartition::::new(); - for (address_space, pointer) in to_access { - let block = self.block_data.get(&(address_space, pointer)); - - debug_assert_eq!(block.pointer % N as u32, 0); - debug_assert_eq!(block.size, N); - - equipartition.insert( - (address_space, pointer / N as u32), - TimestampedValues { - timestamp: block.timestamp, - values: self.range_array::(address_space, pointer), - }, - ); - } - equipartition - } - - // Modifies the partition to ensure that there is a block starting at (address_space, query). - fn split_to_make_boundary( - &mut self, - address_space: u32, - query: u32, - records: &mut AccessAdapterInventory, - ) { - let lim = (self.data[(address_space - self.as_offset) as usize].memory_size()) as u32; - if query == lim { - return; - } - assert!(query < lim); - let original_block = self.block_containing(address_space, query); - if original_block.pointer == query { - return; - } - - let data = self.range_vec(address_space, original_block.pointer, original_block.size); - - let timestamp = original_block.timestamp; - - let mut cur_ptr = original_block.pointer; - let mut cur_size = original_block.size; - while cur_size > 0 { - // Split. - records.add_record(AccessAdapterRecord { - timestamp, - address_space: F::from_canonical_u32(address_space), - start_index: F::from_canonical_u32(cur_ptr), - data: data[(cur_ptr - original_block.pointer) as usize - ..(cur_ptr - original_block.pointer) as usize + cur_size] - .to_vec(), - kind: AccessAdapterRecordKind::Split, - }); - - let half_size = cur_size / 2; - let half_size_u32 = half_size as u32; - let mid_ptr = cur_ptr + half_size_u32; - - if query <= mid_ptr { - // The right is finalized; add it to the partition. - let block = BlockData { - pointer: mid_ptr, - size: half_size, - timestamp, - }; - self.block_data - .set_range(&(address_space, mid_ptr), half_size, block); - } - if query >= cur_ptr + half_size_u32 { - // The left is finalized; add it to the partition. - let block = BlockData { - pointer: cur_ptr, - size: half_size, - timestamp, - }; - self.block_data - .set_range(&(address_space, cur_ptr), half_size, block); - } - if mid_ptr <= query { - cur_ptr = mid_ptr; - } - if cur_ptr == query { - break; - } - cur_size = half_size; - } - } - - fn access_updating_timestamp( - &mut self, - address_space: u32, - pointer: u32, - size: usize, - records: &mut AccessAdapterInventory, - ) -> u32 { - self.access(address_space, pointer, size, records); - - let mut prev_timestamp = None; - - let mut i = 0; - while i < size as u32 { - let block = self.block_data.get_mut(&(address_space, pointer + i)); - debug_assert!(i == 0 || prev_timestamp == Some(block.timestamp)); - prev_timestamp = Some(block.timestamp); - block.timestamp = self.timestamp; - i = block.pointer + block.size as u32; - } - prev_timestamp.unwrap() - } - - fn access( - &mut self, - address_space: u32, - pointer: u32, - size: usize, - records: &mut AccessAdapterInventory, - ) { - self.split_to_make_boundary(address_space, pointer, records); - self.split_to_make_boundary(address_space, pointer + size as u32, records); - - let block_data = self.block_containing(address_space, pointer); - - if block_data.pointer == pointer && block_data.size == size { - return; - } - assert!(size > 1); - - // Now recursively access left and right blocks to ensure they are in the partition. - let half_size = size / 2; - self.access(address_space, pointer, half_size, records); - self.access( - address_space, - pointer + half_size as u32, - half_size, - records, - ); - - self.merge_block_with_next(address_space, pointer, records); - } - - /// Merges the two adjacent blocks starting at (address_space, pointer). - /// - /// Panics if there is no block starting at (address_space, pointer) or if the two blocks - /// do not have the same size. - fn merge_block_with_next( - &mut self, - address_space: u32, - pointer: u32, - records: &mut AccessAdapterInventory, - ) { - let left_block = self.block_data.get(&(address_space, pointer)); - - let left_timestamp = left_block.timestamp; - let size = left_block.size; - - let right_timestamp = self - .block_data - .get(&(address_space, pointer + size as u32)) - .timestamp; - - let timestamp = max(left_timestamp, right_timestamp); - self.block_data.set_range( - &(address_space, pointer), - 2 * size, - BlockData { - pointer, - size: 2 * size, - timestamp, - }, - ); - records.add_record(AccessAdapterRecord { - timestamp, - address_space: F::from_canonical_u32(address_space), - start_index: F::from_canonical_u32(pointer), - data: self.range_vec(address_space, pointer, 2 * size), - kind: AccessAdapterRecordKind::Merge { - left_timestamp, - right_timestamp, - }, - }); - } - - fn block_containing(&mut self, address_space: u32, pointer: u32) -> BlockData { - self.block_data - .get_without_adding(&(address_space, pointer)) - } - - pub fn get(&self, address_space: u32, pointer: u32) -> F { - self.data[(address_space - self.as_offset) as usize] - .get(pointer as usize) - .cloned() - .unwrap_or_default() - } - - fn range_array(&self, address_space: u32, pointer: u32) -> [F; N] { - array::from_fn(|i| self.get(address_space, pointer + i as u32)) - } - - fn range_vec(&self, address_space: u32, pointer: u32, len: usize) -> Vec { - let pointer = pointer as usize; - self.data[(address_space - self.as_offset) as usize].range_vec(pointer..pointer + len) - } - - pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory { - let range_bus = self.range_checker.bus(); - MemoryAuxColsFactory { - range_checker: self.range_checker.clone(), - timestamp_lt_air: AssertLtSubAir::new(range_bus, self.timestamp_max_bits), - _marker: Default::default(), - } - } - - // just for unit testing - #[cfg(test)] - fn last_record(&self) -> &MemoryRecord { - self.log.last().unwrap().as_ref().unwrap() - } -} - -#[cfg(test)] -mod tests { - use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, - }; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - - use super::{BlockData, MemoryRecord, OfflineMemory}; - use crate::{ - arch::MemoryConfig, - system::memory::{ - adapter::{AccessAdapterInventory, AccessAdapterRecord, AccessAdapterRecordKind}, - offline_checker::MemoryBus, - paged_vec::AddressMap, - MemoryImage, TimestampedValues, - }, - }; - - macro_rules! bb { - ($x:expr) => { - BabyBear::from_canonical_u32($x) - }; - } - - macro_rules! bba { - [$($x:expr),*] => { - [$(BabyBear::from_canonical_u32($x)),*] - } - } - - macro_rules! bbvec { - [$($x:expr),*] => { - vec![$(BabyBear::from_canonical_u32($x)),*] - } - } - - fn setup_test( - initial_memory: MemoryImage, - initial_block_size: usize, - ) -> (OfflineMemory, AccessAdapterInventory) { - let memory_bus = MemoryBus::new(0); - let range_checker = - SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)); - let mem_config = MemoryConfig { - as_offset: initial_memory.as_offset, - ..Default::default() - }; - let memory = OfflineMemory::new( - initial_memory, - initial_block_size, - memory_bus, - range_checker.clone(), - mem_config, - ); - let access_adapter_inventory = AccessAdapterInventory::new( - range_checker, - memory_bus, - mem_config.clk_max_bits, - mem_config.max_access_adapter_n, - ); - (memory, access_adapter_inventory) - } - - #[test] - fn test_partition() { - let initial_memory = AddressMap::new(0, 1, 16); - let (mut memory, _) = setup_test(initial_memory, 8); - assert_eq!( - memory.block_containing(1, 13), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 8), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 15), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 16), - BlockData { - pointer: 16, - size: 8, - timestamp: 0, - } - ); - } - - #[test] - fn test_write_read_initial_block_len_1() { - let (mut memory, mut access_adapters) = setup_test(MemoryImage::default(), 1); - let address_space = 1; - - memory.write(address_space, 0, bbvec![1, 2, 3, 4], &mut access_adapters); - - memory.read(address_space, 0, 2, &mut access_adapters); - let read_record = memory.last_record(); - assert_eq!(read_record.data, bba![1, 2]); - - memory.write(address_space, 2, bbvec![100], &mut access_adapters); - - memory.read(address_space, 0, 4, &mut access_adapters); - let read_record = memory.last_record(); - assert_eq!(read_record.data, bba![1, 2, 100, 4]); - } - - #[test] - fn test_records_initial_block_len_1() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 1); - - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - - // Above write first causes merge of [0:1] and [1:2] into [0:2]. - assert_eq!( - adapter_records.records_for_n(2)[0], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // then merge [2:3] and [3:4] into [2:4]. - assert_eq!( - adapter_records.records_for_n(2)[1], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(2), - data: bbvec![0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // then merge [0:2] and [2:4] into [0:4]. - assert_eq!( - adapter_records.records_for_n(4)[0], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0, 0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // At time 1 we write [0:4]. - let write_record = memory.last_record(); - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 1, - prev_timestamp: 0, - data: bbvec![1, 2, 3, 4], - prev_data: Some(bbvec![0, 0, 0, 0]), - } - ); - assert_eq!(memory.timestamp(), 2); - assert_eq!(adapter_records.total_records(), 3); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - // At time 2 we read [0:4]. - assert_eq!(adapter_records.total_records(), 3); - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 2, - prev_timestamp: 1, - data: bbvec![1, 2, 3, 4], - prev_data: None, - } - ); - assert_eq!(memory.timestamp(), 3); - - memory.write(1, 0, bbvec![10, 11], &mut adapter_records); - let write_record = memory.last_record(); - // write causes split [0:4] into [0:2] and [2:4] (to prepare for write to [0:2]). - assert_eq!(adapter_records.total_records(), 4); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 2, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![1, 2, 3, 4], - kind: AccessAdapterRecordKind::Split, - } - ); - - // At time 3 we write [10, 11] into [0, 2]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 3, - prev_timestamp: 2, - data: bbvec![10, 11], - prev_data: Some(bbvec![1, 2]), - } - ); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - assert_eq!(adapter_records.total_records(), 5); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 3, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![10, 11, 3, 4], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 3, - right_timestamp: 2 - }, - } - ); - // At time 9 we read [0:4]. - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 4, - prev_timestamp: 3, - data: bbvec![10, 11, 3, 4], - prev_data: None, - } - ); - } - - #[test] - fn test_records_initial_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - let write_record = memory.last_record(); - - // Above write first causes split of [0:8] into [0:4] and [4:8]. - assert_eq!(adapter_records.total_records(), 1); - assert_eq!( - adapter_records.records_for_n(8).last().unwrap(), - &AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0, 0, 0, 0, 0, 0, 0], - kind: AccessAdapterRecordKind::Split, - } - ); - // At time 1 we write [0:4]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 1, - prev_timestamp: 0, - data: bbvec![1, 2, 3, 4], - prev_data: Some(bbvec![0, 0, 0, 0]), - } - ); - assert_eq!(memory.timestamp(), 2); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - // At time 2 we read [0:4]. - assert_eq!(adapter_records.total_records(), 1); - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 2, - prev_timestamp: 1, - data: bbvec![1, 2, 3, 4], - prev_data: None, - } - ); - assert_eq!(memory.timestamp(), 3); - - memory.write(1, 0, bbvec![10, 11], &mut adapter_records); - let write_record = memory.last_record(); - // write causes split [0:4] into [0:2] and [2:4] (to prepare for write to [0:2]). - assert_eq!(adapter_records.total_records(), 2); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 2, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![1, 2, 3, 4], - kind: AccessAdapterRecordKind::Split, - } - ); - - // At time 3 we write [10, 11] into [0, 2]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 3, - prev_timestamp: 2, - data: bbvec![10, 11], - prev_data: Some(bbvec![1, 2]), - } - ); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - assert_eq!(adapter_records.total_records(), 3); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 3, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![10, 11, 3, 4], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 3, - right_timestamp: 2 - }, - } - ); - // At time 9 we read [0:4]. - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 4, - prev_timestamp: 3, - data: bbvec![10, 11, 3, 4], - prev_data: None, - } - ); - } - - #[test] - fn test_get_initial_block_len_1() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 1); - - memory.write(2, 0, bbvec![4, 3, 2, 1], &mut adapter_records); - - assert_eq!(memory.get(2, 0), BabyBear::from_canonical_u32(4)); - assert_eq!(memory.get(2, 1), BabyBear::from_canonical_u32(3)); - assert_eq!(memory.get(2, 2), BabyBear::from_canonical_u32(2)); - assert_eq!(memory.get(2, 3), BabyBear::from_canonical_u32(1)); - assert_eq!(memory.get(2, 5), BabyBear::ZERO); - - assert_eq!(memory.get(1, 0), BabyBear::ZERO); - } - - #[test] - fn test_get_initial_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - - memory.write(2, 0, bbvec![4, 3, 2, 1], &mut adapter_records); - - assert_eq!(memory.get(2, 0), BabyBear::from_canonical_u32(4)); - assert_eq!(memory.get(2, 1), BabyBear::from_canonical_u32(3)); - assert_eq!(memory.get(2, 2), BabyBear::from_canonical_u32(2)); - assert_eq!(memory.get(2, 3), BabyBear::from_canonical_u32(1)); - assert_eq!(memory.get(2, 5), BabyBear::ZERO); - assert_eq!(memory.get(2, 9), BabyBear::ZERO); - assert_eq!(memory.get(1, 0), BabyBear::ZERO); - } - - #[test] - fn test_finalize_empty() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 4); - - let memory = memory.finalize::<4>(&mut adapter_records); - assert_eq!(memory.len(), 0); - assert_eq!(adapter_records.total_records(), 0); - } - - #[test] - fn test_finalize_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - // Make block 0:4 in address space 1 active. - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - - // Make block 16:32 in address space 1 active. - memory.write( - 1, - 16, - bbvec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - &mut adapter_records, - ); - - // Make block 64:72 in address space 2 active. - memory.write(2, 64, bbvec![8, 7, 6, 5, 4, 3, 2, 1], &mut adapter_records); - - let num_records_before_finalize = adapter_records.total_records(); - - // Finalize to a partition of size 8. - let final_memory = memory.finalize::<8>(&mut adapter_records); - assert_eq!(final_memory.len(), 4); - assert_eq!( - final_memory.get(&(1, 0)), - Some(&TimestampedValues { - values: bba![1, 2, 3, 4, 0, 0, 0, 0], - timestamp: 1, - }) - ); - // start_index = 16 corresponds to label = 2 - assert_eq!( - final_memory.get(&(1, 2)), - Some(&TimestampedValues { - values: bba![1, 1, 1, 1, 1, 1, 1, 1], - timestamp: 2, - }) - ); - // start_index = 24 corresponds to label = 3 - assert_eq!( - final_memory.get(&(1, 3)), - Some(&TimestampedValues { - values: bba![1, 1, 1, 1, 1, 1, 1, 1], - timestamp: 2, - }) - ); - // start_index = 64 corresponds to label = 8 - assert_eq!( - final_memory.get(&(2, 8)), - Some(&TimestampedValues { - values: bba![8, 7, 6, 5, 4, 3, 2, 1], - timestamp: 3, - }) - ); - - // We need to do 1 + 1 + 0 = 2 adapters. - assert_eq!( - adapter_records.total_records() - num_records_before_finalize, - 2 - ); - } - - #[test] - fn test_write_read_initial_block_len_8_initial_memory() { - type F = BabyBear; - - // Initialize initial memory with blocks at indices 0 and 2 - let mut initial_memory = MemoryImage::default(); - for i in 0..8 { - initial_memory.insert(&(1, i), F::from_canonical_u32(i + 1)); - initial_memory.insert(&(1, 16 + i), F::from_canonical_u32(i + 1)); - } - - let (mut memory, mut adapter_records) = setup_test(initial_memory, 8); - - // Verify initial state of block 0 (pointers 0–8) - memory.read(1, 0, 8, &mut adapter_records); - let initial_read_record_0 = memory.last_record(); - assert_eq!(initial_read_record_0.data, bbvec![1, 2, 3, 4, 5, 6, 7, 8]); - - // Verify initial state of block 2 (pointers 16–24) - memory.read(1, 16, 8, &mut adapter_records); - let initial_read_record_2 = memory.last_record(); - assert_eq!(initial_read_record_2.data, bbvec![1, 2, 3, 4, 5, 6, 7, 8]); - - // Test: Write a partial block to block 0 (pointer 0) and read back partially and fully - memory.write(1, 0, bbvec![9, 9, 9, 9], &mut adapter_records); - memory.read(1, 0, 2, &mut adapter_records); - let partial_read_record = memory.last_record(); - assert_eq!(partial_read_record.data, bbvec![9, 9]); - - memory.read(1, 0, 8, &mut adapter_records); - let full_read_record_0 = memory.last_record(); - assert_eq!(full_read_record_0.data, bbvec![9, 9, 9, 9, 5, 6, 7, 8]); - - // Test: Write a single element to pointer 2 and verify read in different lengths - memory.write(1, 2, bbvec![100], &mut adapter_records); - memory.read(1, 1, 4, &mut adapter_records); - let read_record_4 = memory.last_record(); - assert_eq!(read_record_4.data, bbvec![9, 100, 9, 5]); - - memory.read(1, 2, 8, &mut adapter_records); - let full_read_record_2 = memory.last_record(); - assert_eq!(full_read_record_2.data, bba![100, 9, 5, 6, 7, 8, 0, 0]); - - // Test: Write and read at the last pointer in block 2 (pointer 23, part of key (1, 2)) - memory.write(1, 23, bbvec![77], &mut adapter_records); - memory.read(1, 23, 2, &mut adapter_records); - let boundary_read_record = memory.last_record(); - assert_eq!(boundary_read_record.data, bba![77, 0]); // Last byte modified, ensuring boundary check - - // Test: Reading from an uninitialized block (should default to 0) - memory.read(1, 10, 4, &mut adapter_records); - let default_read_record = memory.last_record(); - assert_eq!(default_read_record.data, bba![0, 0, 0, 0]); - - memory.read(1, 100, 4, &mut adapter_records); - let default_read_record = memory.last_record(); - assert_eq!(default_read_record.data, bba![0, 0, 0, 0]); - - // Test: Overwrite entire memory pointer 16–24 and verify - memory.write( - 1, - 16, - bbvec![50, 50, 50, 50, 50, 50, 50, 50], - &mut adapter_records, - ); - memory.read(1, 16, 8, &mut adapter_records); - let overwrite_read_record = memory.last_record(); - assert_eq!( - overwrite_read_record.data, - bba![50, 50, 50, 50, 50, 50, 50, 50] - ); // Verify entire block overwrite - } -} diff --git a/crates/vm/src/system/memory/offline_checker/bridge.rs b/crates/vm/src/system/memory/offline_checker/bridge.rs index 2c7e180cfb..367e1344d7 100644 --- a/crates/vm/src/system/memory/offline_checker/bridge.rs +++ b/crates/vm/src/system/memory/offline_checker/bridge.rs @@ -1,3 +1,4 @@ +use getset::CopyGetters; use openvm_circuit_primitives::{ assert_less_than::{AssertLessThanIo, AssertLtSubAir}, is_zero::{IsZeroIo, IsZeroSubAir}, @@ -19,9 +20,9 @@ use crate::system::memory::{ /// AUX_LEN is the number of auxiliary columns (aka the number of limbs that the input numbers will /// be decomposed into) for the `AssertLtSubAir` in the `MemoryOfflineChecker`. -/// Warning: This requires that (clk_max_bits + decomp - 1) / decomp = AUX_LEN +/// Warning: This requires that (timestamp_max_bits + decomp - 1) / decomp = AUX_LEN /// in MemoryOfflineChecker (or whenever AssertLtSubAir is used) -pub(crate) const AUX_LEN: usize = 2; +pub const AUX_LEN: usize = 2; /// The [MemoryBridge] is used within AIR evaluation functions to constrain logical memory /// operations (read/write). It adds all necessary constraints and interactions. @@ -34,14 +35,22 @@ impl MemoryBridge { /// Create a new [MemoryBridge] with the provided offline_checker. pub fn new( memory_bus: MemoryBus, - clk_max_bits: usize, + timestamp_max_bits: usize, range_bus: VariableRangeCheckerBus, ) -> Self { Self { - offline_checker: MemoryOfflineChecker::new(memory_bus, clk_max_bits, range_bus), + offline_checker: MemoryOfflineChecker::new(memory_bus, timestamp_max_bits, range_bus), } } + pub fn memory_bus(&self) -> MemoryBus { + self.offline_checker.memory_bus + } + + pub fn range_bus(&self) -> VariableRangeCheckerBus { + self.offline_checker.timestamp_lt_air.bus + } + /// Prepare a logical memory read operation. #[must_use] pub fn read<'a, T, V, const N: usize>( @@ -256,17 +265,23 @@ impl, const N: usize> MemoryWriteOperation<'_ } } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, CopyGetters)] struct MemoryOfflineChecker { + #[get_copy = "pub"] memory_bus: MemoryBus, + #[get_copy = "pub"] timestamp_lt_air: AssertLtSubAir, } impl MemoryOfflineChecker { - fn new(memory_bus: MemoryBus, clk_max_bits: usize, range_bus: VariableRangeCheckerBus) -> Self { + fn new( + memory_bus: MemoryBus, + timestamp_max_bits: usize, + range_bus: VariableRangeCheckerBus, + ) -> Self { Self { memory_bus, - timestamp_lt_air: AssertLtSubAir::new(range_bus, clk_max_bits), + timestamp_lt_air: AssertLtSubAir::new(range_bus, timestamp_max_bits), } } diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 5a27b3e433..ef9821f859 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -9,37 +9,28 @@ use crate::system::memory::offline_checker::bridge::AUX_LEN; // repr(C) is needed to make sure that the compiler does not reorder the fields // we assume the order of the fields when using borrow or borrow_mut -#[repr(C)] /// Base structure for auxiliary memory columns. +#[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryBaseAuxCols { /// The previous timestamps in which the cells were accessed. - pub(in crate::system::memory) prev_timestamp: T, + pub prev_timestamp: T, /// The auxiliary columns to perform the less than check. - pub(in crate::system::memory) timestamp_lt_aux: LessThanAuxCols, + pub timestamp_lt_aux: LessThanAuxCols, +} + +impl MemoryBaseAuxCols { + #[inline(always)] + pub fn set_prev(&mut self, prev_timestamp: F) { + self.prev_timestamp = prev_timestamp; + } } #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryWriteAuxCols { - pub(in crate::system::memory) base: MemoryBaseAuxCols, - pub(in crate::system::memory) prev_data: [T; N], -} - -impl MemoryWriteAuxCols { - pub(in crate::system::memory) fn new( - prev_data: [T; N], - prev_timestamp: T, - lt_aux: LessThanAuxCols, - ) -> Self { - Self { - base: MemoryBaseAuxCols { - prev_timestamp, - timestamp_lt_aux: lt_aux, - }, - prev_data, - } - } + pub base: MemoryBaseAuxCols, + pub prev_data: [T; N], } impl MemoryWriteAuxCols { @@ -47,13 +38,21 @@ impl MemoryWriteAuxCols { Self { base, prev_data } } + #[inline(always)] pub fn get_base(self) -> MemoryBaseAuxCols { self.base } + #[inline(always)] pub fn prev_data(&self) -> &[T; N] { &self.prev_data } + + /// Sets the previous data **without** updating the less than auxiliary columns. + #[inline(always)] + pub fn set_prev_data(&mut self, data: [T; N]) { + self.prev_data = data; + } } /// The auxiliary columns for a memory read operation with block size `N`. @@ -67,10 +66,7 @@ pub struct MemoryReadAuxCols { } impl MemoryReadAuxCols { - pub(in crate::system::memory) fn new( - prev_timestamp: u32, - timestamp_lt_aux: LessThanAuxCols, - ) -> Self { + pub fn new(prev_timestamp: u32, timestamp_lt_aux: LessThanAuxCols) -> Self { Self { base: MemoryBaseAuxCols { prev_timestamp: F::from_canonical_u32(prev_timestamp), @@ -79,17 +75,24 @@ impl MemoryReadAuxCols { } } + #[inline(always)] pub fn get_base(self) -> MemoryBaseAuxCols { self.base } + + /// Sets the previous timestamp **without** updating the less than auxiliary columns. + #[inline(always)] + pub fn set_prev(&mut self, timestamp: F) { + self.base.prev_timestamp = timestamp; + } } #[repr(C)] #[derive(Clone, Debug, AlignedBorrow)] pub struct MemoryReadOrImmediateAuxCols { - pub(crate) base: MemoryBaseAuxCols, - pub(crate) is_immediate: T, - pub(crate) is_zero_aux: T, + pub base: MemoryBaseAuxCols, + pub is_immediate: T, + pub is_zero_aux: T, } impl AsRef> for MemoryWriteAuxCols { @@ -102,3 +105,21 @@ impl AsRef> for MemoryWriteAuxCols unsafe { &*(self as *const MemoryWriteAuxCols as *const MemoryReadAuxCols) } } } + +impl AsMut> for MemoryWriteAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} + +impl AsMut> for MemoryReadAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} + +impl AsMut> for MemoryReadOrImmediateAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} diff --git a/crates/vm/src/system/memory/offline_checker/mod.rs b/crates/vm/src/system/memory/offline_checker/mod.rs index ac9f32dc18..8b15328185 100644 --- a/crates/vm/src/system/memory/offline_checker/mod.rs +++ b/crates/vm/src/system/memory/offline_checker/mod.rs @@ -5,3 +5,18 @@ mod columns; pub use bridge::*; pub use bus::*; pub use columns::*; + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MemoryReadAuxRecord { + pub prev_timestamp: u32, +} + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MemoryWriteAuxRecord { + pub prev_timestamp: u32, + pub prev_data: [T; NUM_LIMBS], +} + +pub type MemoryWriteBytesAuxRecord = MemoryWriteAuxRecord; diff --git a/crates/vm/src/system/memory/online.rs b/crates/vm/src/system/memory/online.rs index a5bf663e4c..7302d17827 100644 --- a/crates/vm/src/system/memory/online.rs +++ b/crates/vm/src/system/memory/online.rs @@ -1,151 +1,1110 @@ -use std::fmt::Debug; +use std::{array::from_fn, fmt::Debug, num::NonZero}; -use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; +use getset::Getters; +use itertools::zip_eq; +use openvm_instructions::exe::SparseMemoryImage; +use openvm_stark_backend::{ + p3_field::{Field, PrimeField32}, + p3_maybe_rayon::prelude::*, + p3_util::log2_strict_usize, +}; +use tracing::instrument; -use super::paged_vec::{AddressMap, PAGE_SIZE}; use crate::{ - arch::MemoryConfig, - system::memory::{offline::INITIAL_TIMESTAMP, MemoryImage, RecordId}, + arch::{ + AddressSpaceHostConfig, AddressSpaceHostLayout, DenseRecordArena, MemoryConfig, + RecordArena, MAX_CELL_BYTE_SIZE, + }, + system::{ + memory::{ + adapter::records::{AccessLayout, AccessRecordHeader, MERGE_AND_NOT_SPLIT_FLAG}, + MemoryAddress, TimestampedEquipartition, TimestampedValues, CHUNK, + }, + TouchedMemory, + }, + utils::slice_as_bytes, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum MemoryLogEntry { - Read { - address_space: u32, - pointer: u32, +mod basic; +#[cfg(any(unix, windows))] +mod memmap; +mod paged_vec; + +#[cfg(not(any(unix, windows)))] +pub use basic::*; +#[cfg(any(unix, windows))] +pub use memmap::*; +pub use paged_vec::PagedVec; + +#[cfg(all(any(unix, windows), not(feature = "basic-memory")))] +pub type MemoryBackend = memmap::MmapMemory; +#[cfg(any(not(any(unix, windows)), feature = "basic-memory"))] +pub type MemoryBackend = basic::BasicMemory; + +pub const INITIAL_TIMESTAMP: u32 = 0; +/// Default mmap page size. Change this if using THB. +pub const PAGE_SIZE: usize = 4096; + +// Memory access constraints +const MAX_BLOCK_SIZE: usize = 32; +const MIN_ALIGN: usize = 1; +const MAX_SEGMENTS: usize = MAX_BLOCK_SIZE / MIN_ALIGN; + +/// (address_space, pointer) +pub type Address = (u32, u32); + +/// API for any memory implementation that allocates a contiguous region of memory. +pub trait LinearMemory { + /// Create instance of `Self` with `size` bytes. + fn new(size: usize) -> Self; + /// Allocated size of the memory in bytes. + fn size(&self) -> usize; + /// Returns the entire memory as a raw byte slice. + fn as_slice(&self) -> &[u8]; + /// Returns the entire memory as a raw byte slice. + fn as_mut_slice(&mut self) -> &mut [u8]; + /// Read `BLOCK` from `self` at `from` address without moving it. + /// + /// Panics or segfaults if `from..from + size_of::()` is out of bounds. + /// + /// # Safety + /// - `BLOCK` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - See [`core::ptr::read`] for similar considerations. + /// - Memory at `from` must be properly aligned for `BLOCK`. Use [`Self::read_unaligned`] if + /// alignment is not guaranteed. + unsafe fn read(&self, from: usize) -> BLOCK; + /// Read `BLOCK` from `self` at `from` address without moving it. + /// Same as [`Self::read`] except that it does not require alignment. + /// + /// Panics or segfaults if `from..from + size_of::()` is out of bounds. + /// + /// # Safety + /// - `BLOCK` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - See [`core::ptr::read`] for similar considerations. + unsafe fn read_unaligned(&self, from: usize) -> BLOCK; + /// Write `BLOCK` to `self` at `start` address without reading the old value. Does not drop + /// `values`. Semantically, `values` is moved into the location pointed to by `start`. + /// + /// Panics or segfaults if `start..start + size_of::()` is out of bounds. + /// + /// # Safety + /// - See [`core::ptr::write`] for similar considerations. + /// - Memory at `start` must be properly aligned for `BLOCK`. Use [`Self::write_unaligned`] if + /// alignment is not guaranteed. + unsafe fn write(&mut self, start: usize, values: BLOCK); + /// Write `BLOCK` to `self` at `start` address without reading the old value. Does not drop + /// `values`. Semantically, `values` is moved into the location pointed to by `start`. + /// Same as [`Self::write`] but without alignment requirement. + /// + /// Panics or segfaults if `start..start + size_of::()` is out of bounds. + /// + /// # Safety + /// - See [`core::ptr::write`] for similar considerations. + unsafe fn write_unaligned(&mut self, start: usize, values: BLOCK); + /// Swaps `values` with memory at `start..start + size_of::()`. + /// + /// Panics or segfaults if `start..start + size_of::()` is out of bounds. + /// + /// # Safety + /// - `BLOCK` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - Memory at `start` must be properly aligned for `BLOCK`. + /// - The data in `values` should not overlap with memory in `self`. + unsafe fn swap(&mut self, start: usize, values: &mut BLOCK); + /// Copies `data` into memory at `to` address. + /// + /// Panics or segfaults if `to..to + size_of_val(data)` is out of bounds. + /// + /// # Safety + /// - `T` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - The underlying memory of `data` should not overlap with `self`. + /// - The starting pointer of `self` should be aligned to `T`. + /// - The memory pointer at `to` should be aligned to `T`. + unsafe fn copy_nonoverlapping(&mut self, to: usize, data: &[T]); + /// Returns a slice `&[T]` for the memory region `start..start + len`. + /// + /// Panics or segfaults if `start..start + len * size_of::()` is out of bounds. + /// + /// # Safety + /// - `T` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - Memory at `start` must be properly aligned for `T`. + unsafe fn get_aligned_slice(&self, start: usize, len: usize) -> &[T]; +} + +/// Map from address space to linear memory. +/// The underlying memory is typeless, stored as raw bytes, but usage implicitly assumes that each +/// address space has memory cells of a fixed type (e.g., `u8, F`). We do not use a typemap for +/// performance reasons, and it is up to the user to enforce types. Needless to say, this is a very +/// `unsafe` API. +#[derive(Debug, Clone)] +pub struct AddressMap { + /// Underlying memory data. + pub mem: Vec, + /// Host configuration for each address space. + pub config: Vec, +} + +impl Default for AddressMap { + fn default() -> Self { + Self::from_mem_config(&MemoryConfig::default()) + } +} + +impl AddressMap { + pub fn new(config: Vec) -> Self { + assert_eq!(config[0].num_cells, 0, "Address space 0 must have 0 cells"); + let mem = config + .iter() + .map(|config| M::new(config.num_cells.checked_mul(config.layout.size()).unwrap())) + .collect(); + Self { mem, config } + } + + pub fn from_mem_config(mem_config: &MemoryConfig) -> Self { + Self::new(mem_config.addr_spaces.clone()) + } + + #[inline(always)] + pub fn get_memory(&self) -> &Vec { + &self.mem + } + + #[inline(always)] + pub fn get_memory_mut(&mut self) -> &mut Vec { + &mut self.mem + } + + /// # Safety + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get_f(&self, addr_space: u32, ptr: u32) -> F { + let layout = &self.config.get_unchecked(addr_space as usize).layout; + let start = ptr as usize * layout.size(); + let bytes = self.get_u8_slice(addr_space, start, layout.size()); + layout.to_field(bytes) + } + + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get(&self, (addr_space, ptr): Address) -> T { + debug_assert_eq!( + size_of::(), + self.config[addr_space as usize].layout.size() + ); + // SAFETY: + // - alignment is automatic since we multiply by `size_of::()` + self.mem + .get_unchecked(addr_space as usize) + .read((ptr as usize) * size_of::()) + } + + /// Panics or segfaults if `ptr..ptr + len` is out of bounds + /// + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get_slice( + &self, + (addr_space, ptr): Address, len: usize, - }, - Write { - address_space: u32, - pointer: u32, - data: Vec, - }, - IncrementTimestampBy(u32), + ) -> &[T] { + debug_assert_eq!( + size_of::(), + self.config[addr_space as usize].layout.size() + ); + let start = (ptr as usize) * size_of::(); + let mem = self.mem.get_unchecked(addr_space as usize); + // SAFETY: + // - alignment is automatic since we multiply by `size_of::()` + mem.get_aligned_slice(start, len) + } + + /// Reads the slice at **byte** addresses `start..start + len` from address space `addr_space` + /// linear memory. Panics or segfaults if `start..start + len` is out of bounds + /// + /// # Safety + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get_u8_slice(&self, addr_space: u32, start: usize, len: usize) -> &[u8] { + let mem = self.mem.get_unchecked(addr_space as usize); + mem.get_aligned_slice(start, len) + } + + /// Copies `data` into the memory at `(addr_space, ptr)`. + /// + /// Panics or segfaults if `ptr + size_of_val(data)` is out of bounds. + /// + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - The linear memory in `addr_space` is aligned to `T`. + pub unsafe fn copy_slice_nonoverlapping( + &mut self, + (addr_space, ptr): Address, + data: &[T], + ) { + let start = (ptr as usize) * size_of::(); + // SAFETY: + // - Linear memory is aligned to `T` and `start` is multiple of `size_of::()` so + // alignment is satisfied. + // - `data` and `self.mem` are non-overlapping + self.mem + .get_unchecked_mut(addr_space as usize) + .copy_nonoverlapping(start, data); + } + + // TODO[jpw]: stabilize the boundary memory image format and how to construct + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub fn from_sparse(config: Vec, sparse_map: SparseMemoryImage) -> Self { + let mut vec = Self::new(config); + for ((addr_space, index), data_byte) in sparse_map.into_iter() { + // SAFETY: + // - safety assumptions in function doc comments + unsafe { + vec.mem + .get_unchecked_mut(addr_space as usize) + .write_unaligned(index as usize, data_byte); + } + } + vec + } } -/// A simple data structure to read to/write from memory. -/// -/// Stores a log of memory accesses to reconstruct aspects of memory state for trace generation. -#[derive(Debug)] -pub struct Memory { - pub(super) data: AddressMap, - pub(super) log: Vec>, - timestamp: u32, +/// API for guest memory conforming to OpenVM ISA +// @dev Note we don't make this a trait because phantom executors currently need a concrete type for +// guest memory +#[derive(Debug, Clone)] +pub struct GuestMemory { + pub memory: AddressMap, } -impl Memory { - pub fn new(mem_config: &MemoryConfig) -> Self { +impl GuestMemory { + pub fn new(addr: AddressMap) -> Self { + Self { memory: addr } + } + + /// Returns `[pointer:BLOCK_SIZE]_{address_space}` + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + #[inline(always)] + pub unsafe fn read( + &self, + addr_space: u32, + ptr: u32, + ) -> [T; BLOCK_SIZE] + where + T: Copy + Debug, + { + self.debug_assert_cell_type::(addr_space); + // SAFETY: + // - `T` should be "plain old data" + // - alignment for `[T; BLOCK_SIZE]` is automatic since we multiply by `size_of::()` + self.memory + .get_memory() + .get_unchecked(addr_space as usize) + .read((ptr as usize) * size_of::()) + } + + /// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}` + /// + /// # Safety + /// See [`GuestMemory::read`]. + #[inline(always)] + pub unsafe fn write( + &mut self, + addr_space: u32, + ptr: u32, + values: [T; BLOCK_SIZE], + ) where + T: Copy + Debug, + { + self.debug_assert_cell_type::(addr_space); + // SAFETY: + // - alignment for `[T; BLOCK_SIZE]` is automatic since we multiply by `size_of::()` + self.memory + .get_memory_mut() + .get_unchecked_mut(addr_space as usize) + .write((ptr as usize) * size_of::(), values); + } + + /// Swaps `values` with `[pointer:BLOCK_SIZE]_{address_space}`. + /// + /// # Safety + /// See [`GuestMemory::read`] and [`LinearMemory::swap`]. + #[inline(always)] + pub unsafe fn swap( + &mut self, + addr_space: u32, + ptr: u32, + values: &mut [T; BLOCK_SIZE], + ) where + T: Copy + Debug, + { + self.debug_assert_cell_type::(addr_space); + // SAFETY: + // - alignment for `[T; BLOCK_SIZE]` is automatic since we multiply by `size_of::()` + self.memory + .get_memory_mut() + .get_unchecked_mut(addr_space as usize) + .swap((ptr as usize) * size_of::(), values); + } + + #[inline(always)] + #[allow(clippy::missing_safety_doc)] + pub unsafe fn get_slice(&self, addr_space: u32, ptr: u32, len: usize) -> &[T] { + self.memory.get_slice((addr_space, ptr), len) + } + + #[inline(always)] + fn debug_assert_cell_type(&self, addr_space: u32) { + debug_assert_eq!( + size_of::(), + self.memory.config[addr_space as usize].layout.size() + ); + } +} + +#[repr(C)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)] +pub struct AccessMetadata { + /// Packed timestamp (29 bits) and log2(block_size) (3 bits) + pub timestamp_and_log_block_size: u32, + /// Offset to block start (in ALIGN units). + pub offset_to_start: u8, +} + +impl AccessMetadata { + const TIMESTAMP_MASK: u32 = (1 << 29) - 1; + const LOG_BLOCK_SIZE_SHIFT: u32 = 29; + + pub fn new(timestamp: u32, block_size: u8, offset_to_start: u8) -> Self { + debug_assert!(timestamp < (1 << 29), "Timestamp must be less than 2^29"); + debug_assert!( + block_size == 0 || (block_size.is_power_of_two() && block_size <= MAX_BLOCK_SIZE as u8), + "Block size must be 0 or power of 2 and <= {}", + MAX_BLOCK_SIZE + ); + + let encoded_block_size = if block_size == 0 { + 0 + } else { + // SAFETY: We already asserted that block_size is non-zero in this branch + unsafe { NonZero::new_unchecked(block_size) }.ilog2() + 1 + }; + let packed = timestamp | (encoded_block_size << Self::LOG_BLOCK_SIZE_SHIFT); + Self { - data: AddressMap::from_mem_config(mem_config), - timestamp: INITIAL_TIMESTAMP + 1, - log: Vec::with_capacity(mem_config.access_capacity), + timestamp_and_log_block_size: packed, + offset_to_start, } } - /// Instantiates a new `Memory` data structure from an image. - pub fn from_image(image: MemoryImage, access_capacity: usize) -> Self { + pub fn timestamp(&self) -> u32 { + self.timestamp_and_log_block_size & Self::TIMESTAMP_MASK + } + + pub fn block_size(&self) -> u8 { + let encoded = self.timestamp_and_log_block_size >> Self::LOG_BLOCK_SIZE_SHIFT; + if encoded == 0 { + 0 + } else { + 1 << (encoded - 1) + } + } +} + +/// Online memory that stores additional information for trace generation purposes. +/// In particular, keeps track of timestamp. +#[derive(Getters)] +pub struct TracingMemory { + pub timestamp: u32, + /// The initial block size -- this depends on the type of boundary chip. + initial_block_size: usize, + /// The underlying data memory, with memory cells typed by address space: see [AddressMap]. + #[getset(get = "pub")] + pub data: GuestMemory, + /// Maps addr_space to (ptr / min_block_size[addr_space] -> AccessMetadata) for latest access + /// metadata. Uses paged storage for memory efficiency. AccessMetadata stores offset_to_start + /// (in ALIGN units), block_size, and timestamp (latter two only valid at offset_to_start == + /// 0). + pub(super) meta: Vec>, + /// For each `addr_space`, the minimum block size allowed for memory accesses. In other words, + /// all memory accesses in `addr_space` must be aligned to this block size. + pub min_block_size: Vec, + pub access_adapter_records: DenseRecordArena, +} + +// min_block_size * cell_size never exceeds 8 +const INITIAL_CELL_BUFFER: &[u8] = &[0u8; 8]; +// min_block_size never exceeds 8 +const INITIAL_TIMESTAMP_BUFFER: &[u32] = &[INITIAL_TIMESTAMP; 8]; + +impl TracingMemory { + pub fn new( + mem_config: &MemoryConfig, + initial_block_size: usize, + access_adapter_arena_size_bound: usize, + ) -> Self { + let image = GuestMemory::new(AddressMap::from_mem_config(mem_config)); + Self::from_image(image, initial_block_size, access_adapter_arena_size_bound) + } + + /// Constructor from pre-existing memory image. + pub fn from_image( + image: GuestMemory, + initial_block_size: usize, + access_adapter_arena_size_bound: usize, + ) -> Self { + let (meta, min_block_size): (Vec<_>, Vec<_>) = + zip_eq(image.memory.get_memory(), &image.memory.config) + .map(|(mem, addr_sp)| { + let num_cells = mem.size() / addr_sp.layout.size(); + let min_block_size = addr_sp.min_block_size; + let total_metadata_len = num_cells.div_ceil(min_block_size); + (PagedVec::new(total_metadata_len), min_block_size as u32) + }) + .unzip(); + let access_adapter_records = + DenseRecordArena::with_byte_capacity(access_adapter_arena_size_bound); Self { data: image, + meta, + min_block_size, timestamp: INITIAL_TIMESTAMP + 1, - log: Vec::with_capacity(access_capacity), + initial_block_size, + access_adapter_records, } } - fn last_record_id(&self) -> RecordId { - RecordId(self.log.len() - 1) + #[inline(always)] + fn assert_alignment(&self, block_size: usize, align: usize, addr_space: u32, ptr: u32) { + debug_assert!(block_size.is_power_of_two()); + debug_assert_eq!(block_size % align, 0); + debug_assert_ne!(addr_space, 0); + debug_assert_eq!(align as u32, self.min_block_size[addr_space as usize]); + assert_eq!( + ptr % (align as u32), + 0, + "pointer={ptr} not aligned to {align}" + ); } - /// Writes an array of values to the memory at the specified address space and start index. - /// - /// Returns the `RecordId` for the memory record and the previous data. - pub fn write( + /// Get block metadata by jumping to the start of the block. + /// Returns (block_start_pointer, block_metadata). + #[inline(always)] + fn get_block_metadata( &mut self, - address_space: u32, - pointer: u32, - values: [F; N], - ) -> (RecordId, [F; N]) { - assert!(N.is_power_of_two()); + address_space: usize, + pointer: usize, + ) -> (u32, AccessMetadata) { + let ptr_index = pointer / ALIGN; + let meta_page = unsafe { self.meta.get_unchecked_mut(address_space) }; + let current_meta = meta_page.get(ptr_index); - let prev_data = self.data.set_range(&(address_space, pointer), &values); + let (block_start_index, block_metadata) = if current_meta.offset_to_start == 0 { + (ptr_index, current_meta) + } else { + let offset = current_meta.offset_to_start; + let start_idx = ptr_index - offset as usize; + let start_meta = meta_page.get(start_idx); + (start_idx, start_meta) + }; - self.log.push(MemoryLogEntry::Write { - address_space, - pointer, - data: values.to_vec(), - }); - self.timestamp += 1; + let block_start_pointer = (block_start_index * ALIGN) as u32; - (self.last_record_id(), prev_data) + (block_start_pointer, block_metadata) } - /// Reads an array of values from the memory at the specified address space and start index. - pub fn read(&mut self, address_space: u32, pointer: u32) -> (RecordId, [F; N]) { - assert!(N.is_power_of_two()); + #[inline(always)] + fn get_timestamp(&mut self, address_space: usize, pointer: usize) -> u32 { + let ptr_index = pointer / ALIGN; + let meta_page = unsafe { self.meta.get_unchecked_mut(address_space) }; + let current_meta = meta_page.get(ptr_index); + + if current_meta.offset_to_start == 0 { + current_meta.timestamp() + } else { + let offset = current_meta.offset_to_start; + let block_start_index = ptr_index - offset as usize; + meta_page.get(block_start_index).timestamp() + } + } + + /// Updates the metadata with the given block. + /// Stores timestamp and block_size only at block start, offsets elsewhere. + #[inline(always)] + fn set_meta_block( + &mut self, + address_space: usize, + pointer: usize, + timestamp: u32, + ) { + let ptr = pointer / ALIGN; + // SAFETY: address_space is assumed to be valid and within bounds + let meta_page = unsafe { self.meta.get_unchecked_mut(address_space) }; + + // Store full metadata at the block start + meta_page.set(ptr, AccessMetadata::new(timestamp, BLOCK_SIZE as u8, 0)); + + // Store offsets for other positions in the block + for i in 1..(BLOCK_SIZE / ALIGN) { + meta_page.set(ptr + i, AccessMetadata::new(0, 0, i as u8)); + } + } + + pub(crate) fn add_split_record(&mut self, header: AccessRecordHeader) { + if header.block_size == header.lowest_block_size { + return; + } + let data_slice = unsafe { + self.data.memory.get_u8_slice( + header.address_space, + (header.pointer * header.type_size) as usize, + (header.block_size * header.type_size) as usize, + ) + }; + + let record_mut = self + .access_adapter_records + .alloc(AccessLayout::from_record_header(&header)); + *record_mut.header = header; + record_mut.data.copy_from_slice(data_slice); + // we don't mind garbage values in prev_* + } - self.log.push(MemoryLogEntry::Read { - address_space, - pointer, - len: N, + /// `data_slice` is the underlying data of the record in raw host memory format. + pub(crate) fn add_merge_record( + &mut self, + header: AccessRecordHeader, + data_slice: &[u8], + prev_ts: &[u32], + ) { + if header.block_size == header.lowest_block_size { + return; + } + + let record_mut = self + .access_adapter_records + .alloc(AccessLayout::from_record_header(&header)); + *record_mut.header = header; + record_mut.header.timestamp_and_mask |= MERGE_AND_NOT_SPLIT_FLAG; + record_mut.data.copy_from_slice(data_slice); + record_mut.timestamps.copy_from_slice(prev_ts); + } + + /// Calculate splits and merges needed for a memory access. + /// Returns Some((splits, merge)) or None if no operations needed. + #[inline(always)] + #[allow(clippy::type_complexity)] + fn calculate_splits_and_merges( + &mut self, + address_space: usize, + pointer: usize, + ) -> Option<(Vec<(usize, usize)>, (usize, usize))> { + // Skip adapters if this is a repeated access to the same location with same size + let (start_ptr, block_meta) = self.get_block_metadata::(address_space, pointer); + if block_meta.block_size() == BLOCK_SIZE as u8 && start_ptr == pointer as u32 { + return None; + } + + // Split intersecting blocks to align bytes + let mut splits_buf = [(0usize, 0usize); MAX_SEGMENTS]; + let mut splits_count = 0; + let mut current_ptr = pointer; + let end_ptr = pointer + BLOCK_SIZE; + + while current_ptr < end_ptr { + let (start_ptr, block_metadata) = + self.get_block_metadata::(address_space, current_ptr); + + if block_metadata.block_size() == 0 { + current_ptr += ALIGN; + continue; + } + + if block_metadata.block_size() > ALIGN as u8 { + // SAFETY: splits_count < MAX_SEGMENTS by construction since we iterate over + // at most BLOCK_SIZE/ALIGN segments and BLOCK_SIZE <= MAX_BLOCK_SIZE + unsafe { + *splits_buf.get_unchecked_mut(splits_count) = + (start_ptr as usize, block_metadata.block_size() as usize); + } + splits_count += 1; + } + + // Skip to the next segment after this block ends + current_ptr = start_ptr as usize + block_metadata.block_size() as usize; + } + + let merge = (pointer, BLOCK_SIZE); + + Some((splits_buf[..splits_count].to_vec(), merge)) + } + + #[inline(always)] + fn split_by_meta( + &mut self, + start_ptr: u32, + timestamp: u32, + block_size: u8, + address_space: usize, + ) { + if block_size == MIN_BLOCK_SIZE as u8 { + return; + } + let begin = start_ptr as usize / MIN_BLOCK_SIZE; + let meta_page = unsafe { self.meta.get_unchecked_mut(address_space) }; + + for i in 0..(block_size as usize / MIN_BLOCK_SIZE) { + // Each split piece becomes its own block start + meta_page.set( + begin + i, + AccessMetadata::new(timestamp, MIN_BLOCK_SIZE as u8, 0), + ); + } + self.add_split_record(AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: address_space as u32, + pointer: start_ptr, + block_size: block_size as u32, + lowest_block_size: MIN_BLOCK_SIZE as u32, + type_size: size_of::() as u32, }); + } - let values = if address_space == 0 { - assert_eq!(N, 1, "cannot batch read from address space 0"); - [F::from_canonical_u32(pointer); N] + /// Returns the timestamp of the previous access to `[pointer:BLOCK_SIZE]_{address_space}`. + /// + /// Caller must ensure alignment (e.g. via `assert_alignment`) prior to calling this function. + #[inline(always)] + fn prev_access_time( + &mut self, + address_space: usize, + pointer: usize, + prev_values: &[T; BLOCK_SIZE], + ) -> u32 { + debug_assert_eq!(ALIGN, self.data.memory.config[address_space].min_block_size); + debug_assert_eq!( + unsafe { + self.data + .memory + .config + .get_unchecked(address_space) + .layout + .size() + }, + size_of::() + ); + // Calculate what splits and merges are needed for this memory access + let result = if let Some((splits, (merge_ptr, merge_size))) = + self.calculate_splits_and_merges::(address_space, pointer) + { + // Process all splits first + for (split_ptr, split_size) in splits { + let (_, block_metadata) = + self.get_block_metadata::(address_space, split_ptr); + let timestamp = block_metadata.timestamp(); + self.split_by_meta::( + split_ptr as u32, + timestamp, + split_size as u8, + address_space, + ); + } + + // Process merge + let mut prev_ts_buf = [0u32; MAX_SEGMENTS]; + + let mut max_timestamp = INITIAL_TIMESTAMP; + + let mut ptr = merge_ptr; + let end_ptr = merge_ptr + merge_size; + let mut seg_idx = 0; + while ptr < end_ptr { + let (_, block_metadata) = self.get_block_metadata::(address_space, ptr); + + let timestamp = if block_metadata.block_size() > 0 { + block_metadata.timestamp() + } else { + self.handle_uninitialized_memory::(address_space, ptr); + INITIAL_TIMESTAMP + }; + + // SAFETY: seg_idx < MAX_SEGMENTS since we iterate at most merge_size/ALIGN times + // and merge_size <= BLOCK_SIZE <= MAX_BLOCK_SIZE + unsafe { + *prev_ts_buf.get_unchecked_mut(seg_idx) = timestamp; + } + max_timestamp = max_timestamp.max(timestamp); + ptr += ALIGN; + seg_idx += 1; + } + + // Create the merge record + self.add_merge_record( + AccessRecordHeader { + timestamp_and_mask: max_timestamp, + address_space: address_space as u32, + pointer: merge_ptr as u32, + block_size: merge_size as u32, + lowest_block_size: ALIGN as u32, + type_size: size_of::() as u32, + }, + // SAFETY: T is plain old data + unsafe { slice_as_bytes(prev_values) }, + &prev_ts_buf[..seg_idx], + ); + + max_timestamp } else { - self.range_array::(address_space, pointer) + self.get_timestamp::(address_space, pointer) }; + + // Update the metadata for this access + self.set_meta_block::(address_space, pointer, self.timestamp); + result + } + + /// Handle uninitialized memory by creating appropriate split or merge records. + #[inline(always)] + fn handle_uninitialized_memory( + &mut self, + address_space: usize, + pointer: usize, + ) { + if self.initial_block_size >= ALIGN { + // Split the initial block into chunks + let segment_index = pointer / ALIGN; + let block_start = segment_index & !(self.initial_block_size / ALIGN - 1); + let start_ptr = (block_start * ALIGN) as u32; + self.split_by_meta::( + start_ptr, + INITIAL_TIMESTAMP, + self.initial_block_size as u8, + address_space, + ); + } else { + // Create a merge record for single-byte initialization + debug_assert_eq!(self.initial_block_size, 1); + self.add_merge_record( + AccessRecordHeader { + timestamp_and_mask: INITIAL_TIMESTAMP, + address_space: address_space as u32, + pointer: pointer as u32, + block_size: ALIGN as u32, + lowest_block_size: self.initial_block_size as u32, + type_size: size_of::() as u32, + }, + &INITIAL_CELL_BUFFER[..ALIGN], + &INITIAL_TIMESTAMP_BUFFER[..ALIGN], + ); + } + } + + /// Atomic read operation which increments the timestamp by 1. + /// Returns `(t_prev, [pointer:BLOCK_SIZE]_{address_space})` where `t_prev` is the + /// timestamp of the last memory access. + /// + /// The previous memory access is treated as atomic even if previous accesses were for + /// a smaller block size. This is made possible by internal memory access adapters + /// that split/merge memory blocks. More specifically, the last memory access corresponding + /// to `t_prev` may refer to an atomic access inserted by the memory access adapters. + /// + /// # Assumptions + /// The `BLOCK_SIZE` is a multiple of `ALIGN`, which must equal the minimum block size + /// of `address_space`. + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// plain old data, and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + /// + /// In addition: + /// - `address_space` must be valid. + #[inline(always)] + pub unsafe fn read( + &mut self, + address_space: u32, + pointer: u32, + ) -> (u32, [T; BLOCK_SIZE]) + where + T: Copy + Debug, + { + self.assert_alignment(BLOCK_SIZE, ALIGN, address_space, pointer); + let values = self.data.read(address_space, pointer); + let t_prev = self.prev_access_time::( + address_space as usize, + pointer as usize, + &values, + ); + self.timestamp += 1; + + (t_prev, values) + } + + /// Atomic write operation that writes `values` into `[pointer:BLOCK_SIZE]_{address_space}` and + /// then increments the timestamp by 1. Returns `(t_prev, values_prev)` which equal the + /// timestamp and value `[pointer:BLOCK_SIZE]_{address_space}` of the last memory access. + /// + /// The previous memory access is treated as atomic even if previous accesses were for + /// a smaller block size. This is made possible by internal memory access adapters + /// that split/merge memory blocks. More specifically, the last memory access corresponding + /// to `t_prev` may refer to an atomic access inserted by the memory access adapters. + /// + /// # Assumptions + /// The `BLOCK_SIZE` is a multiple of `ALIGN`, which must equal the minimum block size + /// of `address_space`. + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + /// + /// In addition: + /// - `address_space` must be valid. + #[inline(always)] + pub unsafe fn write( + &mut self, + address_space: u32, + pointer: u32, + values: [T; BLOCK_SIZE], + ) -> (u32, [T; BLOCK_SIZE]) + where + T: Copy + Debug, + { + self.assert_alignment(BLOCK_SIZE, ALIGN, address_space, pointer); + let values_prev = self.data.read(address_space, pointer); + let t_prev = self.prev_access_time::( + address_space as usize, + pointer as usize, + &values_prev, + ); + self.data.write(address_space, pointer, values); + self.timestamp += 1; + + (t_prev, values_prev) + } + + pub fn increment_timestamp(&mut self) { self.timestamp += 1; - (self.last_record_id(), values) } pub fn increment_timestamp_by(&mut self, amount: u32) { self.timestamp += amount; - self.log.push(MemoryLogEntry::IncrementTimestampBy(amount)) } pub fn timestamp(&self) -> u32 { self.timestamp } - #[inline(always)] - pub fn get(&self, address_space: u32, pointer: u32) -> F { - *self.data.get(&(address_space, pointer)).unwrap_or(&F::ZERO) + /// Finalize the boundary and merkle chips. + #[instrument(name = "memory_finalize", skip_all)] + pub fn finalize(&mut self, is_persistent: bool) -> TouchedMemory { + let touched_blocks = self.touched_blocks(); + + match is_persistent { + false => TouchedMemory::Volatile( + self.touched_blocks_to_equipartition::(touched_blocks), + ), + true => TouchedMemory::Persistent( + self.touched_blocks_to_equipartition::(touched_blocks), + ), + } } - #[inline(always)] - fn range_array(&self, address_space: u32, pointer: u32) -> [F; N] { - self.data.get_range(&(address_space, pointer)) + /// Returns the list of all touched blocks. The list is sorted by address. + fn touched_blocks(&self) -> Vec<(Address, AccessMetadata)> { + assert_eq!(self.meta.len(), self.min_block_size.len()); + self.meta + .par_iter() + .zip(self.min_block_size.par_iter()) + .enumerate() + .flat_map(|(addr_space, (meta_page, &align))| { + meta_page + .par_iter() + .filter_map(move |(idx, metadata)| { + let ptr = idx as u32 * align; + if metadata.offset_to_start == 0 && metadata.block_size() != 0 { + Some(((addr_space as u32, ptr), metadata)) + } else { + None + } + }) + .collect::>() + }) + .collect() } -} -#[cfg(test)] -mod tests { - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; + /// Returns the equipartition of the touched blocks. + /// Modifies records and adds new to account for the initial/final segments. + fn touched_blocks_to_equipartition( + &mut self, + touched_blocks: Vec<((u32, u32), AccessMetadata)>, + ) -> TimestampedEquipartition { + // [perf] We can `.with_capacity()` if we keep track of the number of segments we initialize + let mut final_memory = Vec::new(); - use super::Memory; - use crate::arch::MemoryConfig; + debug_assert!(touched_blocks.is_sorted_by_key(|(addr, _)| addr)); + self.handle_touched_blocks::(&mut final_memory, touched_blocks); - macro_rules! bba { - [$($x:expr),*] => { - [$(BabyBear::from_canonical_u32($x)),*] - } + debug_assert!(final_memory.is_sorted_by_key(|(key, _)| *key)); + final_memory } - #[test] - fn test_write_read() { - let mut memory = Memory::new(&MemoryConfig::default()); - let address_space = 1; - - memory.write(address_space, 0, bba![1, 2, 3, 4]); + fn handle_touched_blocks( + &mut self, + final_memory: &mut Vec<((u32, u32), TimestampedValues)>, + touched_blocks: Vec<((u32, u32), AccessMetadata)>, + ) { + let mut current_values = vec![0u8; MAX_CELL_BYTE_SIZE * CHUNK]; + let mut current_cnt = 0; + let mut current_address = MemoryAddress::new(0, 0); + let mut current_timestamps = vec![0; CHUNK]; + for ((addr_space, ptr), access_metadata) in touched_blocks { + // SAFETY: addr_space of touched blocks are all in bounds + let addr_space_config = + unsafe { *self.data.memory.config.get_unchecked(addr_space as usize) }; + let min_block_size = addr_space_config.min_block_size; + let cell_size = addr_space_config.layout.size(); + let timestamp = access_metadata.timestamp(); + let block_size = access_metadata.block_size(); + assert!( + current_cnt == 0 + || (current_address.address_space == addr_space + && current_address.pointer + current_cnt as u32 == ptr), + "The union of all touched blocks must consist of blocks with sizes divisible by `CHUNK`" + ); + debug_assert!(block_size >= min_block_size as u8); + debug_assert!(ptr % min_block_size as u32 == 0); - let (_, data) = memory.read::<2>(address_space, 0); - assert_eq!(data, bba![1, 2]); + if current_cnt == 0 { + assert_eq!( + ptr & (CHUNK as u32 - 1), + 0, + "The union of all touched blocks must consist of `CHUNK`-aligned blocks" + ); + current_address = MemoryAddress::new(addr_space, ptr); + } - memory.write(address_space, 2, bba![100]); + if block_size > min_block_size as u8 { + self.add_split_record(AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: addr_space, + pointer: ptr, + block_size: block_size as u32, + lowest_block_size: min_block_size as u32, + type_size: cell_size as u32, + }); + } + if min_block_size > CHUNK { + assert_eq!(current_cnt, 0); + for i in (0..block_size as u32).step_by(min_block_size) { + self.add_split_record(AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: addr_space, + pointer: ptr + i, + block_size: min_block_size as u32, + lowest_block_size: CHUNK as u32, + type_size: cell_size as u32, + }); + } + // SAFETY: touched blocks are in bounds + let values = unsafe { + self.data.memory.get_u8_slice( + addr_space, + ptr as usize * cell_size, + block_size as usize * cell_size, + ) + }; + for i in (0..block_size as u32).step_by(CHUNK) { + final_memory.push(( + (addr_space, ptr + i), + TimestampedValues { + timestamp, + values: from_fn(|j| { + let byte_idx = (i as usize + j) * cell_size; + // SAFETY: block_size is multiple of CHUNK and we are reading chunks + // of cells within bounds + unsafe { + addr_space_config + .layout + .to_field(&values[byte_idx..byte_idx + cell_size]) + } + }), + }, + )); + } + } else { + for i in 0..block_size as u32 { + // SAFETY: getting cell data + let cell_data = unsafe { + self.data.memory.get_u8_slice( + addr_space, + (ptr + i) as usize * cell_size, + cell_size, + ) + }; + current_values[current_cnt * cell_size..current_cnt * cell_size + cell_size] + .copy_from_slice(cell_data); + if current_cnt & (min_block_size - 1) == 0 { + // SAFETY: current_cnt / min_block_size < CHUNK / min_block_size <= CHUNK + unsafe { + *current_timestamps.get_unchecked_mut(current_cnt / min_block_size) = + timestamp; + } + } + current_cnt += 1; + if current_cnt == CHUNK { + let timestamp = *current_timestamps[..CHUNK / min_block_size] + .iter() + .max() + .unwrap(); + self.add_merge_record( + AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: addr_space, + pointer: current_address.pointer, + block_size: CHUNK as u32, + lowest_block_size: min_block_size as u32, + type_size: cell_size as u32, + }, + ¤t_values[..CHUNK * cell_size], + ¤t_timestamps[..CHUNK / min_block_size], + ); + final_memory.push(( + (current_address.address_space, current_address.pointer), + TimestampedValues { + timestamp, + values: from_fn(|i| unsafe { + // SAFETY: cell_size is correct, and alignment is guaranteed + addr_space_config.layout.to_field( + ¤t_values[i * cell_size..i * cell_size + cell_size], + ) + }), + }, + )); + current_address.pointer += current_cnt as u32; + current_cnt = 0; + } + } + } + } + assert_eq!(current_cnt, 0, "The union of all touched blocks must consist of blocks with sizes divisible by `CHUNK`"); + } - let (_, data) = memory.read::<4>(address_space, 0); - assert_eq!(data, bba![1, 2, 100, 4]); + pub fn address_space_alignment(&self) -> Vec { + self.min_block_size + .iter() + .map(|&x| log2_strict_usize(x as usize) as u8) + .collect() } } diff --git a/crates/vm/src/system/memory/online/basic.rs b/crates/vm/src/system/memory/online/basic.rs new file mode 100644 index 0000000000..b5cddeb775 --- /dev/null +++ b/crates/vm/src/system/memory/online/basic.rs @@ -0,0 +1,243 @@ +use std::{ + alloc::{alloc_zeroed, dealloc, Layout}, + ptr::NonNull, +}; + +use crate::system::memory::online::{LinearMemory, PAGE_SIZE}; + +pub struct BasicMemory { + ptr: NonNull, + size: usize, + layout: Layout, +} + +impl BasicMemory { + #[inline(always)] + pub fn as_ptr(&self) -> *const u8 { + self.ptr.as_ptr() + } + + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.ptr.as_ptr() + } +} + +impl Drop for BasicMemory { + fn drop(&mut self) { + if self.size > 0 { + unsafe { + dealloc(self.ptr.as_ptr(), self.layout); + } + } + } +} + +impl Clone for BasicMemory { + fn clone(&self) -> Self { + if self.size == 0 { + // Ensure we maintain the same aligned pointer for zero-size + let aligned_ptr = PAGE_SIZE as *mut u8; + let ptr = unsafe { NonNull::new_unchecked(aligned_ptr) }; + return Self { + ptr, + size: 0, + layout: self.layout, + }; + } + + let layout = self.layout; + let ptr = unsafe { + let new_ptr = alloc_zeroed(layout); + if new_ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + std::ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr, self.size); + NonNull::new_unchecked(new_ptr) + }; + Self { + ptr, + size: self.size, + layout, + } + } +} + +impl std::fmt::Debug for BasicMemory { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BasicMemory") + .field("size", &self.size) + .field("alignment", &self.layout.align()) + .finish() + } +} + +impl LinearMemory for BasicMemory { + fn new(size: usize) -> Self { + if size == 0 { + // For zero-size allocation, use a dangling pointer with proper alignment + // We need to ensure the pointer is aligned to PAGE_SIZE + let aligned_ptr = PAGE_SIZE as *mut u8; + let ptr = unsafe { NonNull::new_unchecked(aligned_ptr) }; + let layout = Layout::from_size_align(0, PAGE_SIZE) + .expect("Failed to create layout with PAGE_SIZE alignment"); + return Self { + ptr, + size: 0, + layout, + }; + } + + // Use PAGE_SIZE alignment for consistency with MmapMemory + // This also ensures good alignment for any type we might store + let layout = Layout::from_size_align(size, PAGE_SIZE) + .expect("Failed to create layout with PAGE_SIZE alignment"); + + let ptr = unsafe { + let raw_ptr = alloc_zeroed(layout); + if raw_ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + NonNull::new_unchecked(raw_ptr) + }; + + Self { ptr, size, layout } + } + + fn size(&self) -> usize { + self.size + } + + fn as_slice(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) } + } + + fn as_mut_slice(&mut self) -> &mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) } + } + + #[inline(always)] + unsafe fn read(&self, from: usize) -> BLOCK { + let size = std::mem::size_of::(); + assert!( + from + size <= self.size, + "read from={from} of size={size} out of bounds: memory size={}", + self.size + ); + + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - Bounds check is done via assert above + // - We assume `src` is aligned to `BLOCK` + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read(src) + } + + #[inline(always)] + unsafe fn read_unaligned(&self, from: usize) -> BLOCK { + let size = std::mem::size_of::(); + assert!( + from + size <= self.size, + "read_unaligned from={from} of size={size} out of bounds: memory size={}", + self.size + ); + + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - Bounds check is done via assert above + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read_unaligned(src) + } + + #[inline(always)] + unsafe fn write(&mut self, start: usize, values: BLOCK) { + let size = std::mem::size_of::(); + assert!( + start + size <= self.size, + "write start={start} of size={size} out of bounds: memory size={}", + self.size + ); + + let dst = self.as_mut_ptr().add(start) as *mut BLOCK; + // SAFETY: + // - Bounds check is done via assert above + // - We assume `dst` is aligned to `BLOCK` + core::ptr::write(dst, values); + } + + #[inline(always)] + unsafe fn write_unaligned(&mut self, start: usize, values: BLOCK) { + let size = std::mem::size_of::(); + assert!( + start + size <= self.size, + "write_unaligned start={start} of size={size} out of bounds: memory size={}", + self.size + ); + + // Use slice's copy_from_slice for safe byte-level copy + let src_bytes = std::slice::from_raw_parts(&values as *const BLOCK as *const u8, size); + self.as_mut_slice()[start..start + size].copy_from_slice(src_bytes); + } + + #[inline(always)] + unsafe fn swap(&mut self, start: usize, values: &mut BLOCK) { + let size = std::mem::size_of::(); + assert!( + start + size <= self.size, + "swap start={start} of size={size} out of bounds: memory size={}", + self.size + ); + + // SAFETY: + // - Bounds check is done via assert above + // - We assume `start` is aligned to `BLOCK` + core::ptr::swap( + self.as_mut_ptr().add(start) as *mut BLOCK, + values as *mut BLOCK, + ); + } + + #[inline(always)] + unsafe fn copy_nonoverlapping(&mut self, to: usize, data: &[T]) { + let byte_len = std::mem::size_of_val(data); + assert!( + to + byte_len <= self.size, + "copy_nonoverlapping to={to} of size={byte_len} out of bounds: memory size={}", + self.size + ); + + // Use slice's copy_from_slice for safe byte-level copy + let src_bytes = std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_len); + self.as_mut_slice()[to..to + byte_len].copy_from_slice(src_bytes); + } + + #[inline(always)] + unsafe fn get_aligned_slice(&self, start: usize, len: usize) -> &[T] { + let byte_len = len * std::mem::size_of::(); + assert!( + start + byte_len <= self.size, + "get_aligned_slice start={start} of size={byte_len} out of bounds: memory size={}", + self.size + ); + assert!( + start % std::mem::align_of::() == 0, + "get_aligned_slice: misaligned start" + ); + + let data = self.as_ptr().add(start) as *const T; + // SAFETY: + // - Bounds check is done via assert above + // - Alignment check is done via assert above + // - `T` is "plain old data" (POD), so conversion from underlying bytes is properly + // initialized + // - `self` will not be mutated while borrowed + core::slice::from_raw_parts(data, len) + } +} + +// SAFETY: BasicMemory properly manages its allocation and can be sent between threads +unsafe impl Send for BasicMemory {} +// SAFETY: BasicMemory has no interior mutability and can be shared between threads +unsafe impl Sync for BasicMemory {} diff --git a/crates/vm/src/system/memory/online/memmap.rs b/crates/vm/src/system/memory/online/memmap.rs new file mode 100644 index 0000000000..3b2155906a --- /dev/null +++ b/crates/vm/src/system/memory/online/memmap.rs @@ -0,0 +1,173 @@ +use std::fmt::Debug; + +use memmap2::MmapMut; + +use super::{LinearMemory, PAGE_SIZE}; + +pub const CELL_STRIDE: usize = 1; + +/// Mmap-backed linear memory. OS-memory pages are paged in on-demand and zero-initialized. +#[derive(Debug)] +pub struct MmapMemory { + mmap: MmapMut, +} + +impl Clone for MmapMemory { + fn clone(&self) -> Self { + let mut new_mmap = MmapMut::map_anon(self.mmap.len()).unwrap(); + new_mmap.copy_from_slice(&self.mmap); + Self { mmap: new_mmap } + } +} + +impl MmapMemory { + #[inline(always)] + pub fn as_ptr(&self) -> *const u8 { + self.mmap.as_ptr() + } + + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.mmap.as_mut_ptr() + } +} + +impl LinearMemory for MmapMemory { + /// Create a new MmapMemory with the given `size` in bytes. + /// We round `size` up to be a multiple of the mmap page size (4kb by default) so that OS-level + /// MMU protection corresponds to out of bounds protection. + fn new(mut size: usize) -> Self { + size = size.div_ceil(PAGE_SIZE) * PAGE_SIZE; + // anonymous mapping means pages are zero-initialized on first use + Self { + mmap: MmapMut::map_anon(size).unwrap(), + } + } + + fn size(&self) -> usize { + self.mmap.len() + } + + fn as_slice(&self) -> &[u8] { + &self.mmap + } + + fn as_mut_slice(&mut self) -> &mut [u8] { + &mut self.mmap + } + + #[inline(always)] + unsafe fn read(&self, from: usize) -> BLOCK { + debug_assert!( + from + size_of::() <= self.size(), + "read from={from} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - MMU will segfault if `src` access is out of bounds. + // - We assume `src` is aligned to `BLOCK` + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read(src) + } + + #[inline(always)] + unsafe fn read_unaligned(&self, from: usize) -> BLOCK { + debug_assert!( + from + size_of::() <= self.size(), + "read_unaligned from={from} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - MMU will segfault if `src` access is out of bounds. + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read_unaligned(src) + } + + #[inline(always)] + unsafe fn write(&mut self, start: usize, values: BLOCK) { + debug_assert!( + start + size_of::() <= self.size(), + "write start={start} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let dst = self.as_mut_ptr().add(start) as *mut BLOCK; + // SAFETY: + // - MMU will segfault if `dst` access is out of bounds. + // - We assume `dst` is aligned to `BLOCK` + core::ptr::write(dst, values); + } + + #[inline(always)] + unsafe fn write_unaligned(&mut self, start: usize, values: BLOCK) { + debug_assert!( + start + size_of::() <= self.size(), + "write_unaligned start={start} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let dst = self.as_mut_ptr().add(start) as *mut BLOCK; + // SAFETY: + // - MMU will segfault if `dst` access is out of bounds. + core::ptr::write_unaligned(dst, values); + } + + #[inline(always)] + unsafe fn swap(&mut self, start: usize, values: &mut BLOCK) { + debug_assert!( + start + size_of::() <= self.size(), + "swap start={start} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + // SAFETY: + // - MMU will segfault if `start` access is out of bounds. + // - We assume `start` is aligned to `BLOCK` + core::ptr::swap( + self.as_mut_ptr().add(start) as *mut BLOCK, + values as *mut BLOCK, + ); + } + + #[inline(always)] + unsafe fn copy_nonoverlapping(&mut self, to: usize, data: &[T]) { + debug_assert!( + to + size_of_val(data) <= self.size(), + "copy_nonoverlapping to={to} of size={} out of bounds: memory size={}", + size_of_val(data), + self.size() + ); + debug_assert_eq!(PAGE_SIZE % align_of::(), 0); + let src = data.as_ptr(); + let dst = self.as_mut_ptr().add(to) as *mut T; + // SAFETY: + // - MMU will segfault if `dst..dst + size_of_val(data)` is out of bounds. + // - Assumes `to` is aligned to `T` and `self.as_mut_ptr()` is aligned to `T`, which implies + // the same for `dst`. + core::ptr::copy_nonoverlapping::(src, dst, data.len()); + } + + #[inline(always)] + unsafe fn get_aligned_slice(&self, start: usize, len: usize) -> &[T] { + debug_assert!( + start + len * size_of::() <= self.size(), + "get_aligned_slice start={start} of size={} out of bounds: memory size={}", + len * size_of::(), + self.size() + ); + let data = self.as_ptr().add(start) as *const T; + // SAFETY: + // - MMU will segfault if `data..data + len * size_of::()` is out of bounds. + // - Assumes `data` is aligned to `T` + // - `T` is "plain old data" (POD), so conversion from underlying bytes is properly + // initialized + // - `self` will not be mutated while borrowed + core::slice::from_raw_parts(data, len) + } +} diff --git a/crates/vm/src/system/memory/online/paged_vec.rs b/crates/vm/src/system/memory/online/paged_vec.rs new file mode 100644 index 0000000000..30fe77297d --- /dev/null +++ b/crates/vm/src/system/memory/online/paged_vec.rs @@ -0,0 +1,93 @@ +use std::fmt::Debug; + +use openvm_stark_backend::p3_maybe_rayon::prelude::*; + +#[derive(Debug, Clone)] +pub struct PagedVec { + pages: Vec>>, +} + +unsafe impl Send for PagedVec {} +unsafe impl Sync for PagedVec {} + +impl PagedVec { + #[inline] + /// `total_size` is the capacity of elements of type `T`. + pub fn new(total_size: usize) -> Self { + let num_pages = total_size.div_ceil(PAGE_SIZE); + Self { + pages: vec![None; num_pages], + } + } + + #[cold] + #[inline(never)] + fn create_zeroed_page() -> Box<[T; PAGE_SIZE]> { + unsafe { + let layout = std::alloc::Layout::array::(PAGE_SIZE).unwrap(); + let ptr = std::alloc::alloc_zeroed(layout) as *mut [T; PAGE_SIZE]; + Box::from_raw(ptr) + } + } + + /// Get value at index without allocating new pages. + /// Panics if index is out of bounds. Returns default value if page doesn't exist. + #[inline] + pub fn get(&self, index: usize) -> T { + let page_idx = index / PAGE_SIZE; + let offset = index % PAGE_SIZE; + + self.pages[page_idx] + .as_ref() + .map(|page| unsafe { *page.get_unchecked(offset) }) + .unwrap_or_default() + } + + /// Panics if the index is out of bounds. Creates new page before write when necessary. + #[inline] + pub fn set(&mut self, index: usize, value: T) { + let page_idx = index / PAGE_SIZE; + let offset = index % PAGE_SIZE; + + let page = self.pages[page_idx].get_or_insert_with(Self::create_zeroed_page); + + // SAFETY: offset < PAGE_SIZE by construction + unsafe { + *page.get_unchecked_mut(offset) = value; + } + } + + pub fn par_iter(&self) -> impl ParallelIterator + '_ + where + T: Send + Sync, + { + self.pages + .par_iter() + .enumerate() + .filter_map(move |(page_idx, page)| { + page.as_ref().map(move |p| { + p.par_iter() + .enumerate() + .map(move |(offset, &value)| (page_idx * PAGE_SIZE + offset, value)) + }) + }) + .flatten() + } + + pub fn iter(&self) -> impl Iterator + '_ + where + T: Send + Sync, + { + self.pages + .iter() + .enumerate() + .filter_map(move |(page_idx, page)| { + page.as_ref().map(move |p| { + p.iter() + .enumerate() + .map(move |(offset, &value)| (page_idx * PAGE_SIZE + offset, value)) + }) + }) + .flatten() + } +} diff --git a/crates/vm/src/system/memory/paged_vec.rs b/crates/vm/src/system/memory/paged_vec.rs deleted file mode 100644 index 8a8b030970..0000000000 --- a/crates/vm/src/system/memory/paged_vec.rs +++ /dev/null @@ -1,447 +0,0 @@ -use std::{mem::MaybeUninit, ops::Range, ptr}; - -use serde::{Deserialize, Serialize}; - -use crate::arch::MemoryConfig; - -/// (address_space, pointer) -pub type Address = (u32, u32); -pub const PAGE_SIZE: usize = 1 << 12; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PagedVec { - pub pages: Vec>>, -} - -// ------------------------------------------------------------------ -// Common Helper Functions -// These functions encapsulate the common logic for copying ranges -// across pages, both for read-only and read-write (set) cases. -impl PagedVec { - // Copies a range of length `len` starting at index `start` - // into the memory pointed to by `dst`. If the relevant page is not - // initialized, fills that portion with T::default(). - fn read_range_generic(&self, start: usize, len: usize, dst: *mut T) { - let start_page = start / PAGE_SIZE; - let end_page = (start + len - 1) / PAGE_SIZE; - unsafe { - if start_page == end_page { - let offset = start % PAGE_SIZE; - if let Some(page) = self.pages[start_page].as_ref() { - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len); - } else { - std::slice::from_raw_parts_mut(dst, len).fill(T::default()); - } - } else { - let offset = start % PAGE_SIZE; - let first_part = PAGE_SIZE - offset; - if let Some(page) = self.pages[start_page].as_ref() { - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part); - } else { - std::slice::from_raw_parts_mut(dst, first_part).fill(T::default()); - } - let second_part = len - first_part; - if let Some(page) = self.pages[end_page].as_ref() { - ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part); - } else { - std::slice::from_raw_parts_mut(dst.add(first_part), second_part) - .fill(T::default()); - } - } - } - } - - // Updates a range of length `len` starting at index `start` with new values. - // It copies the current values into the memory pointed to by `dst` - // and then writes the new values into the underlying pages, - // allocating pages (with defaults) if necessary. - fn set_range_generic(&mut self, start: usize, len: usize, new: *const T, dst: *mut T) { - let start_page = start / PAGE_SIZE; - let end_page = (start + len - 1) / PAGE_SIZE; - unsafe { - if start_page == end_page { - let offset = start % PAGE_SIZE; - let page = - self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len); - ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), len); - } else { - let offset = start % PAGE_SIZE; - let first_part = PAGE_SIZE - offset; - { - let page = - self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part); - ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), first_part); - } - let second_part = len - first_part; - { - let page = - self.pages[end_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part); - ptr::copy_nonoverlapping(new.add(first_part), page.as_mut_ptr(), second_part); - } - } - } - } -} - -// ------------------------------------------------------------------ -// Implementation for types requiring Default + Clone -impl PagedVec { - pub fn new(num_pages: usize) -> Self { - Self { - pages: vec![None; num_pages], - } - } - - pub fn get(&self, index: usize) -> Option<&T> { - let page_idx = index / PAGE_SIZE; - self.pages[page_idx] - .as_ref() - .map(|page| &page[index % PAGE_SIZE]) - } - - pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { - let page_idx = index / PAGE_SIZE; - self.pages[page_idx] - .as_mut() - .map(|page| &mut page[index % PAGE_SIZE]) - } - - pub fn set(&mut self, index: usize, value: T) -> Option { - let page_idx = index / PAGE_SIZE; - if let Some(page) = self.pages[page_idx].as_mut() { - Some(std::mem::replace(&mut page[index % PAGE_SIZE], value)) - } else { - let page = self.pages[page_idx].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - page[index % PAGE_SIZE] = value; - None - } - } - - #[inline(always)] - pub fn range_vec(&self, range: Range) -> Vec { - let len = range.end - range.start; - // Create a vector for uninitialized values. - let mut result: Vec> = Vec::with_capacity(len); - // SAFETY: We set the length and then initialize every element via read_range_generic. - unsafe { - result.set_len(len); - self.read_range_generic(range.start, len, result.as_mut_ptr() as *mut T); - std::mem::transmute::>, Vec>(result) - } - } - - pub fn set_range(&mut self, range: Range, values: &[T]) -> Vec { - let len = range.end - range.start; - assert_eq!(values.len(), len); - let mut result: Vec> = Vec::with_capacity(len); - // SAFETY: We will write to every element in result via set_range_generic. - unsafe { - result.set_len(len); - self.set_range_generic( - range.start, - len, - values.as_ptr(), - result.as_mut_ptr() as *mut T, - ); - std::mem::transmute::>, Vec>(result) - } - } - - pub fn memory_size(&self) -> usize { - self.pages.len() * PAGE_SIZE - } - - pub fn is_empty(&self) -> bool { - self.pages.iter().all(|page| page.is_none()) - } -} - -// ------------------------------------------------------------------ -// Implementation for types requiring Default + Copy -impl PagedVec { - #[inline(always)] - pub fn range_array(&self, from: usize) -> [T; N] { - // Create an uninitialized array of MaybeUninit - let mut result: [MaybeUninit; N] = unsafe { - // SAFETY: An uninitialized `[MaybeUninit; N]` is valid. - MaybeUninit::uninit().assume_init() - }; - self.read_range_generic(from, N, result.as_mut_ptr() as *mut T); - // SAFETY: All elements have been initialized. - unsafe { ptr::read(&result as *const _ as *const [T; N]) } - } - - #[inline(always)] - pub fn set_range_array(&mut self, from: usize, values: &[T; N]) -> [T; N] { - // Create an uninitialized array for old values. - let mut result: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; - self.set_range_generic(from, N, values.as_ptr(), result.as_mut_ptr() as *mut T); - unsafe { ptr::read(&result as *const _ as *const [T; N]) } - } -} - -impl PagedVec { - pub fn iter(&self) -> PagedVecIter<'_, T, PAGE_SIZE> { - PagedVecIter { - vec: self, - current_page: 0, - current_index_in_page: 0, - } - } -} - -pub struct PagedVecIter<'a, T, const PAGE_SIZE: usize> { - vec: &'a PagedVec, - current_page: usize, - current_index_in_page: usize, -} - -impl Iterator for PagedVecIter<'_, T, PAGE_SIZE> { - type Item = (usize, T); - - fn next(&mut self) -> Option { - while self.current_page < self.vec.pages.len() - && self.vec.pages[self.current_page].is_none() - { - self.current_page += 1; - debug_assert_eq!(self.current_index_in_page, 0); - self.current_index_in_page = 0; - } - if self.current_page >= self.vec.pages.len() { - return None; - } - let global_index = self.current_page * PAGE_SIZE + self.current_index_in_page; - - let page = self.vec.pages[self.current_page].as_ref()?; - let value = page[self.current_index_in_page].clone(); - - self.current_index_in_page += 1; - if self.current_index_in_page == PAGE_SIZE { - self.current_page += 1; - self.current_index_in_page = 0; - } - Some((global_index, value)) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AddressMap { - pub paged_vecs: Vec>, - pub as_offset: u32, -} - -impl Default for AddressMap { - fn default() -> Self { - Self::from_mem_config(&MemoryConfig::default()) - } -} - -impl AddressMap { - pub fn new(as_offset: u32, as_cnt: usize, mem_size: usize) -> Self { - Self { - paged_vecs: vec![PagedVec::new(mem_size.div_ceil(PAGE_SIZE)); as_cnt], - as_offset, - } - } - pub fn from_mem_config(mem_config: &MemoryConfig) -> Self { - Self::new( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - ) - } - pub fn items(&self) -> impl Iterator + '_ { - self.paged_vecs - .iter() - .enumerate() - .flat_map(move |(as_idx, page)| { - page.iter() - .map(move |(ptr_idx, x)| ((as_idx as u32 + self.as_offset, ptr_idx as u32), x)) - }) - } - pub fn get(&self, address: &Address) -> Option<&T> { - self.paged_vecs[(address.0 - self.as_offset) as usize].get(address.1 as usize) - } - pub fn get_mut(&mut self, address: &Address) -> Option<&mut T> { - self.paged_vecs[(address.0 - self.as_offset) as usize].get_mut(address.1 as usize) - } - pub fn insert(&mut self, address: &Address, data: T) -> Option { - self.paged_vecs[(address.0 - self.as_offset) as usize].set(address.1 as usize, data) - } - pub fn is_empty(&self) -> bool { - self.paged_vecs.iter().all(|page| page.is_empty()) - } - - pub fn from_iter( - as_offset: u32, - as_cnt: usize, - mem_size: usize, - iter: impl IntoIterator, - ) -> Self { - let mut vec = Self::new(as_offset, as_cnt, mem_size); - for (address, data) in iter { - vec.insert(&address, data); - } - vec - } -} - -impl AddressMap { - pub fn get_range(&self, address: &Address) -> [T; N] { - self.paged_vecs[(address.0 - self.as_offset) as usize].range_array(address.1 as usize) - } - pub fn set_range(&mut self, address: &Address, values: &[T; N]) -> [T; N] { - self.paged_vecs[(address.0 - self.as_offset) as usize] - .set_range_array(address.1 as usize, values) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_basic_get_set() { - let mut v = PagedVec::<_, 4>::new(3); - assert_eq!(v.get(0), None); - v.set(0, 42); - assert_eq!(v.get(0), Some(&42)); - } - - #[test] - fn test_cross_page_operations() { - let mut v = PagedVec::<_, 4>::new(3); - v.set(3, 10); // Last element of first page - v.set(4, 20); // First element of second page - assert_eq!(v.get(3), Some(&10)); - assert_eq!(v.get(4), Some(&20)); - } - - #[test] - fn test_page_boundaries() { - let mut v = PagedVec::<_, 4>::new(2); - // Fill first page - v.set(0, 1); - v.set(1, 2); - v.set(2, 3); - v.set(3, 4); - // Fill second page - v.set(4, 5); - v.set(5, 6); - v.set(6, 7); - v.set(7, 8); - - // Verify all values - assert_eq!(v.range_vec(0..8), [1, 2, 3, 4, 5, 6, 7, 8]); - } - - #[test] - fn test_range_cross_page_boundary() { - let mut v = PagedVec::<_, 4>::new(2); - v.set_range(2..8, &[10, 11, 12, 13, 14, 15]); - assert_eq!(v.range_vec(2..8), [10, 11, 12, 13, 14, 15]); - } - - #[test] - fn test_large_indices() { - let mut v = PagedVec::<_, 4>::new(100); - let large_index = 399; - v.set(large_index, 42); - assert_eq!(v.get(large_index), Some(&42)); - } - - #[test] - fn test_range_operations_with_defaults() { - let mut v = PagedVec::<_, 4>::new(3); - v.set(2, 5); - v.set(5, 10); - - // Should include both set values and defaults - assert_eq!(v.range_vec(1..7), [0, 5, 0, 0, 10, 0]); - } - - #[test] - fn test_non_zero_default_type() { - let mut v: PagedVec = PagedVec::new(2); - assert_eq!(v.get(0), None); // bool's default - v.set(0, true); - assert_eq!(v.get(0), Some(&true)); - assert_eq!(v.get(1), Some(&false)); // because we created the page - } - - #[test] - fn test_set_range_overlapping_pages() { - let mut v = PagedVec::<_, 4>::new(3); - let test_data = [1, 2, 3, 4, 5, 6]; - v.set_range(2..8, &test_data); - - // Verify first page - assert_eq!(v.get(2), Some(&1)); - assert_eq!(v.get(3), Some(&2)); - - // Verify second page - assert_eq!(v.get(4), Some(&3)); - assert_eq!(v.get(5), Some(&4)); - assert_eq!(v.get(6), Some(&5)); - assert_eq!(v.get(7), Some(&6)); - } - - #[test] - fn test_overlapping_set_ranges() { - let mut v = PagedVec::<_, 4>::new(3); - - // Initial set_range - v.set_range(0..5, &[1, 2, 3, 4, 5]); - assert_eq!(v.range_vec(0..5), [1, 2, 3, 4, 5]); - - // Overlap from beginning - v.set_range(0..3, &[10, 20, 30]); - assert_eq!(v.range_vec(0..5), [10, 20, 30, 4, 5]); - - // Overlap in middle - v.set_range(2..4, &[42, 43]); - assert_eq!(v.range_vec(0..5), [10, 20, 42, 43, 5]); - - // Overlap at end - v.set_range(4..6, &[91, 92]); - assert_eq!(v.range_vec(0..6), [10, 20, 42, 43, 91, 92]); - } - - #[test] - fn test_overlapping_set_ranges_cross_pages() { - let mut v = PagedVec::<_, 4>::new(3); - - // Fill across first two pages - v.set_range(0..8, &[1, 2, 3, 4, 5, 6, 7, 8]); - - // Overlap end of first page and start of second - v.set_range(2..6, &[21, 22, 23, 24]); - assert_eq!(v.range_vec(0..8), [1, 2, 21, 22, 23, 24, 7, 8]); - - // Overlap multiple pages - v.set_range(1..7, &[31, 32, 33, 34, 35, 36]); - assert_eq!(v.range_vec(0..8), [1, 31, 32, 33, 34, 35, 36, 8]); - } - - #[test] - fn test_iterator() { - let mut v = PagedVec::<_, 4>::new(3); - - v.set_range(4..10, &[1, 2, 3, 4, 5, 6]); - let contents: Vec<_> = v.iter().collect(); - assert_eq!(contents.len(), 8); // two pages - - contents - .iter() - .take(6) - .enumerate() - .for_each(|(i, &(idx, val))| { - assert_eq!((idx, val), (4 + i, 1 + i)); - }); - assert_eq!(contents[6], (10, 0)); - assert_eq!(contents[7], (11, 0)); - } -} diff --git a/crates/vm/src/system/memory/persistent.rs b/crates/vm/src/system/memory/persistent.rs index 55a178be4d..eeb22cbfd6 100644 --- a/crates/vm/src/system/memory/persistent.rs +++ b/crates/vm/src/system/memory/persistent.rs @@ -13,18 +13,19 @@ use openvm_stark_backend::{ p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + Chip, ChipUsageGetter, }; use rustc_hash::FxHashSet; +use tracing::instrument; -use super::merkle::SerialReceiver; +use super::{merkle::SerialReceiver, online::INITIAL_TIMESTAMP, TimestampedValues}; use crate::{ - arch::hasher::Hasher, + arch::{hasher::Hasher, ADDR_SPACE_OFFSET}, system::memory::{ dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage, - TimestampedEquipartition, INITIAL_TIMESTAMP, + TimestampedEquipartition, }, }; @@ -92,7 +93,7 @@ impl Air for PersistentBoundaryA // direction = -1 => is_final = 1 local.expand_direction.into(), AB::Expr::ZERO, - local.address_space - AB::F::from_canonical_u32(self.memory_dims.as_offset), + local.address_space - AB::F::from_canonical_u32(ADDR_SPACE_OFFSET), local.leaf_label.into(), ]; expand_fields.extend(local.hash.map(Into::into)); @@ -123,18 +124,18 @@ impl Air for PersistentBoundaryA pub struct PersistentBoundaryChip { pub air: PersistentBoundaryAir, - touched_labels: TouchedLabels, + pub touched_labels: TouchedLabels, overridden_height: Option, } #[derive(Debug)] -enum TouchedLabels { +pub enum TouchedLabels { Running(FxHashSet<(u32, u32)>), Final(Vec>), } #[derive(Debug)] -struct FinalTouchedLabel { +pub struct FinalTouchedLabel { address_space: u32, label: u32, init_values: [F; CHUNK], @@ -159,7 +160,15 @@ impl TouchedLabels { _ => panic!("Cannot touch after finalization"), } } - fn len(&self) -> usize { + + pub fn is_empty(&self) -> bool { + match self { + TouchedLabels::Running(touched_labels) => touched_labels.is_empty(), + TouchedLabels::Final(touched_labels) => touched_labels.is_empty(), + } + } + + pub fn len(&self) -> usize { match self { TouchedLabels::Running(touched_labels) => touched_labels.len(), TouchedLabels::Final(touched_labels) => touched_labels.len(), @@ -198,59 +207,51 @@ impl PersistentBoundaryChip { } } - pub fn finalize( + #[instrument(name = "boundary_finalize", level = "debug", skip_all)] + pub(crate) fn finalize( &mut self, - initial_memory: &MemoryImage, + initial_memory: &MemoryImage, + // Only touched stuff final_memory: &TimestampedEquipartition, - hasher: &mut H, + hasher: &H, ) where H: Hasher + Sync + for<'a> SerialReceiver<&'a [F]>, { - match &mut self.touched_labels { - TouchedLabels::Running(touched_labels) => { - let final_touched_labels: Vec<_> = touched_labels - .par_iter() - .map(|&(address_space, label)| { - let pointer = label * CHUNK as u32; - let init_values = array::from_fn(|i| { - *initial_memory - .get(&(address_space, pointer + i as u32)) - .unwrap_or(&F::ZERO) - }); - let initial_hash = hasher.hash(&init_values); - let timestamped_values = final_memory.get(&(address_space, label)).unwrap(); - let final_hash = hasher.hash(×tamped_values.values); - FinalTouchedLabel { - address_space, - label, - init_values, - final_values: timestamped_values.values, - init_hash: initial_hash, - final_hash, - final_timestamp: timestamped_values.timestamp, - } - }) - .collect(); - for l in &final_touched_labels { - hasher.receive(&l.init_values); - hasher.receive(&l.final_values); + let final_touched_labels: Vec<_> = final_memory + .par_iter() + .map(|&((addr_space, ptr), ts_values)| { + let TimestampedValues { timestamp, values } = ts_values; + // SAFETY: addr_space from `final_memory` are all in bounds + let init_values = array::from_fn(|i| unsafe { + initial_memory.get_f::(addr_space, ptr + i as u32) + }); + let initial_hash = hasher.hash(&init_values); + let final_hash = hasher.hash(&values); + FinalTouchedLabel { + address_space: addr_space, + label: ptr / CHUNK as u32, + init_values, + final_values: values, + init_hash: initial_hash, + final_hash, + final_timestamp: timestamp, } - self.touched_labels = TouchedLabels::Final(final_touched_labels); - } - _ => panic!("Cannot finalize after finalization"), + }) + .collect(); + for l in &final_touched_labels { + hasher.receive(&l.init_values); + hasher.receive(&l.final_values); } + self.touched_labels = TouchedLabels::Final(final_touched_labels); } } -impl Chip for PersistentBoundaryChip, CHUNK> +impl Chip> for PersistentBoundaryChip, CHUNK> where + SC: StarkGenericConfig, Val: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { let trace = { let width = PersistentBoundaryCols::, CHUNK>::width(); // Boundary AIR should always present in order to fix the AIR ID of merkle AIR. @@ -265,13 +266,13 @@ where } let mut rows = Val::::zero_vec(height * width); - let touched_labels = match self.touched_labels { + let touched_labels = match &self.touched_labels { TouchedLabels::Final(touched_labels) => touched_labels, _ => panic!("Cannot generate trace before finalization"), }; rows.par_chunks_mut(2 * width) - .zip(touched_labels.into_par_iter()) + .zip(touched_labels.par_iter()) .for_each(|(row, touched_label)| { let (initial_row, final_row) = row.split_at_mut(width); *initial_row.borrow_mut() = PersistentBoundaryCols { @@ -292,9 +293,9 @@ where timestamp: Val::::from_canonical_u32(touched_label.final_timestamp), }; }); - RowMajorMatrix::new(rows, width) + Arc::new(RowMajorMatrix::new(rows, width)) }; - AirProofInput::simple_no_pis(trace) + AirProvingContext::simple_no_pis(trace) } } diff --git a/crates/vm/src/system/memory/tests.rs b/crates/vm/src/system/memory/tests.rs index 9ebb9306aa..ae630c7f3e 100644 --- a/crates/vm/src/system/memory/tests.rs +++ b/crates/vm/src/system/memory/tests.rs @@ -1,329 +1,99 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, - sync::Arc, -}; +use std::array; -use itertools::Itertools; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_poseidon2_air::Poseidon2Config; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, - p3_air::{Air, BaseAir}, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, -}; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, - p3_baby_bear::BabyBear, utils::create_seeded_rng, -}; -use rand::{ - prelude::{SliceRandom, StdRng}, - Rng, -}; +use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{thread_rng, Rng}; -use super::MemoryController; use crate::{ - arch::{ - testing::{memory::gen_pointer, MEMORY_BUS, MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS}, - MemoryConfig, - }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryBus, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, OfflineMemory, RecordId, - }, - poseidon2::Poseidon2PeripheryChip, - }, + arch::{testing::VmChipTestBuilder, MemoryConfig}, + system::memory::online::TracingMemory, }; -const MAX: usize = 32; -const RANGE_CHECKER_BUS: BusIndex = 3; - -#[repr(C)] -#[derive(AlignedBorrow)] -struct MemoryRequesterCols { - address_space: T, - pointer: T, - data_1: [T; 1], - data_4: [T; 4], - data_max: [T; MAX], - timestamp: T, - write_1_aux: MemoryWriteAuxCols, - write_4_aux: MemoryWriteAuxCols, - read_1_aux: MemoryReadAuxCols, - read_4_aux: MemoryReadAuxCols, - read_max_aux: MemoryReadAuxCols, - is_write_1: T, - is_write_4: T, - is_read_1: T, - is_read_4: T, - is_read_max: T, -} - -struct MemoryRequesterAir { - memory_bridge: MemoryBridge, -} +type F = BabyBear; -impl BaseAirWithPublicValues for MemoryRequesterAir {} -impl PartitionedBaseAir for MemoryRequesterAir {} -impl BaseAir for MemoryRequesterAir { - fn width(&self) -> usize { - MemoryRequesterCols::::width() - } -} - -impl Air for MemoryRequesterAir { - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local: &MemoryRequesterCols = (*local).borrow(); - - let flags = [ - local.is_read_1, - local.is_write_1, - local.is_read_4, - local.is_write_4, - local.is_read_max, - ]; +fn test_memory_write_by_tester(mut tester: VmChipTestBuilder) { + let mut rng = create_seeded_rng(); - let mut sum = AB::Expr::ZERO; - for flag in flags { - builder.assert_bool(flag); - sum += flag.into(); + // The point here is to have a lot of equal + // and intersecting/overlapping blocks, + // by limiting the space of valid pointers. + let max_ptr = 20; + let aligns = [4, 4, 4, 1]; + let value_bounds = [256, 256, 256, (1 << 30)]; + let max_log_block_size = 4; + let its = 1000; + for _ in 0..its { + let addr_sp = rng.gen_range(1..=aligns.len()); + let align: usize = aligns[addr_sp - 1]; + let value_bound: u32 = value_bounds[addr_sp - 1]; + let ptr = rng.gen_range(0..max_ptr / align) * align; + let log_len = rng.gen_range(align.trailing_zeros()..=max_log_block_size); + match log_len { + 0 => tester.write::<1>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 1 => tester.write::<2>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 2 => tester.write::<4>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 3 => tester.write::<8>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 4 => tester.write::<16>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + _ => unreachable!(), } - builder.assert_one(sum); - - self.memory_bridge - .read( - MemoryAddress::new(local.address_space, local.pointer), - local.data_1, - local.timestamp, - &local.read_1_aux, - ) - .eval(builder, local.is_read_1); - - self.memory_bridge - .read( - MemoryAddress::new(local.address_space, local.pointer), - local.data_4, - local.timestamp, - &local.read_4_aux, - ) - .eval(builder, local.is_read_4); - - self.memory_bridge - .write( - MemoryAddress::new(local.address_space, local.pointer), - local.data_1, - local.timestamp, - &local.write_1_aux, - ) - .eval(builder, local.is_write_1); - - self.memory_bridge - .write( - MemoryAddress::new(local.address_space, local.pointer), - local.data_4, - local.timestamp, - &local.write_4_aux, - ) - .eval(builder, local.is_write_4); - - self.memory_bridge - .read( - MemoryAddress::new(local.address_space, local.pointer), - local.data_max, - local.timestamp, - &local.read_max_aux, - ) - .eval(builder, local.is_read_max); } -} - -fn generate_trace( - records: Vec, - offline_memory: &OfflineMemory, -) -> RowMajorMatrix { - let height = records.len().next_power_of_two(); - let width = MemoryRequesterCols::::width(); - let mut values = F::zero_vec(height * width); - let aux_factory = offline_memory.aux_cols_factory(); - - for (row, record_id) in values.chunks_mut(width).zip(records) { - let record = offline_memory.record_by_id(record_id).clone(); - - let row: &mut MemoryRequesterCols = row.borrow_mut(); - row.address_space = record.address_space; - row.pointer = record.pointer; - row.timestamp = F::from_canonical_u32(record.timestamp); - - match (record.data_slice().len(), &record.prev_data_slice()) { - (1, &None) => { - aux_factory.generate_read_aux(&record, &mut row.read_1_aux); - row.data_1 = record.data_slice().try_into().unwrap(); - row.is_read_1 = F::ONE; - } - (1, &Some(_)) => { - aux_factory.generate_write_aux(&record, &mut row.write_1_aux); - row.data_1 = record.data_slice().try_into().unwrap(); - row.is_write_1 = F::ONE; - } - (4, &None) => { - aux_factory.generate_read_aux(&record, &mut row.read_4_aux); - row.data_4 = record.data_slice().try_into().unwrap(); - row.is_read_4 = F::ONE; - } - (4, &Some(_)) => { - aux_factory.generate_write_aux(&record, &mut row.write_4_aux); - row.data_4 = record.data_slice().try_into().unwrap(); - row.is_write_4 = F::ONE; - } - (MAX, &None) => { - aux_factory.generate_read_aux(&record, &mut row.read_max_aux); - row.data_max = record.data_slice().try_into().unwrap(); - row.is_read_max = F::ONE; - } - _ => panic!("unexpected pattern"), - } - } - RowMajorMatrix::new(values, width) + let tester = tester.build().finalize(); + tester.simple_test().expect("Verification failed"); } -/// Simple integration test for memory chip. -/// -/// Creates a bunch of random read/write records, used to generate a trace for [MemoryRequesterAir], -/// which sends reads/writes over [MemoryBridge]. #[test] -fn test_memory_controller() { - let memory_bus = MemoryBus::new(MEMORY_BUS); - let memory_config = MemoryConfig::default(); - let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp); - let range_checker = SharedVariableRangeCheckerChip::new(range_bus); - - let mut memory_controller = - MemoryController::with_volatile_memory(memory_bus, memory_config, range_checker.clone()); - - let mut rng = create_seeded_rng(); - let records = make_random_accesses(&mut memory_controller, &mut rng); - let memory_requester_air = Arc::new(MemoryRequesterAir { - memory_bridge: memory_controller.memory_bridge(), - }); - - memory_controller.finalize(None::<&mut Poseidon2PeripheryChip>); - - let memory_requester_trace = { - let offline_memory = memory_controller.offline_memory(); - let trace = generate_trace(records, &offline_memory.lock().unwrap()); - trace - }; - - let mut airs = memory_controller.airs(); - let mut air_proof_inputs = memory_controller.generate_air_proof_inputs(); - airs.push(memory_requester_air); - air_proof_inputs.push(AirProofInput::simple_no_pis(memory_requester_trace)); - airs.push(range_checker.air()); - air_proof_inputs.push(range_checker.generate_air_proof_input()); - - BabyBearPoseidon2Engine::run_test_fast(airs, air_proof_inputs).expect("Verification failed"); +fn test_memory_write_volatile() { + test_memory_write_by_tester(VmChipTestBuilder::::volatile(MemoryConfig::default())); } #[test] -fn test_memory_controller_persistent() { - let memory_bus = MemoryBus::new(MEMORY_BUS); - let merkle_bus = PermutationCheckBus::new(MEMORY_MERKLE_BUS); - let compression_bus = PermutationCheckBus::new(POSEIDON2_DIRECT_BUS); - let memory_config = MemoryConfig::default(); - let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp); - let range_checker = SharedVariableRangeCheckerChip::new(range_bus); - - let mut memory_controller = MemoryController::with_persistent_memory( - memory_bus, - memory_config, - range_checker.clone(), - merkle_bus, - compression_bus, - ); - - let mut rng = create_seeded_rng(); - let records = make_random_accesses(&mut memory_controller, &mut rng); - - let memory_requester_air = MemoryRequesterAir { - memory_bridge: memory_controller.memory_bridge(), - }; - - let mut poseidon_chip = - Poseidon2PeripheryChip::new(Poseidon2Config::default(), POSEIDON2_DIRECT_BUS, 3); - - memory_controller.finalize(Some(&mut poseidon_chip)); - - let memory_requester_trace = { - let offline_memory = memory_controller.offline_memory(); - let trace = generate_trace(records, &offline_memory.lock().unwrap()); - trace - }; - - let mut airs = memory_controller.airs(); - let mut air_proof_inputs = memory_controller.generate_air_proof_inputs(); - airs.extend([ - Arc::new(memory_requester_air), - poseidon_chip.air(), - range_checker.air(), - ]); - air_proof_inputs.extend([ - AirProofInput::simple_no_pis(memory_requester_trace), - poseidon_chip.generate_air_proof_input(), - range_checker.generate_air_proof_input(), - ]); - - BabyBearPoseidon2Engine::run_test_fast(airs, air_proof_inputs).expect("Verification failed"); +fn test_memory_write_persistent() { + test_memory_write_by_tester(VmChipTestBuilder::::persistent(MemoryConfig::default())); } -fn make_random_accesses( - memory_controller: &mut MemoryController, - mut rng: &mut StdRng, -) -> Vec { - (0..1024) - .map(|_| { - let address_space = F::from_canonical_u32(*[1, 2].choose(&mut rng).unwrap()); - - match rng.gen_range(0..5) { - 0 => { - let pointer = F::from_canonical_usize(gen_pointer(rng, 1)); - let data = F::from_canonical_u32(rng.gen_range(0..1 << 30)); - let (record_id, _) = memory_controller.write(address_space, pointer, [data]); - record_id - } - 1 => { - let pointer = F::from_canonical_usize(gen_pointer(rng, 1)); - let (record_id, _) = memory_controller.read::<1>(address_space, pointer); - record_id - } - 2 => { - let pointer = F::from_canonical_usize(gen_pointer(rng, 4)); - let (record_id, _) = memory_controller.read::<4>(address_space, pointer); - record_id - } - 3 => { - let pointer = F::from_canonical_usize(gen_pointer(rng, 4)); - let data = array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..1 << 30))); - let (record_id, _) = memory_controller.write::<4>(address_space, pointer, data); - record_id - } - 4 => { - let pointer = F::from_canonical_usize(gen_pointer(rng, MAX)); - let (record_id, _) = memory_controller.read::(address_space, pointer); - record_id - } - _ => unreachable!(), +#[test] +fn test_no_adapter_records_for_singleton_accesses() { + let memory_config = MemoryConfig::default(); + let mut memory = TracingMemory::new(&memory_config, 1, 0); + + let mut rng = thread_rng(); + for _ in 0..1000 { + // TODO[jpw]: test other address spaces? + let address_space = 4u32; + let pointer = rng.gen_range(0..1 << memory_config.pointer_max_bits); + + if rng.gen_bool(0.5) { + let data = F::from_canonical_u32(rng.gen_range(0..1 << 30)); + // address space is 4 so cell type is `F` + unsafe { + memory.write::(address_space, pointer, [data]); } - }) - .collect_vec() + } else { + unsafe { + memory.read::(address_space, pointer); + } + } + } + assert!(memory.access_adapter_records.allocated().is_empty()); } diff --git a/crates/vm/src/system/memory/tree/mod.rs b/crates/vm/src/system/memory/tree/mod.rs deleted file mode 100644 index fcdb86d8ee..0000000000 --- a/crates/vm/src/system/memory/tree/mod.rs +++ /dev/null @@ -1,177 +0,0 @@ -pub mod public_values; - -use std::{ops::Range, sync::Arc}; - -use openvm_stark_backend::{p3_field::PrimeField32, p3_maybe_rayon::prelude::*}; -use MemoryNode::*; - -use super::controller::dimensions::MemoryDimensions; -use crate::{ - arch::hasher::{Hasher, HasherChip}, - system::memory::MemoryImage, -}; - -#[derive(Clone, Debug, PartialEq)] -pub enum MemoryNode { - Leaf { - values: [F; CHUNK], - }, - NonLeaf { - hash: [F; CHUNK], - left: Arc>, - right: Arc>, - }, -} - -impl MemoryNode { - pub fn hash(&self) -> [F; CHUNK] { - match self { - Leaf { values: hash } => *hash, - NonLeaf { hash, .. } => *hash, - } - } - - pub fn new_leaf(values: [F; CHUNK]) -> Self { - Leaf { values } - } - - pub fn new_nonleaf( - left: Arc>, - right: Arc>, - hasher: &mut impl HasherChip, - ) -> Self { - NonLeaf { - hash: hasher.compress_and_record(&left.hash(), &right.hash()), - left, - right, - } - } - - /// Returns a tree of height `height` with all leaves set to `leaf_value`. - pub fn construct_uniform( - height: usize, - leaf_value: [F; CHUNK], - hasher: &impl Hasher, - ) -> MemoryNode { - if height == 0 { - Self::new_leaf(leaf_value) - } else { - let child = Arc::new(Self::construct_uniform(height - 1, leaf_value, hasher)); - NonLeaf { - hash: hasher.compress(&child.hash(), &child.hash()), - left: child.clone(), - right: child, - } - } - } - - fn from_memory( - memory: &[(u64, F)], - lookup_range: Range, - length: u64, - from: u64, - hasher: &(impl Hasher + Sync), - zero_leaf: &MemoryNode, - ) -> MemoryNode { - if length == CHUNK as u64 { - if lookup_range.is_empty() { - zero_leaf.clone() - } else { - debug_assert_eq!(memory[lookup_range.start].0, from); - let mut values = [F::ZERO; CHUNK]; - for (index, value) in memory[lookup_range].iter() { - values[(index % CHUNK as u64) as usize] = *value; - } - MemoryNode::new_leaf(hasher.hash(&values)) - } - } else if lookup_range.is_empty() { - let leaf_value = hasher.hash(&[F::ZERO; CHUNK]); - MemoryNode::construct_uniform( - (length / CHUNK as u64).trailing_zeros() as usize, - leaf_value, - hasher, - ) - } else { - let midpoint = from + length / 2; - let mid = { - let mut left = lookup_range.start; - let mut right = lookup_range.end; - if memory[left].0 >= midpoint { - left - } else { - while left + 1 < right { - let mid = left + (right - left) / 2; - if memory[mid].0 < midpoint { - left = mid; - } else { - right = mid; - } - } - right - } - }; - let (left, right) = join( - || { - Self::from_memory( - memory, - lookup_range.start..mid, - length >> 1, - from, - hasher, - zero_leaf, - ) - }, - || { - Self::from_memory( - memory, - mid..lookup_range.end, - length >> 1, - midpoint, - hasher, - zero_leaf, - ) - }, - ); - NonLeaf { - hash: hasher.compress(&left.hash(), &right.hash()), - left: Arc::new(left), - right: Arc::new(right), - } - } - } - - pub fn tree_from_memory( - memory_dimensions: MemoryDimensions, - memory: &MemoryImage, - hasher: &(impl Hasher + Sync), - ) -> MemoryNode { - // Construct a Vec that includes the address space in the label calculation, - // representing the entire memory tree. - let memory_items = memory - .items() - .filter(|((_, ptr), _)| *ptr as usize / CHUNK < (1 << memory_dimensions.address_height)) - .map(|((address_space, pointer), value)| { - ( - memory_dimensions.label_to_index((address_space, pointer / CHUNK as u32)) - * CHUNK as u64 - + (pointer % CHUNK as u32) as u64, - value, - ) - }) - .collect::>(); - debug_assert!(memory_items.is_sorted_by_key(|(addr, _)| addr)); - debug_assert!( - memory_items.last().map_or(0, |(addr, _)| *addr) - < ((CHUNK as u64) << memory_dimensions.overall_height()) - ); - let zero_leaf = MemoryNode::new_leaf(hasher.hash(&[F::ZERO; CHUNK])); - Self::from_memory( - &memory_items, - 0..memory_items.len(), - (CHUNK as u64) << memory_dimensions.overall_height(), - 0, - hasher, - &zero_leaf, - ) - } -} diff --git a/crates/vm/src/system/memory/volatile/mod.rs b/crates/vm/src/system/memory/volatile/mod.rs index e01162c789..9296c91247 100644 --- a/crates/vm/src/system/memory/volatile/mod.rs +++ b/crates/vm/src/system/memory/volatile/mod.rs @@ -21,11 +21,12 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, + prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + Chip, ChipUsageGetter, }; use static_assertions::const_assert; +use tracing::instrument; use super::TimestampedEquipartition; use crate::system::memory::{ @@ -183,7 +184,7 @@ pub struct VolatileBoundaryChip { pub air: VolatileBoundaryAir, range_checker: SharedVariableRangeCheckerChip, overridden_height: Option, - final_memory: Option>, + pub final_memory: Option>, addr_space_max_bits: usize, pointer_max_bits: usize, } @@ -218,27 +219,26 @@ impl VolatileBoundaryChip { } /// Volatile memory requires the starting and final memory to be in equipartition with block /// size `1`. When block size is `1`, then the `label` is the same as the address pointer. + #[instrument(name = "boundary_finalize", level = "debug", skip_all)] pub fn finalize(&mut self, final_memory: TimestampedEquipartition) { self.final_memory = Some(final_memory); } } -impl Chip for VolatileBoundaryChip> +impl Chip> for VolatileBoundaryChip> where Val: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { // Volatile memory requires the starting and final memory to be in equipartition with block // size `1`. When block size is `1`, then the `label` is the same as the address // pointer. let width = self.trace_width(); - let air = Arc::new(self.air); + let addr_lt_air = &self.air.addr_lt_air; + // TEMP[jpw]: clone let final_memory = self .final_memory + .clone() .expect("Trace generation should be called after finalize"); let trace_height = if let Some(height) = self.overridden_height { assert!( @@ -279,7 +279,7 @@ where if i != memory_len - 1 { let (next_addr_space, next_ptr) = sorted_final_memory[i + 1].0; let mut out = Val::::ZERO; - air.addr_lt_air.0.generate_subrow( + addr_lt_air.0.generate_subrow( ( self.range_checker.as_ref(), &[ @@ -300,7 +300,7 @@ where if memory_len > 0 { let mut out = Val::::ZERO; let row: &mut VolatileBoundaryCols<_> = rows[width * (trace_height - 1)..].borrow_mut(); - air.addr_lt_air.0.generate_subrow( + addr_lt_air.0.generate_subrow( ( self.range_checker.as_ref(), &[Val::::ZERO, Val::::ZERO], @@ -310,8 +310,8 @@ where ); } - let trace = RowMajorMatrix::new(rows, width); - AirProofInput::simple_no_pis(trace) + let trace = Arc::new(RowMajorMatrix::new(rows, width)); + AirProvingContext::simple_no_pis(trace) } } diff --git a/crates/vm/src/system/memory/volatile/tests.rs b/crates/vm/src/system/memory/volatile/tests.rs index 29917d219d..a0c484793b 100644 --- a/crates/vm/src/system/memory/volatile/tests.rs +++ b/crates/vm/src/system/memory/volatile/tests.rs @@ -1,11 +1,12 @@ use std::{collections::HashSet, iter, sync::Arc}; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; +use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_stark_backend::{ - interaction::BusIndex, p3_field::FieldAlgebra, p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, Chip, + interaction::BusIndex, + p3_field::FieldAlgebra, + p3_matrix::dense::RowMajorMatrix, + prover::{cpu::CpuBackend, types::AirProvingContext}, + AirRef, Chip, }; use openvm_stark_sdk::{ config::baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, @@ -45,7 +46,7 @@ fn boundary_air_test() { } let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, DECOMP); - let range_checker = SharedVariableRangeCheckerChip::new(range_bus); + let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); let mut boundary_chip = VolatileBoundaryChip::new(memory_bus, 2, LIMB_BITS, range_checker.clone()); @@ -55,21 +56,22 @@ fn boundary_air_test() { let final_data = Val::from_canonical_u32(rng.gen_range(0..MAX_VAL)); let final_clk = rng.gen_range(1..MAX_VAL) as u32; - final_memory.insert( + final_memory.push(( (addr_space, pointer), TimestampedValues { values: [final_data], timestamp: final_clk, }, - ); + )); } + final_memory.sort_by_key(|(key, _)| *key); let diff_height = num_addresses.next_power_of_two() - num_addresses; let init_memory_dummy_air = DummyInteractionAir::new(4, false, MEMORY_BUS); let final_memory_dummy_air = DummyInteractionAir::new(4, true, MEMORY_BUS); - let init_memory_trace = RowMajorMatrix::new( + let init_memory_trace = Arc::new(RowMajorMatrix::new( distinct_addresses .iter() .flat_map(|(addr_space, pointer)| { @@ -84,13 +86,16 @@ fn boundary_air_test() { .chain(iter::repeat_n(Val::ZERO, 5 * diff_height)) .collect(), 5, - ); + )); - let final_memory_trace = RowMajorMatrix::new( + let final_memory_trace = Arc::new(RowMajorMatrix::new( distinct_addresses .iter() .flat_map(|(addr_space, pointer)| { - let timestamped_value = final_memory.get(&(*addr_space, *pointer)).unwrap(); + let timestamped_value = final_memory[final_memory + .binary_search_by(|(key, _)| key.cmp(&(*addr_space, *pointer))) + .unwrap()] + .1; vec![ Val::ONE, @@ -103,24 +108,24 @@ fn boundary_air_test() { .chain(iter::repeat_n(Val::ZERO, 5 * diff_height)) .collect(), 5, - ); + )); boundary_chip.finalize(final_memory.clone()); - let boundary_air = boundary_chip.air(); - let boundary_api: AirProofInput = - boundary_chip.generate_air_proof_input(); + let boundary_air = Arc::new(boundary_chip.air.clone()) as AirRef<_>; + let boundary_ctx: AirProvingContext> = + boundary_chip.generate_proving_ctx(()); // test trace height override { - let overridden_height = boundary_api.main_trace_height() * 2; - let range_checker = SharedVariableRangeCheckerChip::new(range_bus); + let overridden_height = boundary_ctx.main_trace_height() * 2; + let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); let mut boundary_chip = VolatileBoundaryChip::new(memory_bus, 2, LIMB_BITS, range_checker.clone()); boundary_chip.set_overridden_height(overridden_height); boundary_chip.finalize(final_memory.clone()); - let boundary_api: AirProofInput = - boundary_chip.generate_air_proof_input(); + let boundary_ctx: AirProvingContext> = + boundary_chip.generate_proving_ctx(()); assert_eq!( - boundary_api.main_trace_height(), + boundary_ctx.main_trace_height(), overridden_height.next_power_of_two() ); } @@ -128,15 +133,15 @@ fn boundary_air_test() { BabyBearPoseidon2Engine::run_test_fast( vec![ boundary_air, - range_checker.air(), + Arc::new(range_checker.air), Arc::new(init_memory_dummy_air), Arc::new(final_memory_dummy_air), ], vec![ - boundary_api, - range_checker.generate_air_proof_input(), - AirProofInput::simple_no_pis(init_memory_trace), - AirProofInput::simple_no_pis(final_memory_trace), + boundary_ctx, + range_checker.generate_proving_ctx(()), + AirProvingContext::simple_no_pis(init_memory_trace), + AirProvingContext::simple_no_pis(final_memory_trace), ], ) .expect("Verification failed"); diff --git a/crates/vm/src/system/mod.rs b/crates/vm/src/system/mod.rs index a1038ac86a..20ae065f4a 100644 --- a/crates/vm/src/system/mod.rs +++ b/crates/vm/src/system/mod.rs @@ -1,11 +1,623 @@ +use std::sync::Arc; + +use derive_more::derive::From; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerAir, VariableRangeCheckerBus, + VariableRangeCheckerChip, +}; +use openvm_instructions::{ + LocalOpcode, PhantomDiscriminant, PublishOpcode, SysPhantom, SystemOpcode, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + interaction::{LookupBus, PermutationCheckBus}, + p3_field::{Field, PrimeField32}, + prover::{ + cpu::{CpuBackend, CpuDevice}, + hal::{MatrixDimensions, ProverBackend}, + types::{AirProvingContext, CommittedTraceData}, + }, + AirRef, Chip, +}; +use rustc_hash::FxHashMap; + +use self::{connector::VmConnectorAir, program::ProgramAir, public_values::PublicValuesAir}; +use crate::{ + arch::{ + vm_poseidon2_config, AirInventory, AirInventoryError, BusIndexManager, ChipInventory, + ChipInventoryError, DenseRecordArena, ExecutionBridge, ExecutionBus, ExecutionState, + ExecutorInventory, ExecutorInventoryError, MatrixRecordArena, PhantomSubExecutor, + RowMajorMatrixArena, SystemConfig, VmAirWrapper, VmBuilder, VmChipComplex, VmChipWrapper, + VmCircuitConfig, VmExecutionConfig, CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PUBLIC_VALUES_AIR_ID, + }, + system::{ + connector::VmConnectorChip, + memory::{ + interface::MemoryInterfaceAirs, + offline_checker::{MemoryBridge, MemoryBus}, + online::GuestMemory, + MemoryAirInventory, MemoryController, TimestampedEquipartition, CHUNK, + }, + native_adapter::{NativeAdapterAir, NativeAdapterExecutor}, + phantom::{ + CycleEndPhantomExecutor, CycleStartPhantomExecutor, NopPhantomExecutor, PhantomAir, + PhantomChip, PhantomExecutor, PhantomFiller, + }, + poseidon2::{ + air::Poseidon2PeripheryAir, new_poseidon2_periphery_air, Poseidon2PeripheryChip, + }, + program::{ProgramBus, ProgramChip}, + public_values::{PublicValuesChip, PublicValuesCoreAir, PublicValuesExecutor}, + }, +}; + pub mod connector; pub mod memory; +// Necessary for the PublicValuesChip pub mod native_adapter; -/// Chip to handle phantom instructions. -/// The Air will always constrain a NOP which advances pc by DEFAULT_PC_STEP. -/// The runtime executor will execute different phantom instructions that may -/// affect trace generation based on the operand. pub mod phantom; pub mod poseidon2; pub mod program; pub mod public_values; + +/// **If** internal poseidon2 chip exists, then its insertion index is 1. +const POSEIDON2_INSERTION_IDX: usize = 1; +/// **If** public values chip exists, then its executor index is 0. +pub(crate) const PV_EXECUTOR_IDX: usize = 0; + +/// Trait for trace generation of all system AIRs. The system chip complex is special because we may +/// not exactly following the exact matching between `Air` and `Chip`. Moreover we may require more +/// flexibility than what is provided through the trait object [`AnyChip`]. +/// +/// The [SystemChipComplex] is meant to be constructible once the VM configuration is known, and it +/// can be loaded with arbitrary programs supported by the instruction set available to its +/// configuration. The [SystemChipComplex] is meant to persistent between instances of proof +/// generation. +pub trait SystemChipComplex { + /// Loads the program in the form of a cached trace with prover data. + fn load_program(&mut self, cached_program_trace: CommittedTraceData); + + /// Transport the initial memory state to device. This may be called before preflight execution + /// begins and start async device processes in parallel to execution. + fn transport_init_memory_to_device(&mut self, memory: &GuestMemory); + + /// The caller must guarantee that `record_arenas` has length equal to the number of system + /// AIRs, although some arenas may be empty if they are unused. + fn generate_proving_ctx( + &mut self, + system_records: SystemRecords, + record_arenas: Vec, + ) -> Vec>; + + /// This function is only used for metric collection purposes and custom implementations are + /// free to ignore it. + /// + /// Since system chips (primarily memory) will only have all information needed to compute the + /// true used trace heights after `generate_proving_ctx` is called, this method will be called + /// after `generate_proving_ctx` on the trace `heights` of all AIRs (including non-system AIRs) + /// in the AIR ID order. + /// + /// The default implementation does nothing. + #[cfg(feature = "metrics")] + fn finalize_trace_heights(&self, _heights: &mut [usize]) {} +} + +/// Trait meant to be implemented on a SystemChipComplex. +pub trait SystemWithFixedTraceHeights { + /// `heights` will have length equal to number of system AIRs, in AIR ID order. This function + /// must guarantee that the system trace matrices generated have the required heights. + fn override_trace_heights(&mut self, heights: &[u32]); +} + +pub struct SystemRecords { + pub from_state: ExecutionState, + pub to_state: ExecutionState, + pub exit_code: Option, + /// `i` -> frequency of instruction in `i`th row of trace matrix. This requires filtering + /// `program.instructions_and_debug_infos` to remove gaps. + pub filtered_exec_frequencies: Vec, + // We always use a [DenseRecordArena] here, regardless of the generic `RA` used for other + // execution records. + pub access_adapter_records: DenseRecordArena, + // Perf[jpw]: this should be computed on-device and changed to just touched blocks + pub touched_memory: TouchedMemory, + /// The public values of the [PublicValuesChip]. These should only be non-empty if + /// continuations are disabled. + pub public_values: Vec, +} + +pub enum TouchedMemory { + Persistent(TimestampedEquipartition), + Volatile(TimestampedEquipartition), +} + +#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor, From)] +pub enum SystemExecutor { + PublicValues(PublicValuesExecutor), + Phantom(PhantomExecutor), +} + +/// SystemPort combines system resources needed by most extensions +#[derive(Clone, Copy)] +pub struct SystemPort { + pub execution_bus: ExecutionBus, + pub program_bus: ProgramBus, + pub memory_bridge: MemoryBridge, +} + +#[derive(Clone)] +pub struct SystemAirInventory { + pub program: ProgramAir, + pub connector: VmConnectorAir, + pub memory: MemoryAirInventory, + /// Public values AIR exists if and only if continuations is disabled and `num_public_values` + /// is greater than 0. + pub public_values: Option, +} + +impl SystemAirInventory { + pub fn new( + config: &SystemConfig, + port: SystemPort, + merkle_compression_buses: Option<(PermutationCheckBus, PermutationCheckBus)>, + ) -> Self { + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = port; + let range_bus = memory_bridge.range_bus(); + let program = ProgramAir::new(program_bus); + let connector = VmConnectorAir::new( + execution_bus, + program_bus, + range_bus, + config.memory_config.timestamp_max_bits, + ); + assert_eq!( + config.continuation_enabled, + merkle_compression_buses.is_some() + ); + + let memory = MemoryAirInventory::new( + memory_bridge, + &config.memory_config, + range_bus, + merkle_compression_buses, + ); + + let public_values = if config.has_public_values_chip() { + let air = VmAirWrapper::new( + NativeAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + PublicValuesCoreAir::new( + config.num_public_values, + config.max_constraint_degree as u32 - 1, + ), + ); + Some(air) + } else { + None + }; + + Self { + program, + connector, + memory, + public_values, + } + } + + pub fn port(&self) -> SystemPort { + SystemPort { + memory_bridge: self.memory.bridge, + program_bus: self.program.bus, + execution_bus: self.connector.execution_bus, + } + } + + pub fn into_airs(self) -> Vec> { + let mut airs: Vec> = Vec::new(); + airs.push(Arc::new(self.program)); + airs.push(Arc::new(self.connector)); + if let Some(public_values) = self.public_values { + airs.push(Arc::new(public_values)); + } + airs.extend(self.memory.into_airs()); + airs + } +} + +impl VmExecutionConfig for SystemConfig { + type Executor = SystemExecutor; + + /// The only way to create an [ExecutorInventory] is from a [SystemConfig]. This will add an + /// executor for [PublicValuesExecutor] if continuations is disabled. It will always add an + /// executor for [PhantomChip], which handles all phantom sub-executors. + fn create_executors( + &self, + ) -> Result, ExecutorInventoryError> { + let mut inventory = ExecutorInventory::new(self.clone()); + // PublicValuesChip is required when num_public_values > 0 in single segment mode. + if self.has_public_values_chip() { + assert_eq!(inventory.executors().len(), PV_EXECUTOR_IDX); + + let public_values = PublicValuesExecutor::new( + NativeAdapterExecutor::default(), + self.num_public_values, + (self.max_constraint_degree as u32).checked_sub(1).unwrap(), + ); + inventory.add_executor(public_values, [PublishOpcode::PUBLISH.global_opcode()])?; + } + let phantom_opcode = SystemOpcode::PHANTOM.global_opcode(); + let mut phantom_executors: FxHashMap>> = + FxHashMap::default(); + // Use NopPhantomExecutor so the discriminant is set but `DebugPanic` is handled specially. + phantom_executors.insert( + PhantomDiscriminant(SysPhantom::DebugPanic as u16), + Arc::new(NopPhantomExecutor), + ); + phantom_executors.insert( + PhantomDiscriminant(SysPhantom::Nop as u16), + Arc::new(NopPhantomExecutor), + ); + phantom_executors.insert( + PhantomDiscriminant(SysPhantom::CtStart as u16), + Arc::new(CycleStartPhantomExecutor), + ); + phantom_executors.insert( + PhantomDiscriminant(SysPhantom::CtEnd as u16), + Arc::new(CycleEndPhantomExecutor), + ); + let phantom = PhantomExecutor::new(phantom_executors, phantom_opcode); + inventory.add_executor(phantom, [phantom_opcode])?; + + Ok(inventory) + } +} + +impl VmCircuitConfig for SystemConfig { + /// Every VM circuit within the OpenVM circuit architecture **must** be initialized from the + /// [SystemConfig]. + fn create_airs(&self) -> Result, AirInventoryError> { + let mut bus_idx_mgr = BusIndexManager::new(); + let execution_bus = ExecutionBus::new(bus_idx_mgr.new_bus_idx()); + let memory_bus = MemoryBus::new(bus_idx_mgr.new_bus_idx()); + let program_bus = ProgramBus::new(bus_idx_mgr.new_bus_idx()); + let range_bus = + VariableRangeCheckerBus::new(bus_idx_mgr.new_bus_idx(), self.memory_config.decomp); + + let merkle_compression_buses = if self.continuation_enabled { + let merkle_bus = PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()); + let compression_bus = PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()); + Some((merkle_bus, compression_bus)) + } else { + None + }; + let memory_bridge = + MemoryBridge::new(memory_bus, self.memory_config.timestamp_max_bits, range_bus); + let system_port = SystemPort { + execution_bus, + program_bus, + memory_bridge, + }; + let system = SystemAirInventory::new(self, system_port, merkle_compression_buses); + + let mut inventory = AirInventory::new(self.clone(), system, bus_idx_mgr); + + let range_checker = VariableRangeCheckerAir::new(range_bus); + // Range checker is always the first AIR in the inventory + inventory.add_air(range_checker); + + if self.continuation_enabled { + assert_eq!(inventory.ext_airs().len(), POSEIDON2_INSERTION_IDX); + // Add direct poseidon2 AIR for persistent memory. + // Currently we never use poseidon2 opcodes when continuations is enabled: we will need + // special handling when that happens + let (_, compression_bus) = merkle_compression_buses.unwrap(); + let direct_bus_idx = compression_bus.index; + let air = new_poseidon2_periphery_air( + vm_poseidon2_config(), + LookupBus::new(direct_bus_idx), + self.max_constraint_degree, + ); + inventory.add_air_ref(air); + } + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let phantom = PhantomAir { + execution_bridge, + phantom_opcode: SystemOpcode::PHANTOM.global_opcode(), + }; + inventory.add_air(phantom); + + Ok(inventory) + } +} + +// =================== CPU Backend Specific System Chip Complex Constructor ================== + +/// Base system chips for CPU backend. These chips must exactly correspond to the AIRs in +/// [SystemAirInventory]. +pub struct SystemChipInventory { + pub program_chip: ProgramChip, + pub connector_chip: VmConnectorChip>, + /// Contains all memory chips + pub memory_controller: MemoryController>, + pub public_values_chip: Option>>, +} + +// Note[jpw]: We could get rid of the `mem_inventory` input because `MemoryController` doesn't need +// the buses for tracegen. We leave it to use old interfaces. +impl SystemChipInventory +where + Val: PrimeField32, +{ + pub fn new( + config: &SystemConfig, + mem_inventory: &MemoryAirInventory, + range_checker: SharedVariableRangeCheckerChip, + hasher_chip: Option>>>, + ) -> Self { + // We create an empty program chip: the program should be loaded later (and can be swapped + // out). The execution frequencies are supplied only after execution. + let program_chip = ProgramChip::unloaded(); + let connector_chip = VmConnectorChip::>::new( + range_checker.clone(), + config.memory_config.timestamp_max_bits, + ); + let memory_bus = mem_inventory.bridge.memory_bus(); + let memory_controller = match &mem_inventory.interface { + MemoryInterfaceAirs::Persistent { + boundary: _, + merkle, + } => { + assert!(config.continuation_enabled); + MemoryController::>::with_persistent_memory( + memory_bus, + config.memory_config.clone(), + range_checker.clone(), + merkle.merkle_bus, + merkle.compression_bus, + hasher_chip.unwrap(), + ) + } + MemoryInterfaceAirs::Volatile { boundary: _ } => { + assert!(!config.continuation_enabled); + MemoryController::with_volatile_memory( + memory_bus, + config.memory_config.clone(), + range_checker.clone(), + ) + } + }; + + let public_values_chip = config.has_public_values_chip().then(|| { + VmChipWrapper::new( + PublicValuesExecutor::new( + NativeAdapterExecutor::default(), + config.num_public_values, + config.max_constraint_degree as u32 - 1, + ), + memory_controller.helper(), + ) + }); + + Self { + program_chip, + connector_chip, + memory_controller, + public_values_chip, + } + } +} + +impl SystemChipComplex> for SystemChipInventory +where + RA: RowMajorMatrixArena>, + SC: StarkGenericConfig, + Val: PrimeField32, +{ + fn load_program(&mut self, cached_program_trace: CommittedTraceData>) { + let _ = self.program_chip.cached.replace(cached_program_trace); + } + + fn transport_init_memory_to_device(&mut self, memory: &GuestMemory) { + self.memory_controller + .set_initial_memory(memory.memory.clone()); + } + + fn generate_proving_ctx( + &mut self, + system_records: SystemRecords>, + mut record_arenas: Vec, + ) -> Vec>> { + let SystemRecords { + from_state, + to_state, + exit_code, + filtered_exec_frequencies, + access_adapter_records, + touched_memory, + public_values, + } = system_records; + + if let Some(chip) = &mut self.public_values_chip { + chip.inner.set_public_values(&public_values); + } + self.program_chip.filtered_exec_frequencies = filtered_exec_frequencies; + let program_ctx = self.program_chip.generate_proving_ctx(()); + self.connector_chip.begin(from_state); + self.connector_chip.end(to_state, exit_code); + let connector_ctx = self.connector_chip.generate_proving_ctx(()); + + let pv_ctx = self.public_values_chip.as_ref().map(|chip| { + let arena = record_arenas.remove(PUBLIC_VALUES_AIR_ID); + chip.generate_proving_ctx(arena) + }); + + let memory_ctxs = self + .memory_controller + .generate_proving_ctx(access_adapter_records, touched_memory); + + [program_ctx, connector_ctx] + .into_iter() + .chain(pv_ctx) + .chain(memory_ctxs) + .collect() + } + + #[cfg(feature = "metrics")] + fn finalize_trace_heights(&self, heights: &mut [usize]) { + use openvm_stark_backend::ChipUsageGetter; + + use crate::system::memory::interface::MemoryInterface; + + let boundary_idx = PUBLIC_VALUES_AIR_ID + usize::from(self.public_values_chip.is_some()); + let mut access_adapter_offset = boundary_idx + 1; + match &self.memory_controller.interface_chip { + MemoryInterface::Volatile { boundary_chip } => { + let boundary_height = boundary_chip + .final_memory + .as_ref() + .map(|m| m.len()) + .unwrap_or(0); + heights[boundary_idx] = boundary_height; + } + MemoryInterface::Persistent { + boundary_chip, + merkle_chip, + .. + } => { + let boundary_height = 2 * boundary_chip.touched_labels.len(); + heights[boundary_idx] = boundary_height; + heights[boundary_idx + 1] = merkle_chip.current_height; + access_adapter_offset += 1; + + // Poseidon2Periphery height also varies based on memory, so set it now even though + // it's not a system chip: + let poseidon_chip = self.memory_controller.hasher_chip.as_ref().unwrap(); + let poseidon_height = poseidon_chip.current_trace_height(); + // We know the chip insertion index, which starts from *the end* of the the AIR + // ordering + let poseidon_idx = heights.len() - 1 - POSEIDON2_INSERTION_IDX; + heights[poseidon_idx] = poseidon_height; + } + } + let access_heights = &self + .memory_controller + .access_adapter_inventory + .trace_heights; + heights[access_adapter_offset..access_adapter_offset + access_heights.len()] + .copy_from_slice(access_heights); + } +} + +#[derive(Clone)] +pub struct SystemCpuBuilder; + +impl VmBuilder for SystemCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = SystemConfig; + type RecordArena = MatrixRecordArena>; + type SystemChipInventory = SystemChipInventory; + + fn create_chip_complex( + &self, + config: &SystemConfig, + airs: AirInventory, + ) -> Result< + VmChipComplex>, CpuBackend, SystemChipInventory>, + ChipInventoryError, + > { + let range_bus = airs.range_checker().bus; + let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + + let mut inventory = ChipInventory::new(airs); + // PublicValuesChip is required when num_public_values > 0 in single segment mode. + if config.has_public_values_chip() { + assert_eq!( + inventory.executor_idx_to_insertion_idx.len(), + PV_EXECUTOR_IDX + ); + // We set insertion_idx so that air_idx = num_airs - (insertion_idx + 1) = + // PUBLIC_VALUES_AIR_ID in `VmChipComplex::executor_idx_to_air_idx`. We need to do this + // because this chip is special and not part of the normal inventory. + let insertion_idx = inventory + .airs() + .num_airs() + .checked_sub(1 + PUBLIC_VALUES_AIR_ID) + .unwrap(); + inventory.executor_idx_to_insertion_idx.push(insertion_idx); + } + inventory.next_air::()?; + inventory.add_periphery_chip(range_checker.clone()); + + let hasher_chip = if config.continuation_enabled { + assert_eq!(inventory.chips().len(), POSEIDON2_INSERTION_IDX); + // ATTENTION: The threshold 7 here must match the one in `new_poseidon2_periphery_air` + let direct_bus = if config.max_constraint_degree >= 7 { + inventory + .next_air::, 0>>()? + .bus + } else { + inventory + .next_air::, 1>>()? + .bus + }; + let chip = Arc::new(Poseidon2PeripheryChip::new( + vm_poseidon2_config(), + direct_bus.index, + config.max_constraint_degree, + )); + inventory.add_periphery_chip(chip.clone()); + Some(chip) + } else { + None + }; + let system = SystemChipInventory::new( + config, + &inventory.airs().system().memory, + range_checker, + hasher_chip, + ); + + let phantom_chip = PhantomChip::new(PhantomFiller, system.memory_controller.helper()); + inventory.add_executor_chip(phantom_chip); + + Ok(VmChipComplex { system, inventory }) + } +} + +impl SystemWithFixedTraceHeights for SystemChipInventory +where + Val: PrimeField32, +{ + /// Warning: this does not set the override for the PublicValuesChip. The PublicValuesChip + /// override must be set via the RecordArena. + fn override_trace_heights(&mut self, heights: &[u32]) { + assert_eq!( + heights[PROGRAM_AIR_ID] as usize, + self.program_chip + .cached + .as_ref() + .expect("program not loaded") + .trace + .height() + ); + assert_eq!(heights[CONNECTOR_AIR_ID], 2); + let mut memory_start_idx = PUBLIC_VALUES_AIR_ID; + if self.public_values_chip.is_some() { + memory_start_idx += 1; + } + self.memory_controller + .set_override_trace_heights(&heights[memory_start_idx..]); + } +} diff --git a/crates/vm/src/system/native_adapter/mod.rs b/crates/vm/src/system/native_adapter/mod.rs index 95c2c7c4a4..7dcf8c1b2e 100644 --- a/crates/vm/src/system/native_adapter/mod.rs +++ b/crates/vm/src/system/native_adapter/mod.rs @@ -1,3 +1,5 @@ +pub mod util; + use std::{ borrow::{Borrow, BorrowMut}, marker::PhantomData, @@ -5,86 +7,31 @@ use std::{ use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + AdapterAirContext, BasicAdapterInterface, ExecutionBridge, ExecutionState, + MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, + MemoryAddress, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_IMM_AS, NATIVE_AS, +}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; - -use crate::system::memory::{OfflineMemory, RecordId}; - -/// R reads(R<=2), W writes(W<=1). -/// Operands: b for the first read, c for the second read, a for the first write. -/// If an operand is not used, its address space and pointer should be all 0. -#[derive(Debug)] -pub struct NativeAdapterChip { - pub air: NativeAdapterAir, - _phantom: PhantomData, -} - -impl NativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: NativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _phantom: PhantomData, - } - } -} +use util::{tracing_read_or_imm_native, tracing_write_native}; -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeReadRecord { - #[serde(with = "BigArray")] - pub reads: [(RecordId, [F; 1]); R], -} - -impl NativeReadRecord { - pub fn b(&self) -> &[F; 1] { - &self.reads[0].1 - } - - pub fn c(&self) -> &[F; 1] { - &self.reads[1].1 - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeWriteRecord { - pub from_state: ExecutionState, - #[serde(with = "BigArray")] - pub writes: [(RecordId, [F; 1]); W], -} - -impl NativeWriteRecord { - pub fn a(&self) -> &[F; 1] { - &self.writes[0].1 - } -} +use super::memory::{online::TracingMemory, MemoryAuxColsFactory}; +use crate::{ + arch::{get_record_from_slice, AdapterTraceExecutor, AdapterTraceFiller}, + system::memory::offline_checker::{MemoryReadAuxRecord, MemoryWriteAuxRecord}, +}; #[repr(C)] #[derive(AlignedBorrow)] @@ -205,101 +152,160 @@ impl VmAdapterAir } } -impl VmAdapterChip - for NativeAdapterChip -{ - type ReadRecord = NativeReadRecord; - type WriteRecord = NativeWriteRecord; - type Air = NativeAdapterAir; - type Interface = BasicAdapterInterface, R, W, 1, 1>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - assert!(R <= 2); - let Instruction { b, c, e, f, .. } = *instruction; + // These are either a pointer to native memory or an immediate value + pub read_ptr_or_imm: [F; R], + // Will set prev_timestamp to `u32::MAX` if the read is from RV32_IMM_AS + pub reads_aux: [MemoryReadAuxRecord; R], + pub write_ptr: [F; W], + pub writes_aux: [MemoryWriteAuxRecord; W], +} - let mut reads = Vec::with_capacity(R); - if R >= 1 { - reads.push(memory.read::<1>(e, b)); - } - if R >= 2 { - reads.push(memory.read::<1>(f, c)); +/// R reads(R<=2), W writes(W<=1). +/// Operands: b for the first read, c for the second read, a for the first write. +/// If an operand is not used, its address space and pointer should be all 0. +#[derive(Clone, Debug)] +pub struct NativeAdapterExecutor { + _phantom: PhantomData, +} + +impl Default for NativeAdapterExecutor { + fn default() -> Self { + Self { + _phantom: PhantomData, } - let i_reads: [_; R] = std::array::from_fn(|i| reads[i].1); + } +} - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) +impl AdapterTraceExecutor for NativeAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[F; 1]; R]; + type WriteData = [[F; 1]; W]; + type RecordMut<'a> = &'a mut NativeAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - assert!(W <= 1); - let Instruction { a, d, .. } = *instruction; - let mut writes = Vec::with_capacity(W); - if W >= 1 { - let (record_id, _) = memory.write(d, a, output.writes[0]); - writes.push((record_id, output.writes[0])); - } + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + debug_assert!(R <= 2); + let &Instruction { b, c, e, f, .. } = instruction; - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: writes.try_into().unwrap(), - }, - )) + let mut reads = [[F::ZERO; 1]; R]; + record + .read_ptr_or_imm + .iter_mut() + .enumerate() + .zip(record.reads_aux.iter_mut()) + .for_each(|((i, ptr_or_imm), read_aux)| { + *ptr_or_imm = if i == 0 { b } else { c }; + let addr_space = if i == 0 { e } else { f }; + reads[i][0] = tracing_read_or_imm_native( + memory, + addr_space, + *ptr_or_imm, + &mut read_aux.prev_timestamp, + ); + }); + reads } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let row_slice: &mut NativeAdapterCols<_, R, W> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); + let &Instruction { a, d, .. } = instruction; + debug_assert!(W <= 1); + debug_assert_eq!(d.as_canonical_u32(), NATIVE_AS); - for (i, read) in read_record.reads.iter().enumerate() { - let (id, _) = read; - let record = memory.record_by_id(*id); - aux_cols_factory - .generate_read_or_immediate_aux(record, &mut row_slice.reads_aux[i].read_aux); - row_slice.reads_aux[i].address = - MemoryAddress::new(record.address_space, record.pointer); + if W >= 1 { + record.write_ptr[0] = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data[0], + &mut record.writes_aux[0].prev_timestamp, + &mut record.writes_aux[0].prev_data, + ); } + } +} + +impl AdapterTraceFiller + for NativeAdapterExecutor +{ + const WIDTH: usize = size_of::>(); - for (i, write) in write_record.writes.iter().enumerate() { - let (id, _) = write; - let record = memory.record_by_id(*id); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i].write_aux); - row_slice.writes_aux[i].address = - MemoryAddress::new(record.address_space, record.pointer); + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &NativeAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut NativeAdapterCols<_, R, W> = adapter_row.borrow_mut(); + // Writing in reverse order to avoid overwriting the `record` + if W >= 1 { + adapter_row.writes_aux[0] + .write_aux + .set_prev_data(record.writes_aux[0].prev_data); + mem_helper.fill( + record.writes_aux[0].prev_timestamp, + record.from_timestamp + R as u32, + adapter_row.writes_aux[0].write_aux.as_mut(), + ); + adapter_row.writes_aux[0].address.pointer = record.write_ptr[0]; + adapter_row.writes_aux[0].address.address_space = F::from_canonical_u32(NATIVE_AS); } - } - fn air(&self) -> &Self::Air { - &self.air + adapter_row + .reads_aux + .iter_mut() + .enumerate() + .zip(record.reads_aux.iter().zip(record.read_ptr_or_imm.iter())) + .rev() + .for_each(|((i, read_cols), (read_record, ptr_or_imm))| { + if read_record.prev_timestamp == u32::MAX { + read_cols.read_aux.is_zero_aux = F::ZERO; + read_cols.read_aux.is_immediate = F::ONE; + mem_helper.fill( + 0, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = *ptr_or_imm; + read_cols.address.address_space = F::from_canonical_u32(RV32_IMM_AS); + } else { + read_cols.read_aux.is_zero_aux = F::from_canonical_u32(NATIVE_AS).inverse(); + read_cols.read_aux.is_immediate = F::ZERO; + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = *ptr_or_imm; + read_cols.address.address_space = F::from_canonical_u32(NATIVE_AS); + } + }); + + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/crates/vm/src/system/native_adapter/util.rs b/crates/vm/src/system/native_adapter/util.rs new file mode 100644 index 0000000000..8a081ee4eb --- /dev/null +++ b/crates/vm/src/system/native_adapter/util.rs @@ -0,0 +1,198 @@ +use openvm_circuit::system::memory::online::TracingMemory; +use openvm_instructions::{riscv::RV32_IMM_AS, NATIVE_AS}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + arch::{execution_mode::E1ExecutionCtx, VmStateMut}, + system::memory::{offline_checker::MemoryWriteAuxCols, online::GuestMemory}, +}; + +#[inline(always)] +pub fn memory_read_native(memory: &GuestMemory, ptr: u32) -> [F; N] +where + F: PrimeField32, +{ + // SAFETY: + // - address space `NATIVE_AS` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.read::(NATIVE_AS, ptr) } +} + +#[inline(always)] +pub fn memory_read_or_imm_native(memory: &GuestMemory, addr_space: u32, ptr_or_imm: F) -> F +where + F: PrimeField32, +{ + debug_assert!(addr_space == RV32_IMM_AS || addr_space == NATIVE_AS); + + if addr_space == NATIVE_AS { + let [result]: [F; 1] = memory_read_native(memory, ptr_or_imm.as_canonical_u32()); + result + } else { + ptr_or_imm + } +} + +#[inline(always)] +pub fn memory_write_native(memory: &mut GuestMemory, ptr: u32, data: [F; N]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `NATIVE_AS` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.write::(NATIVE_AS, ptr, data) } +} + +#[inline(always)] +pub fn memory_read_native_from_state( + state: &mut VmStateMut, + ptr: u32, +) -> [F; N] +where + F: PrimeField32, + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(NATIVE_AS, ptr, N as u32); + + memory_read_native(state.memory, ptr) +} + +#[inline(always)] +pub fn memory_read_or_imm_native_from_state( + state: &mut VmStateMut, + addr_space: u32, + ptr_or_imm: F, +) -> F +where + F: PrimeField32, + Ctx: E1ExecutionCtx, +{ + debug_assert!(addr_space == RV32_IMM_AS || addr_space == NATIVE_AS); + + if addr_space == NATIVE_AS { + let [result]: [F; 1] = memory_read_native_from_state(state, ptr_or_imm.as_canonical_u32()); + result + } else { + ptr_or_imm + } +} + +#[inline(always)] +pub fn memory_write_native_from_state( + state: &mut VmStateMut, + ptr: u32, + data: [F; N], +) where + F: PrimeField32, + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(NATIVE_AS, ptr, N as u32); + + memory_write_native(state.memory, ptr, data) +} + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:BLOCK_SIZE]_4)` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +pub fn timed_read_native( + memory: &mut TracingMemory, + ptr: u32, +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.read::(NATIVE_AS, ptr) } +} + +#[inline(always)] +pub fn timed_write_native( + memory: &mut TracingMemory, + ptr: u32, + vals: [F; BLOCK_SIZE], +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.write::(NATIVE_AS, ptr, vals) } +} + +/// Reads register value at `ptr` from memory and records the previous timestamp. +/// Reads are only done from address space [NATIVE_AS]. +#[inline(always)] +pub fn tracing_read_native( + memory: &mut TracingMemory, + ptr: u32, + prev_timestamp: &mut u32, +) -> [F; BLOCK_SIZE] +where + F: PrimeField32, +{ + let (t_prev, data) = timed_read_native(memory, ptr); + *prev_timestamp = t_prev; + data +} + +/// Writes `ptr, vals` into memory and records the previous timestamp and data. +/// Writes are only done to address space [NATIVE_AS]. +#[inline(always)] +pub fn tracing_write_native( + memory: &mut TracingMemory, + ptr: u32, + vals: [F; BLOCK_SIZE], + prev_timestamp: &mut u32, + prev_data: &mut [F; BLOCK_SIZE], +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write_native(memory, ptr, vals); + *prev_timestamp = t_prev; + *prev_data = data_prev; +} + +/// Writes `ptr, vals` into memory and records the previous timestamp and data. +#[inline(always)] +pub fn tracing_write_native_inplace( + memory: &mut TracingMemory, + ptr: u32, + vals: [F; BLOCK_SIZE], + cols: &mut MemoryWriteAuxCols, +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write_native(memory, ptr, vals); + cols.base.set_prev(F::from_canonical_u32(t_prev)); + cols.prev_data = data_prev; +} + +/// Reads value at `_ptr` from memory and records the previous timestamp. +/// If the read is an immediate, the previous timestamp will be set to `u32::MAX`. +#[inline(always)] +pub fn tracing_read_or_imm_native( + memory: &mut TracingMemory, + addr_space: F, + ptr_or_imm: F, + prev_timestamp: &mut u32, +) -> F +where + F: PrimeField32, +{ + debug_assert!( + addr_space == F::ZERO || addr_space == F::from_canonical_u32(NATIVE_AS), + "addr_space={} is not valid", + addr_space + ); + + if addr_space == F::ZERO { + *prev_timestamp = u32::MAX; + memory.increment_timestamp(); + ptr_or_imm + } else { + let data: [F; 1] = + tracing_read_native(memory, ptr_or_imm.as_canonical_u32(), prev_timestamp); + data[0] + } +} diff --git a/crates/vm/src/system/phantom/execution.rs b/crates/vm/src/system/phantom/execution.rs new file mode 100644 index 0000000000..e9142a536b --- /dev/null +++ b/crates/vm/src/system/phantom/execution.rs @@ -0,0 +1,192 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, PhantomDiscriminant, SysPhantom, +}; +use openvm_stark_backend::p3_field::PrimeField32; +use rand::rngs::StdRng; + +use crate::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, ExecutionError, Executor, MeteredExecutor, PhantomSubExecutor, + StaticProgramError, Streams, VmExecState, + }, + system::{memory::online::GuestMemory, phantom::PhantomExecutor}, +}; + +#[derive(Clone, AlignedBytesBorrow)] +#[repr(C)] +pub(super) struct PhantomOperands { + pub(super) a: u32, + pub(super) b: u32, + pub(super) c: u32, +} + +#[derive(Clone, AlignedBytesBorrow)] +#[repr(C)] +struct PhantomPreCompute { + operands: PhantomOperands, + sub_executor: *const dyn PhantomSubExecutor, +} + +impl Executor for PhantomExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::>() + } + #[inline(always)] + fn pre_compute( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut PhantomPreCompute = data.borrow_mut(); + self.pre_compute_impl(inst, data); + Ok(execute_e1_impl) + } +} + +pub(super) struct PhantomStateMut<'a, F> { + pub(super) pc: &'a mut u32, + pub(super) memory: &'a mut GuestMemory, + pub(super) streams: &'a mut Streams, + pub(super) rng: &'a mut StdRng, +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &PhantomPreCompute, + vm_state: &mut VmExecState, +) { + let sub_executor = &*pre_compute.sub_executor; + if let Err(e) = execute_impl( + PhantomStateMut { + pc: &mut vm_state.vm_state.pc, + memory: &mut vm_state.vm_state.memory, + streams: &mut vm_state.vm_state.streams, + rng: &mut vm_state.vm_state.rng, + }, + &pre_compute.operands, + sub_executor, + ) { + vm_state.exit_code = Err(e); + return; + } + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &PhantomPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} + +#[inline(always)] +fn execute_impl( + state: PhantomStateMut, + operands: &PhantomOperands, + sub_executor: &dyn PhantomSubExecutor, +) -> Result<(), ExecutionError> { + let &PhantomOperands { a, b, c } = operands; + + let discriminant = PhantomDiscriminant(c as u16); + // SysPhantom::{CtStart, CtEnd} are only handled in Preflight Execution, so the only SysPhantom + // to handle here is DebugPanic. + if let Some(discr) = SysPhantom::from_repr(discriminant.0) { + if discr == SysPhantom::DebugPanic { + return Err(ExecutionError::Fail { + pc: *state.pc, + msg: "DebugPanic", + }); + } + } + sub_executor + .phantom_execute( + state.memory, + state.streams, + state.rng, + discriminant, + a, + b, + (c >> 16) as u16, + ) + .map_err(|e| ExecutionError::Phantom { + pc: *state.pc, + discriminant, + inner: e, + })?; + + Ok(()) +} + +impl PhantomExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_impl(&self, inst: &Instruction, data: &mut PhantomPreCompute) { + let c = inst.c.as_canonical_u32(); + *data = PhantomPreCompute { + operands: PhantomOperands { + a: inst.a.as_canonical_u32(), + b: inst.b.as_canonical_u32(), + c, + }, + sub_executor: self + .phantom_executors + .get(&PhantomDiscriminant(c as u16)) + .unwrap_or_else(|| panic!("Phantom executor not found for insn {inst:?}")) + .as_ref(), + }; + } +} + +impl MeteredExecutor for PhantomExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let e2_data: &mut E2PreCompute> = data.borrow_mut(); + e2_data.chip_idx = chip_idx as u32; + self.pre_compute_impl(inst, &mut e2_data.data); + Ok(execute_e2_impl) + } +} diff --git a/crates/vm/src/system/phantom/mod.rs b/crates/vm/src/system/phantom/mod.rs index 28977fe2cd..7fbfe5adaf 100644 --- a/crates/vm/src/system/phantom/mod.rs +++ b/crates/vm/src/system/phantom/mod.rs @@ -1,37 +1,41 @@ +//! Chip to handle phantom instructions. +//! The Air will always constrain a NOP which advances pc by DEFAULT_PC_STEP. +//! The runtime executor will execute different phantom instructions that may +//! affect trace generation based on the operand. use std::{ borrow::{Borrow, BorrowMut}, - sync::{Arc, Mutex, OnceLock}, + sync::Arc, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ - instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant, - SysPhantom, SystemOpcode, VmOpcode, + instruction::Instruction, program::DEFAULT_PC_STEP, PhantomDiscriminant, SysPhantom, + SystemOpcode, VmOpcode, }; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; +use rand::rngs::StdRng; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use serde_big_array::BigArray; -use super::memory::MemoryController; +use super::memory::online::{GuestMemory, TracingMemory}; use crate::{ arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, - PcIncOrSet, PhantomSubExecutor, Streams, + get_record_from_slice, EmptyMultiRowLayout, ExecutionBridge, ExecutionError, + ExecutionState, PcIncOrSet, PhantomSubExecutor, PreflightExecutor, RecordArena, Streams, + TraceFiller, VmChipWrapper, VmStateMut, }, - system::program::ProgramBus, + system::memory::MemoryAuxColsFactory, }; +mod execution; #[cfg(test)] mod tests; @@ -88,95 +92,105 @@ impl Air for PhantomAir { } } -pub struct PhantomChip { - pub air: PhantomAir, - pub rows: Vec>, - streams: OnceLock>>>, - phantom_executors: FxHashMap>>, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct PhantomRecord { + pub pc: u32, + pub operands: [u32; NUM_PHANTOM_OPERANDS], + pub timestamp: u32, } -impl PhantomChip { - pub fn new(execution_bus: ExecutionBus, program_bus: ProgramBus, offset: usize) -> Self { - Self { - air: PhantomAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - phantom_opcode: VmOpcode::from_usize(offset + SystemOpcode::PHANTOM.local_usize()), - }, - rows: vec![], - streams: OnceLock::new(), - phantom_executors: FxHashMap::default(), - } - } - - pub fn set_streams(&mut self, streams: Arc>>) { - if self.streams.set(streams).is_err() { - panic!("Streams should only be set once"); - } - } - - pub(crate) fn add_sub_executor + 'static>( - &mut self, - sub_executor: P, - discriminant: PhantomDiscriminant, - ) -> Option>> { - self.phantom_executors - .insert(discriminant, Box::new(sub_executor)) - } +/// `PhantomChip` is a special executor because it is stateful and stores all the phantom +/// sub-executors. +#[derive(Clone, derive_new::new)] +pub struct PhantomExecutor { + pub(crate) phantom_executors: FxHashMap>>, + phantom_opcode: VmOpcode, } -impl InstructionExecutor for PhantomChip { +pub struct PhantomFiller; +pub type PhantomChip = VmChipWrapper; + +impl PreflightExecutor for PhantomExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, EmptyMultiRowLayout, &'buf mut PhantomRecord>, +{ fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let &Instruction { - opcode, a, b, c, .. - } = instruction; - assert_eq!(opcode, self.air.phantom_opcode); - - let c_u32 = c.as_canonical_u32(); - let discriminant = PhantomDiscriminant(c_u32 as u16); - // If not a system phantom sub-instruction (which is handled in - // ExecutionSegment), look for a phantom sub-executor to handle it. - if SysPhantom::from_repr(discriminant.0).is_none() { - let sub_executor = self - .phantom_executors - .get_mut(&discriminant) - .ok_or_else(|| ExecutionError::PhantomNotFound { - pc: from_state.pc, - discriminant, - })?; - let mut streams = self.streams.get().unwrap().lock().unwrap(); + ) -> Result<(), ExecutionError> { + let record: &mut PhantomRecord = state.ctx.alloc(EmptyMultiRowLayout::default()); + let pc = *state.pc; + record.pc = pc; + record.timestamp = state.memory.timestamp; + let [a, b, c] = [instruction.a, instruction.b, instruction.c].map(|x| x.as_canonical_u32()); + record.operands = [a, b, c]; + + debug_assert_eq!(instruction.opcode, self.phantom_opcode); + let discriminant = PhantomDiscriminant(c as u16); + if let Some(sys) = SysPhantom::from_repr(discriminant.0) { + tracing::trace!("pc: {pc:#x} | system phantom: {sys:?}"); + match sys { + SysPhantom::DebugPanic => { + #[cfg(all( + feature = "metrics", + any(debug_assertions, feature = "perf-metrics") + ))] + { + let metrics = state.metrics; + metrics.update_backtrace(pc); + if let Some(mut backtrace) = metrics.prev_backtrace.take() { + backtrace.resolve(); + eprintln!("openvm program failure; backtrace:\n{:?}", backtrace); + } else { + eprintln!("openvm program failure; no backtrace"); + } + } + return Err(ExecutionError::Fail { + pc, + msg: "DebugPanic", + }); + } + #[cfg(feature = "perf-metrics")] + SysPhantom::CtStart => { + let metrics = state.metrics; + if let Some(info) = metrics.debug_infos.get(pc) { + metrics.cycle_tracker.start(info.dsl_instruction.clone()); + } + } + #[cfg(feature = "perf-metrics")] + SysPhantom::CtEnd => { + let metrics = state.metrics; + if let Some(info) = metrics.debug_infos.get(pc) { + metrics.cycle_tracker.end(info.dsl_instruction.clone()); + } + } + _ => {} + } + } else { + let sub_executor = self.phantom_executors.get(&discriminant).unwrap(); sub_executor - .as_mut() .phantom_execute( - memory, - &mut streams, + &state.memory.data, + state.streams, + state.rng, discriminant, a, b, - (c_u32 >> 16) as u16, + (c >> 16) as u16, ) - .map_err(|e| ExecutionError::Phantom { - pc: from_state.pc, + .map_err(|err| ExecutionError::Phantom { + pc, discriminant, - inner: e, + inner: err, })?; } + *state.pc += DEFAULT_PC_STEP; + state.memory.increment_timestamp(); - self.rows.push(PhantomCols { - pc: F::from_canonical_u32(from_state.pc), - operands: [a, b, c], - timestamp: F::from_canonical_u32(from_state.timestamp), - is_valid: F::ONE, - }); - memory.increment_timestamp(); - Ok(ExecutionState::new( - from_state.pc + DEFAULT_PC_STEP, - from_state.timestamp + 1, - )) + Ok(()) } fn get_opcode_name(&self, _: usize) -> String { @@ -184,41 +198,72 @@ impl InstructionExecutor for PhantomChip { } } -impl ChipUsageGetter for PhantomChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.rows.len() - } - fn trace_width(&self) -> usize { - PhantomCols::::width() +impl TraceFiller for PhantomFiller { + fn fill_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, mut row_slice: &mut [F]) { + // SAFETY: assume that row has size PhantomCols::::width() + let record: &PhantomRecord = unsafe { get_record_from_slice(&mut row_slice, ()) }; + let row: &mut PhantomCols = row_slice.borrow_mut(); + // SAFETY: must assign in reverse order of column struct to prevent overwriting + // borrowed data + row.is_valid = F::ONE; + row.timestamp = F::from_canonical_u32(record.timestamp); + row.operands[2] = F::from_canonical_u32(record.operands[2]); + row.operands[1] = F::from_canonical_u32(record.operands[1]); + row.operands[0] = F::from_canonical_u32(record.operands[0]); + row.pc = F::from_canonical_u32(record.pc) } - fn current_trace_cells(&self) -> usize { - self.trace_width() * self.current_trace_height() +} + +pub struct NopPhantomExecutor; +pub struct CycleStartPhantomExecutor; +pub struct CycleEndPhantomExecutor; + +impl PhantomSubExecutor for NopPhantomExecutor { + #[inline(always)] + fn phantom_execute( + &self, + _memory: &GuestMemory, + _streams: &mut Streams, + _rng: &mut StdRng, + _discriminant: PhantomDiscriminant, + _a: u32, + _b: u32, + _c_upper: u16, + ) -> eyre::Result<()> { + Ok(()) } } -impl Chip for PhantomChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) +impl PhantomSubExecutor for CycleStartPhantomExecutor { + #[inline(always)] + fn phantom_execute( + &self, + _memory: &GuestMemory, + _streams: &mut Streams, + _rng: &mut StdRng, + _discriminant: PhantomDiscriminant, + _a: u32, + _b: u32, + _c_upper: u16, + ) -> eyre::Result<()> { + // Cycle tracking is implemented separately only in Preflight Execution + Ok(()) } +} - fn generate_air_proof_input(self) -> AirProofInput { - let correct_height = self.rows.len().next_power_of_two(); - let width = PhantomCols::>::width(); - let mut rows = Val::::zero_vec(width * correct_height); - rows.par_chunks_mut(width) - .zip(&self.rows) - .for_each(|(row, row_record)| { - let row: &mut PhantomCols<_> = row.borrow_mut(); - *row = *row_record; - }); - let trace = RowMajorMatrix::new(rows, width); - - AirProofInput::simple(trace, vec![]) +impl PhantomSubExecutor for CycleEndPhantomExecutor { + #[inline(always)] + fn phantom_execute( + &self, + _memory: &GuestMemory, + _streams: &mut Streams, + _rng: &mut StdRng, + _discriminant: PhantomDiscriminant, + _a: u32, + _b: u32, + _c_upper: u16, + ) -> eyre::Result<()> { + // Cycle tracking is implemented separately only in Preflight Execution + Ok(()) } } diff --git a/crates/vm/src/system/phantom/tests.rs b/crates/vm/src/system/phantom/tests.rs index 7a0b068d36..14eec85e3c 100644 --- a/crates/vm/src/system/phantom/tests.rs +++ b/crates/vm/src/system/phantom/tests.rs @@ -1,34 +1,41 @@ -use std::sync::{Arc, Mutex}; - use openvm_instructions::{instruction::Instruction, SystemOpcode}; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use super::PhantomChip; -use crate::arch::{instructions::LocalOpcode, testing::VmChipTestBuilder, ExecutionState}; +use super::PhantomExecutor; +use crate::{ + arch::{ + instructions::LocalOpcode, + testing::{TestChipHarness, VmChipTestBuilder}, + ExecutionState, VmChipWrapper, + }, + system::phantom::{PhantomAir, PhantomFiller}, +}; type F = BabyBear; #[test] fn test_nops_and_terminate() { let mut tester = VmChipTestBuilder::default(); - let mut chip = PhantomChip::::new( - tester.execution_bus(), - tester.program_bus(), - SystemOpcode::CLASS_OFFSET, - ); - chip.set_streams(Arc::new(Mutex::new(Default::default()))); + let phantom_opcode = SystemOpcode::PHANTOM.global_opcode(); + let executor = PhantomExecutor::::new(Default::default(), phantom_opcode); + let air = PhantomAir { + execution_bridge: tester.execution_bridge(), + phantom_opcode, + }; + let chip = VmChipWrapper::new(PhantomFiller, tester.memory_helper()); + let num_nops = 5; + let mut harness = TestChipHarness::with_capacity(executor, air, chip, num_nops); - let nop = Instruction::from_isize(SystemOpcode::PHANTOM.global_opcode(), 0, 0, 0, 0, 0); + let nop = Instruction::from_isize(phantom_opcode, 0, 0, 0, 0, 0); let mut state: ExecutionState = ExecutionState::new(F::ZERO, F::ONE); - let num_nops = 5; for _ in 0..num_nops { - tester.execute_with_pc(&mut chip, &nop, state.pc.as_canonical_u32()); + tester.execute_with_pc(&mut harness, &nop, state.pc.as_canonical_u32()); let new_state = tester.execution.records.last().unwrap().final_state; assert_eq!(state.pc + F::from_canonical_usize(4), new_state.pc); assert_eq!(state.timestamp + F::ONE, new_state.timestamp); state = new_state; } - let tester = tester.build().load(chip).finalize(); + let tester = tester.build().load(harness).finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/crates/vm/src/system/poseidon2/air.rs b/crates/vm/src/system/poseidon2/air.rs index 99769d253d..81b99148e5 100644 --- a/crates/vm/src/system/poseidon2/air.rs +++ b/crates/vm/src/system/poseidon2/air.rs @@ -22,7 +22,7 @@ use super::columns::Poseidon2PeripheryCols; #[derive(Clone, new, Debug)] pub struct Poseidon2PeripheryAir { pub(super) subair: Arc>, - pub(super) bus: LookupBus, + pub bus: LookupBus, } impl BaseAirWithPublicValues diff --git a/crates/vm/src/system/poseidon2/chip.rs b/crates/vm/src/system/poseidon2/chip.rs index e0059f1ce1..f7053edcb5 100644 --- a/crates/vm/src/system/poseidon2/chip.rs +++ b/crates/vm/src/system/poseidon2/chip.rs @@ -1,14 +1,18 @@ use std::{ array, - sync::{atomic::AtomicU32, Arc}, + sync::{ + atomic::{AtomicBool, AtomicU32}, + Arc, + }, }; +use dashmap::DashMap; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip}; use openvm_stark_backend::{ interaction::{BusIndex, LookupBus}, - p3_field::PrimeField32, + p3_field::{Field, PrimeField32}, }; -use rustc_hash::FxHashMap; +use rustc_hash::FxBuildHasher; use super::{ air::Poseidon2PeripheryAir, PERIPHERY_POSEIDON2_CHUNK_SIZE, PERIPHERY_POSEIDON2_WIDTH, @@ -16,10 +20,11 @@ use super::{ use crate::arch::hasher::{Hasher, HasherChip}; #[derive(Debug)] -pub struct Poseidon2PeripheryBaseChip { +pub struct Poseidon2PeripheryBaseChip { pub air: Arc>, pub subchip: Poseidon2SubChip, - pub records: FxHashMap<[F; PERIPHERY_POSEIDON2_WIDTH], AtomicU32>, + pub records: DashMap<[F; PERIPHERY_POSEIDON2_WIDTH], AtomicU32, FxBuildHasher>, + pub nonempty: AtomicBool, } impl Poseidon2PeripheryBaseChip { @@ -31,7 +36,8 @@ impl Poseidon2PeripheryBaseChip HasherChip [F; PERIPHERY_POSEIDON2_CHUNK_SIZE] { @@ -73,6 +79,8 @@ impl HasherChip { +#[derive(Chip)] +#[chip(where = "F: Field")] +pub enum Poseidon2PeripheryChip { Register0(Poseidon2PeripheryBaseChip), Register1(Poseidon2PeripheryBaseChip), } @@ -49,22 +56,21 @@ impl Poseidon2PeripheryChip { } } -impl Chip for Poseidon2PeripheryChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - match self { - Poseidon2PeripheryChip::Register0(chip) => chip.air(), - Poseidon2PeripheryChip::Register1(chip) => chip.air(), - } - } - - fn generate_air_proof_input(self) -> AirProofInput { - match self { - Poseidon2PeripheryChip::Register0(chip) => chip.generate_air_proof_input(), - Poseidon2PeripheryChip::Register1(chip) => chip.generate_air_proof_input(), - } +pub fn new_poseidon2_periphery_air( + poseidon2_config: Poseidon2Config>, + direct_bus: LookupBus, + max_constraint_degree: usize, +) -> AirRef { + if max_constraint_degree >= 7 { + Arc::new(Poseidon2PeripheryAir::, 0>::new( + Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), + direct_bus, + )) + } else { + Arc::new(Poseidon2PeripheryAir::, 1>::new( + Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), + direct_bus, + )) } } @@ -106,7 +112,7 @@ impl Hasher for Poseidon2Per impl HasherChip for Poseidon2PeripheryChip { fn compress_and_record( - &mut self, + &self, lhs: &[F; PERIPHERY_POSEIDON2_CHUNK_SIZE], rhs: &[F; PERIPHERY_POSEIDON2_CHUNK_SIZE], ) -> [F; PERIPHERY_POSEIDON2_CHUNK_SIZE] { diff --git a/crates/vm/src/system/poseidon2/tests.rs b/crates/vm/src/system/poseidon2/tests.rs index 095c8acba4..31cbe9ff47 100644 --- a/crates/vm/src/system/poseidon2/tests.rs +++ b/crates/vm/src/system/poseidon2/tests.rs @@ -1,5 +1,9 @@ use openvm_poseidon2_air::Poseidon2Config; -use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; +use openvm_stark_backend::{ + interaction::LookupBus, + p3_field::{FieldAlgebra, PrimeField32}, + AirRef, +}; use openvm_stark_sdk::{ dummy_airs::interaction::dummy_interaction_air::{DummyInteractionChip, DummyInteractionData}, p3_baby_bear::BabyBear, @@ -10,13 +14,28 @@ use rand::RngCore; use crate::{ arch::{ hasher::{Hasher, HasherChip}, - testing::{VmChipTestBuilder, POSEIDON2_DIRECT_BUS}, + testing::{TestSC, VmChipTestBuilder, POSEIDON2_DIRECT_BUS}, }, system::poseidon2::{ - Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_CHUNK_SIZE, PERIPHERY_POSEIDON2_WIDTH, + new_poseidon2_periphery_air, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_CHUNK_SIZE, + PERIPHERY_POSEIDON2_WIDTH, }, }; +fn create_test_chip() -> (AirRef, Poseidon2PeripheryChip) { + let chip = Poseidon2PeripheryChip::::new( + Poseidon2Config::default(), + POSEIDON2_DIRECT_BUS, + 3, + ); + let air = new_poseidon2_periphery_air::( + Poseidon2Config::default(), + LookupBus::new(POSEIDON2_DIRECT_BUS), + 3, + ); + (air, chip) +} + /// Test that the direct bus interactions work. #[test] fn poseidon2_periphery_direct_test() { @@ -31,12 +50,7 @@ fn poseidon2_periphery_direct_test() { std::array::from_fn(|_| BabyBear::from_canonical_u32(rng.next_u32() % (1 << 30))), ) }); - - let mut chip = Poseidon2PeripheryChip::::new( - Poseidon2Config::default(), - POSEIDON2_DIRECT_BUS, - 3, - ); + let (air, chip) = create_test_chip(); let outs: [[BabyBear; PERIPHERY_POSEIDON2_CHUNK_SIZE]; NUM_OPS] = std::array::from_fn(|i| chip.compress_and_record(&hashes[i].0, &hashes[i].1)); @@ -65,8 +79,8 @@ fn poseidon2_periphery_direct_test() { let tester = VmChipTestBuilder::default(); let tester = tester .build() - .load(dummy_interaction_chip) - .load(chip) + .load_periphery((dummy_interaction_chip.air, dummy_interaction_chip)) + .load_periphery_ref((air, chip)) .finalize(); tester.simple_test().expect("Verification failed"); } @@ -86,11 +100,7 @@ fn poseidon2_periphery_duplicate_hashes_test() { }); let counts: [u32; NUM_OPS] = std::array::from_fn(|_| rng.next_u32() % 20); - let mut chip = Poseidon2PeripheryChip::::new( - Poseidon2Config::default(), - POSEIDON2_DIRECT_BUS, - 3, - ); + let (air, chip) = create_test_chip(); let outs: [[BabyBear; PERIPHERY_POSEIDON2_CHUNK_SIZE]; NUM_OPS] = std::array::from_fn(|i| { for _ in 0..counts[i] { @@ -123,7 +133,7 @@ fn poseidon2_periphery_duplicate_hashes_test() { let tester = VmChipTestBuilder::default(); tester .build() - .load(chip) - .load(dummy_interaction_chip) + .load_periphery_ref((air, chip)) + .load_periphery((dummy_interaction_chip.air, dummy_interaction_chip)) .finalize(); } diff --git a/crates/vm/src/system/poseidon2/trace.rs b/crates/vm/src/system/poseidon2/trace.rs index 2b6f3e6b0b..26eb198ea2 100644 --- a/crates/vm/src/system/poseidon2/trace.rs +++ b/crates/vm/src/system/poseidon2/trace.rs @@ -1,4 +1,4 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, sync::Arc}; use openvm_circuit_primitives::utils::next_power_of_two_or_zero; use openvm_stark_backend::{ @@ -7,32 +7,33 @@ use openvm_stark_backend::{ p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::get_air_name, - AirRef, Chip, ChipUsageGetter, + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, ChipUsageGetter, }; use super::{columns::*, Poseidon2PeripheryBaseChip, PERIPHERY_POSEIDON2_WIDTH}; -impl Chip +impl Chip> for Poseidon2PeripheryBaseChip, SBOX_REGISTERS> where Val: PrimeField32, { - fn air(&self) -> AirRef { - self.air.clone() - } - - fn generate_air_proof_input(self) -> AirProofInput { + /// Generates trace and clears internal records state. + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { let height = next_power_of_two_or_zero(self.current_trace_height()); let width = self.trace_width(); let mut inputs = Vec::with_capacity(height); let mut multiplicities = Vec::with_capacity(height); - let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = self - .records - .into_par_iter() - .map(|(input, mult)| (input, mult.load(std::sync::atomic::Ordering::Relaxed))) + #[cfg(feature = "parallel")] + let records_iter = self.records.par_iter(); + #[cfg(not(feature = "parallel"))] + let records_iter = self.records.iter(); + let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = records_iter + .map(|r| { + let (input, mult) = r.pair(); + (*input, mult.load(std::sync::atomic::Ordering::Relaxed)) + }) .unzip(); inputs.extend(actual_inputs); multiplicities.extend(actual_multiplicities); @@ -54,8 +55,9 @@ where let cols: &mut Poseidon2PeripheryCols, SBOX_REGISTERS> = row.borrow_mut(); cols.mult = Val::::from_canonical_u32(mult); }); + self.records.clear(); - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) + AirProvingContext::simple_no_pis(Arc::new(RowMajorMatrix::new(values, width))) } } @@ -63,11 +65,16 @@ impl ChipUsageGetter for Poseidon2PeripheryBaseChip { fn air_name(&self) -> String { - get_air_name(&self.air) + format!("Poseidon2PeripheryAir", SBOX_REGISTERS) } fn current_trace_height(&self) -> usize { - self.records.len() + if self.nonempty.load(std::sync::atomic::Ordering::Relaxed) { + // Not to call `DashMap::len` too often + self.records.len() + } else { + 0 + } } fn trace_width(&self) -> usize { diff --git a/crates/vm/src/system/program/air.rs b/crates/vm/src/system/program/air.rs index 6336d7ad59..7d085877f8 100644 --- a/crates/vm/src/system/program/air.rs +++ b/crates/vm/src/system/program/air.rs @@ -32,7 +32,7 @@ pub struct ProgramExecutionCols { pub g: T, } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, derive_new::new)] pub struct ProgramAir { pub bus: ProgramBus, } diff --git a/crates/vm/src/system/program/mod.rs b/crates/vm/src/system/program/mod.rs index 3d68632d0f..63abd00d84 100644 --- a/crates/vm/src/system/program/mod.rs +++ b/crates/vm/src/system/program/mod.rs @@ -1,10 +1,16 @@ use openvm_instructions::{ - instruction::{DebugInfo, Instruction}, - program::Program, + instruction::Instruction, + program::{Program, DEFAULT_PC_STEP}, + LocalOpcode, SystemOpcode, +}; +use openvm_stark_backend::{ + config::StarkGenericConfig, + p3_field::Field, + p3_maybe_rayon::prelude::*, + prover::{cpu::CpuBackend, types::CommittedTraceData}, }; -use openvm_stark_backend::{p3_field::PrimeField64, ChipUsageGetter}; -use crate::{arch::ExecutionError, system::program::trace::padding_instruction}; +use crate::arch::{ExecutionError, ExecutorId, ExecutorInventory, StaticProgramError}; #[cfg(test)] pub mod tests; @@ -18,88 +24,153 @@ pub use bus::*; const EXIT_CODE_FAIL: usize = 1; -#[derive(Debug)] -pub struct ProgramChip { - pub air: ProgramAir, - pub program: Program, - pub true_program_length: usize, - pub execution_frequencies: Vec, +#[repr(C)] +pub struct PcEntry { + // TODO[jpw]: revisit storing only smaller `precompute` for better cache locality. Currently + // VmOpcode is usize so align=8 and there are 7 u32 operands so we store ExecutorId(u32) after + // to avoid padding. This means PcEntry has align=8 and size=40 bytes, which is too big + pub insn: Instruction, + pub executor_idx: ExecutorId, } -impl ProgramChip { - pub fn new(bus: ProgramBus) -> Self { +impl PcEntry { + pub fn is_some(&self) -> bool { + self.executor_idx != u32::MAX + } +} + +impl PcEntry { + fn undefined() -> Self { Self { - execution_frequencies: vec![], - program: Program::default(), - true_program_length: 0, - air: ProgramAir { bus }, + insn: Instruction::default(), + executor_idx: u32::MAX, } } +} - pub fn new_with_program(program: Program, bus: ProgramBus) -> Self { - let mut ret = Self::new(bus); - ret.set_program(program); - ret - } +// pc_handler, execution_frequencies, debug_infos will all have the same length, which equals +// `Program::len()` +pub struct ProgramHandler { + pub(crate) executors: Vec, + /// This is a map from (pc - pc_base) / pc_step -> [PcEntry]. + /// We will map to `u32::MAX` if the program has no instruction at that pc. + // Perf[jpw/ayush]: We could map directly to the raw pointer(u64) for executor, but storing the + // u32 may be better for cache efficiency. + pc_handler: Vec>, + execution_frequencies: Vec, + pc_base: u32, +} - pub fn set_program(&mut self, mut program: Program) { - let true_program_length = program.len(); - let mut number_actual_instructions = program.num_defined_instructions(); - while !number_actual_instructions.is_power_of_two() { - program.push_instruction(padding_instruction()); - number_actual_instructions += 1; +impl ProgramHandler { + /// Rewrite the program into compiled handlers. + /// + /// ## Assumption + /// There are less than `u32::MAX` total AIRs. + // @dev: We need to clone the executors because they are not completely stateless + pub fn new( + program: &Program, + inventory: &ExecutorInventory, + ) -> Result + where + E: Clone, + { + if inventory.executors().len() > u32::MAX as usize { + // This would mean we cannot use u32::MAX as an "undefined" executor index + return Err(StaticProgramError::TooManyExecutors); } - self.true_program_length = true_program_length; - self.execution_frequencies = vec![0; program.len()]; - self.program = program; + let len = program.instructions_and_debug_infos.len(); + let mut pc_handler = Vec::with_capacity(len); + for insn_and_debug_info in &program.instructions_and_debug_infos { + if let Some((insn, _)) = insn_and_debug_info { + let insn = insn.clone(); + let executor_idx = if insn.opcode == SystemOpcode::TERMINATE.global_opcode() { + // The execution loop will always branch to terminate before using this executor + 0 + } else { + *inventory.instruction_lookup.get(&insn.opcode).ok_or( + StaticProgramError::ExecutorNotFound { + opcode: insn.opcode, + }, + )? + }; + assert!( + (executor_idx as usize) < inventory.executors.len(), + "ExecutorInventory ensures executor_idx is in bounds" + ); + let pc_entry = PcEntry { insn, executor_idx }; + pc_handler.push(pc_entry); + } else { + pc_handler.push(PcEntry::undefined()); + } + } + let executors = inventory.executors.clone(); + + Ok(Self { + execution_frequencies: vec![0u32; len], + executors, + pc_handler, + pc_base: program.pc_base, + }) } - fn get_pc_index(&self, pc: u32) -> Result { - let step = self.program.step; - let pc_base = self.program.pc_base; - let pc_index = ((pc - pc_base) / step) as usize; - if !(0..self.true_program_length).contains(&pc_index) { - return Err(ExecutionError::PcOutOfBounds { - pc, - step, - pc_base, - program_len: self.true_program_length, - }); - } - Ok(pc_index) + #[inline(always)] + fn get_pc_index(&self, pc: u32) -> usize { + let pc_base = self.pc_base; + ((pc - pc_base) / DEFAULT_PC_STEP) as usize } - pub fn get_instruction( - &mut self, - pc: u32, - ) -> Result<&(Instruction, Option), ExecutionError> { - let pc_index = self.get_pc_index(pc)?; - self.execution_frequencies[pc_index] += 1; - self.program - .get_instruction_and_debug_info(pc_index) - .ok_or(ExecutionError::PcNotFound { + /// Returns `(executor, pc_entry, pc_idx)`. + #[inline(always)] + pub fn get_executor(&mut self, pc: u32) -> Result<(&mut E, &PcEntry), ExecutionError> { + let pc_idx = self.get_pc_index(pc); + let entry = self + .pc_handler + .get(pc_idx) + .ok_or_else(|| ExecutionError::PcOutOfBounds { pc, - step: self.program.step, - pc_base: self.program.pc_base, - program_len: self.program.len(), - }) - } -} + pc_base: self.pc_base, + program_len: self.pc_handler.len(), + })?; + // SAFETY: `execution_frequencies` has the same length as `pc_handler` so `get_pc_entry` + // already does the bounds check + unsafe { + *self.execution_frequencies.get_unchecked_mut(pc_idx) += 1; + }; + // SAFETY: the `executor_idx` comes from ExecutorInventory, which ensures that + // `executor_idx` is within bounds + let executor = unsafe { + self.executors + .get_unchecked_mut(entry.executor_idx as usize) + }; -impl ChipUsageGetter for ProgramChip { - fn air_name(&self) -> String { - "ProgramChip".to_string() + Ok((executor, entry)) } - fn constant_trace_height(&self) -> Option { - Some(self.true_program_length.next_power_of_two()) + pub fn filtered_execution_frequencies(&self) -> Vec + where + E: Sync, + { + self.pc_handler + .par_iter() + .enumerate() + .filter_map(|(i, entry)| entry.is_some().then(|| self.execution_frequencies[i])) + .collect() } +} - fn current_trace_height(&self) -> usize { - self.true_program_length - } +// For CPU backend only +pub struct ProgramChip { + /// `i` -> frequency of instruction in `i`th row of trace matrix. This requires filtering + /// `program.instructions_and_debug_infos` to remove gaps. + pub(super) filtered_exec_frequencies: Vec, + pub(super) cached: Option>>, +} - fn trace_width(&self) -> usize { - 1 +impl ProgramChip { + pub(super) fn unloaded() -> Self { + Self { + filtered_exec_frequencies: Vec::new(), + cached: None, + } } } diff --git a/crates/vm/src/system/program/tests/mod.rs b/crates/vm/src/system/program/tests/mod.rs index 4a0293b348..bbc60187e8 100644 --- a/crates/vm/src/system/program/tests/mod.rs +++ b/crates/vm/src/system/program/tests/mod.rs @@ -1,6 +1,7 @@ -use std::iter; +use std::{iter, sync::Arc}; use openvm_instructions::{ + exe::VmExe, instruction::Instruction, program::{Program, DEFAULT_PC_STEP}, LocalOpcode, @@ -10,15 +11,19 @@ use openvm_native_compiler::{ }; use openvm_rv32im_transpiler::BranchEqualOpcode::*; use openvm_stark_backend::{ + config::StarkGenericConfig, + engine::StarkEngine, p3_field::FieldAlgebra, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, + prover::types::AirProvingContext, + Chip, }; use openvm_stark_sdk::{ any_rap_arc_vec, config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, baby_bear_poseidon2_root::BabyBearPoseidon2RootConfig, + FriParameters, }, dummy_airs::interaction::dummy_interaction_air::DummyInteractionAir, engine::StarkFriEngine, @@ -29,30 +34,48 @@ use static_assertions::assert_impl_all; use crate::{ arch::{instructions::SystemOpcode::*, testing::READ_INSTRUCTION_BUS}, - system::program::{trace::VmCommittedExe, ProgramBus, ProgramChip}, + system::program::{trace::VmCommittedExe, ProgramAir, ProgramBus, ProgramChip}, }; assert_impl_all!(VmCommittedExe: Serialize, DeserializeOwned); assert_impl_all!(VmCommittedExe: Serialize, DeserializeOwned); fn interaction_test(program: Program, execution: Vec) { - let bus = ProgramBus::new(READ_INSTRUCTION_BUS); - let mut chip = ProgramChip::new_with_program(program.clone(), bus); let mut execution_frequencies = vec![0; program.len()]; for pc_idx in execution { execution_frequencies[pc_idx as usize] += 1; - chip.get_instruction(pc_idx * DEFAULT_PC_STEP).unwrap(); } - let program_air = chip.air; - let program_proof_input = chip.generate_air_proof_input(None); + let filtered_exec_frequencies: Vec<_> = program + .instructions_and_debug_infos + .iter() + .enumerate() + .filter(|(_, entry)| entry.is_some()) + .map(|(i, _)| execution_frequencies[i]) + .collect(); + let original_height = filtered_exec_frequencies.len(); + + let bus = ProgramBus::new(READ_INSTRUCTION_BUS); + let program_air = ProgramAir::new(bus); + + let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)); + let exe = VmExe::new(program); + let committed_exe = + VmCommittedExe::::commit(exe, engine.config().pcs()); + let cached = committed_exe.get_committed_trace(); + let chip = ProgramChip { + filtered_exec_frequencies, + cached: Some(cached), + }; + let ctx = chip.generate_proving_ctx(()); let counter_air = DummyInteractionAir::new(9, true, bus.inner.index); let mut program_cells = vec![]; + let program = committed_exe.exe.program; for (index, frequency) in execution_frequencies.into_iter().enumerate() { let option = program.get_instruction_and_debug_info(index); if let Some((instruction, _)) = option { program_cells.extend([ - BabyBear::from_canonical_usize(frequency), // hacky: we should switch execution_frequencies into hashmap + BabyBear::from_canonical_u32(frequency), BabyBear::from_canonical_usize(index * (DEFAULT_PC_STEP as usize)), instruction.opcode.to_field(), instruction.a, @@ -68,23 +91,20 @@ fn interaction_test(program: Program, execution: Vec) { // Pad program cells with zeroes to make height a power of two. let width = 10; - let original_height = program.num_defined_instructions(); let desired_height = original_height.next_power_of_two(); let cells_to_add = (desired_height - original_height) * width; program_cells.extend(iter::repeat_n(BabyBear::ZERO, cells_to_add)); - let counter_trace = RowMajorMatrix::new(program_cells, 10); + let counter_trace = Arc::new(RowMajorMatrix::new(program_cells, 10)); println!("trace height = {}", original_height); println!("counter trace height = {}", counter_trace.height()); - BabyBearPoseidon2Engine::run_test_fast( - any_rap_arc_vec!(program_air, counter_air), - vec![ - program_proof_input, - AirProofInput::simple_no_pis(counter_trace), - ], - ) - .expect("Verification failed"); + engine + .run_test( + any_rap_arc_vec!(program_air, counter_air), + vec![ctx, AirProvingContext::simple_no_pis(counter_trace)], + ) + .expect("Verification failed"); } #[test] @@ -178,21 +198,25 @@ fn test_program_negative() { ]; let bus = ProgramBus::new(READ_INSTRUCTION_BUS); let program = Program::from_instructions(&instructions); + let program_air = ProgramAir::new(bus); - let mut chip = ProgramChip::new_with_program(program, bus); let execution_frequencies = vec![1; instructions.len()]; - for pc_idx in 0..instructions.len() { - chip.get_instruction(pc_idx as u32 * DEFAULT_PC_STEP) - .unwrap(); - } - let program_air = chip.air; - let program_proof_input = chip.generate_air_proof_input(None); + let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)); + let exe = VmExe::new(program); + let committed_exe = + VmCommittedExe::::commit(exe, engine.config().pcs()); + let cached = committed_exe.get_committed_trace(); + let chip = ProgramChip { + filtered_exec_frequencies: execution_frequencies.clone(), + cached: Some(cached), + }; + let ctx = chip.generate_proving_ctx(()); let counter_air = DummyInteractionAir::new(7, true, bus.inner.index); let mut program_rows = vec![]; for (pc_idx, instruction) in instructions.iter().enumerate() { program_rows.extend(vec![ - BabyBear::from_canonical_usize(execution_frequencies[pc_idx]), + BabyBear::from_canonical_u32(execution_frequencies[pc_idx]), BabyBear::from_canonical_usize(pc_idx * DEFAULT_PC_STEP as usize), instruction.opcode.to_field(), instruction.a, @@ -204,15 +228,14 @@ fn test_program_negative() { } let mut counter_trace = RowMajorMatrix::new(program_rows, 8); counter_trace.row_mut(1)[1] = BabyBear::ZERO; + let counter_trace = Arc::new(counter_trace); - BabyBearPoseidon2Engine::run_test_fast( - any_rap_arc_vec!(program_air, counter_air), - vec![ - program_proof_input, - AirProofInput::simple_no_pis(counter_trace), - ], - ) - .expect("Verification failed"); + engine + .run_test( + any_rap_arc_vec!(program_air, counter_air), + vec![ctx, AirProvingContext::simple_no_pis(counter_trace)], + ) + .expect("Verification failed"); } #[test] @@ -265,7 +288,7 @@ fn test_program_with_undefined_instructions() { )), ]; - let program = Program::new_without_debug_infos_with_option(&instructions, DEFAULT_PC_STEP, 0); + let program = Program::new_without_debug_infos_with_option(&instructions, 0); interaction_test(program, vec![0, 2, 5]); } diff --git a/crates/vm/src/system/program/trace.rs b/crates/vm/src/system/program/trace.rs index d9e2abd956..bf8a9da868 100644 --- a/crates/vm/src/system/program/trace.rs +++ b/crates/vm/src/system/program/trace.rs @@ -3,63 +3,83 @@ use std::{borrow::BorrowMut, sync::Arc}; use derivative::Derivative; use itertools::Itertools; use openvm_circuit::arch::hasher::poseidon2::Poseidon2Hasher; -use openvm_instructions::{exe::VmExe, program::Program, LocalOpcode, SystemOpcode}; +use openvm_instructions::{ + exe::VmExe, + program::{Program, DEFAULT_PC_STEP}, + LocalOpcode, SystemOpcode, +}; use openvm_stark_backend::{ - config::{Com, Domain, StarkGenericConfig, Val}, - p3_commit::{Pcs, PolynomialSpace}, - p3_field::{Field, FieldAlgebra, PrimeField32, PrimeField64}, + config::{Com, PcsProverData, StarkGenericConfig, Val}, + p3_commit::Pcs, + p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, + p3_util::log2_strict_usize, prover::{ - helper::AirProofInputTestHelper, - types::{AirProofInput, AirProofRawInput, CommittedTraceData}, + cpu::{self, CpuBackend}, + types::{AirProvingContext, CommittedTraceData}, }, + Chip, }; use serde::{Deserialize, Serialize}; -use super::{Instruction, ProgramChip, ProgramExecutionCols, EXIT_CODE_FAIL}; +use super::{Instruction, ProgramExecutionCols, EXIT_CODE_FAIL}; use crate::{ arch::{ hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, MemoryConfig, }, - system::memory::{tree::MemoryNode, AddressMap, CHUNK}, + system::{ + memory::{merkle::MerkleTree, AddressMap, CHUNK}, + program::ProgramChip, + }, }; +// TODO[jpw]: separate so we can have Arc separate from CommittedTraceData +/// **Note**: this struct stores the program ROM twice: once in [VmExe] and once as a cached trace +/// matrix `trace`. #[derive(Serialize, Deserialize, Derivative)] #[serde(bound( - serialize = "VmExe>: Serialize, CommittedTraceData: Serialize", - deserialize = "VmExe>: Deserialize<'de>, CommittedTraceData: Deserialize<'de>" + serialize = "VmExe>: Serialize, Com: Serialize, PcsProverData: Serialize", + deserialize = "VmExe>: Deserialize<'de>, Com: Deserialize<'de>, PcsProverData: Deserialize<'de>" ))] #[derivative(Clone(bound = "Com: Clone"))] pub struct VmCommittedExe { /// Raw executable. pub exe: VmExe>, - /// Committed program trace. - pub committed_program: CommittedTraceData, + pub commitment: Com, + /// Program ROM as cached trace matrix. + pub trace: Arc>>, + pub prover_data: Arc>, } -impl VmCommittedExe -where - Val: PrimeField32, -{ +impl VmCommittedExe { /// Creates [VmCommittedExe] from [VmExe] by using `pcs` to commit to the /// program code as a _cached trace_ matrix. pub fn commit(exe: VmExe>, pcs: &SC::Pcs) -> Self { - let cached_trace = generate_cached_trace(&exe.program); - let domain = pcs.natural_domain_for_degree(cached_trace.height()); - let (commitment, pcs_data) = pcs.commit(vec![(domain, cached_trace.clone())]); + let trace = generate_cached_trace(&exe.program); + let domain = pcs.natural_domain_for_degree(trace.height()); + + let (commitment, data) = pcs.commit(vec![(domain, trace.clone())]); Self { - committed_program: CommittedTraceData { - trace: Arc::new(cached_trace), - commitment, - pcs_data: Arc::new(pcs_data), - }, exe, + commitment, + trace: Arc::new(trace), + prover_data: Arc::new(data), } } pub fn get_program_commit(&self) -> Com { - self.committed_program.commitment.clone() + self.commitment.clone() + } + + pub fn get_committed_trace(&self) -> CommittedTraceData> { + let log_trace_height: u8 = log2_strict_usize(self.trace.height()).try_into().unwrap(); + let data = cpu::PcsData::new(self.prover_data.clone(), vec![log_trace_height]); + CommittedTraceData { + commitment: self.commitment.clone(), + trace: self.trace.clone(), + data, + } } /// Computes a commitment to [VmCommittedExe]. This is a Merklelized hash of: @@ -77,22 +97,16 @@ where pub fn compute_exe_commit(&self, memory_config: &MemoryConfig) -> Com where Com: AsRef<[Val; CHUNK]> + From<[Val; CHUNK]>, + Val: PrimeField32, { let hasher = vm_poseidon2_hasher(); let memory_dimensions = memory_config.memory_dimensions(); - let app_program_commit: &[Val; CHUNK] = self.committed_program.commitment.as_ref(); + let app_program_commit: &[Val; CHUNK] = self.commitment.as_ref(); let mem_config = memory_config; - let init_memory_commit = MemoryNode::tree_from_memory( - memory_dimensions, - &AddressMap::from_iter( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - self.exe.init_memory.clone(), - ), - &hasher, - ) - .hash(); + let memory_image = + AddressMap::from_sparse(mem_config.addr_spaces.clone(), self.exe.init_memory.clone()); + let init_memory_commit = + MerkleTree::from_memory(&memory_image, &memory_dimensions, &hasher).root(); Com::::from(compute_exe_commit( &hasher, app_program_commit, @@ -102,37 +116,25 @@ where } } -impl ProgramChip { - pub fn generate_air_proof_input( - self, - cached: Option>, - ) -> AirProofInput - where - Domain: PolynomialSpace, - { - let common_trace = RowMajorMatrix::new_col( - self.execution_frequencies - .into_iter() - .zip_eq(self.program.instructions_and_debug_infos.iter()) - .filter_map(|(frequency, option)| { - option.as_ref().map(|_| F::from_canonical_usize(frequency)) - }) - .collect::>(), - ); - if let Some(cached) = cached { - AirProofInput { - cached_mains_pdata: vec![(cached.commitment, cached.pcs_data)], - raw: AirProofRawInput { - cached_mains: vec![cached.trace], - common_main: Some(common_trace), - public_values: vec![], - }, - } - } else { - AirProofInput::cached_traces_no_pis( - vec![generate_cached_trace(&self.program)], - common_trace, - ) +impl Chip> for ProgramChip { + /// The cached program trace is cloned and left for future use. The clone is cheap because the + /// cached trace is behind smart pointers. The execution frequencies are left unchanged. + fn generate_proving_ctx(&self, _: RA) -> AirProvingContext> { + let cached = self + .cached + .clone() + .expect("cached program trace must be loaded"); + assert!(self.filtered_exec_frequencies.len() <= cached.trace.height()); + let mut freqs = Val::::zero_vec(cached.trace.height()); + freqs + .par_iter_mut() + .zip(self.filtered_exec_frequencies.par_iter()) + .for_each(|(f, x)| *f = Val::::from_canonical_u32(*x)); + let common_trace = RowMajorMatrix::new_col(freqs); + AirProvingContext { + cached_mains: vec![cached], + common_main: Some(Arc::new(common_trace)), + public_values: vec![], } } } @@ -158,7 +160,7 @@ pub fn compute_exe_commit( hasher.compress(&hasher.compress(&program_hash, &memory_hash), &pc_hash) } -pub(crate) fn generate_cached_trace(program: &Program) -> RowMajorMatrix { +pub(crate) fn generate_cached_trace(program: &Program) -> RowMajorMatrix { let width = ProgramExecutionCols::::width(); let mut instructions = program .enumerate_by_pc() @@ -169,7 +171,7 @@ pub(crate) fn generate_cached_trace(program: &Program) -> Ro let padding = padding_instruction(); while !instructions.len().is_power_of_two() { instructions.push(( - program.pc_base + instructions.len() as u32 * program.step, + program.pc_base + instructions.len() as u32 * DEFAULT_PC_STEP, padding.clone(), )); } diff --git a/crates/vm/src/system/public_values/core.rs b/crates/vm/src/system/public_values/core.rs index de189f101b..b0970e1cad 100644 --- a/crates/vm/src/system/public_values/core.rs +++ b/crates/vm/src/system/public_values/core.rs @@ -1,8 +1,16 @@ -use std::sync::Mutex; +use std::{ + borrow::{Borrow, BorrowMut}, + sync::Mutex, +}; -use openvm_circuit_primitives::{encoder::Encoder, SubAir}; +use openvm_circuit_primitives::{encoder::Encoder, AlignedBytesBorrow, SubAir}; use openvm_instructions::{ - instruction::Instruction, LocalOpcode, PublishOpcode, PublishOpcode::PUBLISH, + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::RV32_IMM_AS, + LocalOpcode, + PublishOpcode::{self, PUBLISH}, + NATIVE_AS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -10,17 +18,27 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; use crate::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, MinimalInstruction, - Result, VmAdapterInterface, VmCoreAir, VmCoreChip, + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, ExecutionError, + Executor, MeteredExecutor, MinimalInstruction, PreflightExecutor, RecordArena, + StaticProgramError, TraceFiller, VmCoreAir, VmExecState, VmStateMut, + }, + system::{ + memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, + native_adapter::NativeAdapterExecutor, + public_values::columns::PublicValuesCoreColsView, }, - system::public_values::columns::PublicValuesCoreColsView, + utils::{transmute_field_to_u32, transmute_u32_to_field}, }; + pub(crate) type AdapterInterface = BasicAdapterInterface, 2, 0, 1, 1>; -pub(crate) type AdapterInterfaceReads = as VmAdapterInterface>::Reads; #[derive(Clone, Debug)] pub struct PublicValuesCoreAir { @@ -99,95 +117,297 @@ impl VmCoreAir { - value: F, - index: F, + pub value: F, + pub index: F, } /// ATTENTION: If a specific public value is not provided, a default 0 will be used when generating /// the proof but in the perspective of constraints, it could be any value. -pub struct PublicValuesCoreChip { - air: PublicValuesCoreAir, +pub struct PublicValuesExecutor> { + adapter: A, + encoder: Encoder, // Mutex is to make the struct Sync. But it actually won't be accessed by multiple threads. - custom_pvs: Mutex>>, + pub(crate) custom_pvs: Mutex>>, } -impl PublicValuesCoreChip { +impl PublicValuesExecutor { /// **Note:** `max_degree` is the maximum degree of the constraint polynomials to represent the /// flags. If you want the overall AIR's constraint degree to be `<= max_constraint_degree`, /// then typically you should set `max_degree` to `max_constraint_degree - 1`. - pub fn new(num_custom_pvs: usize, max_degree: u32) -> Self { + pub fn new(adapter: A, num_custom_pvs: usize, max_degree: u32) -> Self { Self { - air: PublicValuesCoreAir::new(num_custom_pvs, max_degree), + adapter, + encoder: Encoder::new(num_custom_pvs, max_degree, true), custom_pvs: Mutex::new(vec![None; num_custom_pvs]), } } - pub fn get_custom_public_values(&self) -> Vec> { - self.custom_pvs.lock().unwrap().clone() + + pub(crate) fn set_public_values(&mut self, public_values: &[F]) { + let mut custom_pvs = self.custom_pvs.lock().unwrap(); + assert_eq!(public_values.len(), custom_pvs.len()); + for (pv_mut, value) in custom_pvs.iter_mut().zip(public_values) { + *pv_mut = Some(value.clone()); + } } } -impl VmCoreChip> for PublicValuesCoreChip { - type Record = PublicValuesRecord; - type Air = PublicValuesCoreAir; +// We clone when we want to run a new instance of the program, so we reset the custom public values. +impl Clone for PublicValuesExecutor { + fn clone(&self) -> Self { + Self { + adapter: self.adapter.clone(), + encoder: self.encoder.clone(), + custom_pvs: Mutex::new(vec![None; self.custom_pvs.lock().unwrap().len()]), + } + } +} - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - _instruction: &Instruction, - _from_pc: u32, - reads: AdapterInterfaceReads, - ) -> Result<(AdapterRuntimeContext>, Self::Record)> { - let [[value], [index]] = reads; +impl PreflightExecutor for PublicValuesExecutor +where + F: PrimeField32, + A: 'static + Clone + AdapterTraceExecutor, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut PublicValuesRecord), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + PublishOpcode::from_usize(opcode - PublishOpcode::CLASS_OFFSET) + ) + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + [[core_record.value], [core_record.index]] = + self.adapter + .read(state.memory, instruction, &mut adapter_record); { - let idx: usize = index.as_canonical_u32() as usize; + let idx: usize = core_record.index.as_canonical_u32() as usize; let mut custom_pvs = self.custom_pvs.lock().unwrap(); if custom_pvs[idx].is_none() { - custom_pvs[idx] = Some(value); + custom_pvs[idx] = Some(core_record.value); } else { // Not a hard constraint violation when publishing the same value twice but the // program should avoid that. panic!("Custom public value {} already set", idx); } } - let output = AdapterRuntimeContext { - to_pc: None, - writes: [], - }; - let record = Self::Record { value, index }; - Ok((output, record)) - } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - PublishOpcode::from_usize(opcode - PublishOpcode::CLASS_OFFSET) - ) + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } +} + +impl TraceFiller for PublicValuesExecutor +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &PublicValuesRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let cols = PublicValuesCoreColsView::<_, &mut F>::borrow_mut(core_row); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let mut cols = PublicValuesCoreColsView::<_, &mut F>::borrow_mut(row_slice); - debug_assert_eq!(cols.width(), BaseAir::::width(&self.air)); - *cols.is_valid = F::ONE; - *cols.value = record.value; - *cols.index = record.index; let idx: usize = record.index.as_canonical_u32() as usize; - let pt = self.air.encoder.get_flag_pt(idx); - for (i, var) in cols.custom_pv_vars.iter_mut().enumerate() { - **var = F::from_canonical_u32(pt[i]); - } + let pt = self.encoder.get_flag_pt(idx); + + cols.custom_pv_vars + .into_iter() + .zip(pt.iter()) + .for_each(|(var, &val)| { + *var = F::from_canonical_u32(val); + }); + + *cols.index = record.index; + *cols.value = record.value; + *cols.is_valid = F::ONE; } fn generate_public_values(&self) -> Vec { - self.get_custom_public_values() - .into_iter() - .map(|x| x.unwrap_or(F::ZERO)) - .collect() + let custom_pvs = self.custom_pvs.lock().unwrap(); + custom_pvs.iter().map(|&x| x.unwrap_or(F::ZERO)).collect() + } +} + +#[derive(AlignedBytesBorrow)] +#[repr(C)] +struct PublicValuesPreCompute { + b_or_imm: u32, + c_or_imm: u32, + pvs: *const Mutex>>, +} + +impl Executor for PublicValuesExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn pre_compute( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut PublicValuesPreCompute = data.borrow_mut(); + let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, data); + + let fn_ptr = match (b_is_imm, c_is_imm) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for PublicValuesExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute> = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, &mut data.data); + + let fn_ptr = match (b_is_imm, c_is_imm) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + state: &mut VmExecState, +) where + CTX: E1ExecutionCtx, +{ + let pre_compute: &PublicValuesPreCompute = pre_compute.borrow(); + execute_e12_impl::<_, _, B_IS_IMM, C_IS_IMM>(pre_compute, state); +} + +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + state: &mut VmExecState, +) where + CTX: E2ExecutionCtx, +{ + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + state.ctx.on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, B_IS_IMM, C_IS_IMM>(&pre_compute.data, state); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &PublicValuesPreCompute, + state: &mut VmExecState, +) where + CTX: E1ExecutionCtx, +{ + let value = if B_IS_IMM { + transmute_u32_to_field(&pre_compute.b_or_imm) + } else { + state.vm_read::(NATIVE_AS, pre_compute.b_or_imm)[0] + }; + let index = if C_IS_IMM { + transmute_u32_to_field(&pre_compute.c_or_imm) + } else { + state.vm_read::(NATIVE_AS, pre_compute.c_or_imm)[0] + }; + + let idx: usize = index.as_canonical_u32() as usize; + { + let custom_pvs = unsafe { &*pre_compute.pvs }; + let mut custom_pvs = custom_pvs.lock().unwrap(); + + if custom_pvs[idx].is_none() { + custom_pvs[idx] = Some(value); + } else { + // Not a hard constraint violation when publishing the same value twice but the + // program should avoid that. + panic!("Custom public value {} already set", idx); + } } + state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + state.instret += 1; +} + +impl PublicValuesExecutor +where + F: PrimeField32, +{ + fn pre_compute_impl( + &self, + inst: &Instruction, + data: &mut PublicValuesPreCompute, + ) -> (bool, bool) { + let &Instruction { b, c, e, f, .. } = inst; + + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); + + let b_is_imm = e == RV32_IMM_AS; + let c_is_imm = f == RV32_IMM_AS; + + let b_or_imm = if b_is_imm { + transmute_field_to_u32(&b) + } else { + b.as_canonical_u32() + }; + let c_or_imm = if c_is_imm { + transmute_field_to_u32(&c) + } else { + c.as_canonical_u32() + }; + + *data = PublicValuesPreCompute { + b_or_imm, + c_or_imm, + pvs: &self.custom_pvs, + }; - fn air(&self) -> &Self::Air { - &self.air + (b_is_imm, c_is_imm) } } diff --git a/crates/vm/src/system/public_values/mod.rs b/crates/vm/src/system/public_values/mod.rs index 918606497b..e21b5416fe 100644 --- a/crates/vm/src/system/public_values/mod.rs +++ b/crates/vm/src/system/public_values/mod.rs @@ -1,18 +1,15 @@ use crate::{ arch::{VmAirWrapper, VmChipWrapper}, - system::{ - native_adapter::{NativeAdapterAir, NativeAdapterChip}, - public_values::core::{PublicValuesCoreAir, PublicValuesCoreChip}, - }, + system::native_adapter::NativeAdapterAir, }; mod columns; /// Chip to publish custom public values from VM programs. -pub mod core; +mod core; +pub use core::*; #[cfg(test)] mod tests; pub type PublicValuesAir = VmAirWrapper, PublicValuesCoreAir>; -pub type PublicValuesChip = - VmChipWrapper, PublicValuesCoreChip>; +pub type PublicValuesChip = VmChipWrapper>; diff --git a/crates/vm/src/system/public_values/tests.rs b/crates/vm/src/system/public_values/tests.rs index dbf9dc217d..bf7e59eedb 100644 --- a/crates/vm/src/system/public_values/tests.rs +++ b/crates/vm/src/system/public_values/tests.rs @@ -5,7 +5,7 @@ use openvm_stark_backend::{ p3_air::{Air, AirBuilderWithPublicValues}, p3_field::{Field, FieldAlgebra}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, + prover::types::AirProvingContext, rap::PartitionedBaseAir, utils::disable_debug_builder, verifier::VerificationError, @@ -51,8 +51,11 @@ fn public_values_happy_path_1() { let trace = RowMajorMatrix::new_row(cols.flatten()); let pvs = to_field_vec(vec![0, 0, 12]); - BabyBearPoseidon2Engine::run_test_fast(vec![air], vec![AirProofInput::simple(trace, pvs)]) - .expect("Verification failed"); + BabyBearPoseidon2Engine::run_test_fast( + vec![air], + vec![AirProvingContext::simple(Arc::new(trace), pvs)], + ) + .expect("Verification failed"); } #[test] @@ -70,8 +73,11 @@ fn public_values_neg_pv_not_match() { disable_debug_builder(); assert_eq!( - BabyBearPoseidon2Engine::run_test_fast(vec![air], vec![AirProofInput::simple(trace, pvs)]) - .err(), + BabyBearPoseidon2Engine::run_test_fast( + vec![air], + vec![AirProvingContext::simple(Arc::new(trace), pvs)] + ) + .err(), Some(VerificationError::OodEvaluationMismatch) ); } @@ -91,8 +97,11 @@ fn public_values_neg_index_out_of_bound() { disable_debug_builder(); assert_eq!( - BabyBearPoseidon2Engine::run_test_fast(vec![air], vec![AirProofInput::simple(trace, pvs)]) - .err(), + BabyBearPoseidon2Engine::run_test_fast( + vec![air], + vec![AirProvingContext::simple(Arc::new(trace), pvs)] + ) + .err(), Some(VerificationError::OodEvaluationMismatch) ); } @@ -129,8 +138,11 @@ fn public_values_neg_double_publish_impl(actual_pv: u32) { disable_debug_builder(); assert_eq!( - BabyBearPoseidon2Engine::run_test_fast(vec![air], vec![AirProofInput::simple(trace, pvs)]) - .err(), + BabyBearPoseidon2Engine::run_test_fast( + vec![air], + vec![AirProvingContext::simple(Arc::new(trace), pvs)] + ) + .err(), Some(VerificationError::OodEvaluationMismatch) ); } diff --git a/crates/vm/src/utils/mod.rs b/crates/vm/src/utils/mod.rs index 7b4823c53a..0d86c280d9 100644 --- a/crates/vm/src/utils/mod.rs +++ b/crates/vm/src/utils/mod.rs @@ -1,10 +1,59 @@ #[cfg(any(test, feature = "test-utils"))] mod stark_utils; #[cfg(any(test, feature = "test-utils"))] -mod test_utils; +pub mod test_utils; + +use std::mem::size_of_val; pub use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_stark_backend::p3_field::PrimeField32; #[cfg(any(test, feature = "test-utils"))] pub use stark_utils::*; #[cfg(any(test, feature = "test-utils"))] pub use test_utils::*; + +#[inline(always)] +pub fn transmute_field_to_u32(field: &F) -> u32 { + debug_assert_eq!( + std::mem::size_of::(), + std::mem::size_of::(), + "Field type F must have the same size as u32" + ); + debug_assert_eq!( + std::mem::align_of::(), + std::mem::align_of::(), + "Field type F must have the same alignment as u32" + ); + // SAFETY: This assumes that F has the same memory layout as u32. + // This is only safe for field types that are guaranteed to be represented + // as a single u32 internally + unsafe { *(field as *const F as *const u32) } +} + +#[inline(always)] +pub fn transmute_u32_to_field(value: &u32) -> F { + debug_assert_eq!( + std::mem::size_of::(), + std::mem::size_of::(), + "Field type F must have the same size as u32" + ); + debug_assert_eq!( + std::mem::align_of::(), + std::mem::align_of::(), + "Field type F must have the same alignment as u32" + ); + // SAFETY: This assumes that F has the same memory layout as u32. + // This is only safe for field types that are guaranteed to be represented + // as a single u32 internally + unsafe { *(value as *const u32 as *const F) } +} + +/// # Safety +/// The type `T` should be plain old data so there is no worry about [Drop] behavior in the +/// transmutation. +#[inline(always)] +pub unsafe fn slice_as_bytes(slice: &[T]) -> &[u8] { + let len = size_of_val(slice); + // SAFETY: length and alignment are correct. + unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, len) } +} diff --git a/crates/vm/src/utils/stark_utils.rs b/crates/vm/src/utils/stark_utils.rs index d940be5c75..9c8855c905 100644 --- a/crates/vm/src/utils/stark_utils.rs +++ b/crates/vm/src/utils/stark_utils.rs @@ -1,170 +1,171 @@ -use itertools::multiunzip; -use openvm_instructions::{exe::VmExe, program::Program}; +use openvm_instructions::exe::VmExe; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, + config::{Com, Val}, + engine::VerificationData, p3_field::PrimeField32, - verifier::VerificationError, - Chip, }; use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, setup_tracing, FriParameters, }, - engine::{StarkEngine, StarkFriEngine, VerificationDataWithFriParams}, + engine::{StarkFriEngine, VerificationDataWithFriParams}, p3_baby_bear::BabyBear, - utils::ProofInputForTest, }; -use crate::arch::{ - vm::{VirtualMachine, VmExecutor}, - Streams, VmConfig, VmMemoryState, +use crate::{ + arch::{ + debug_proving_ctx, execution_mode::metered::Segment, vm::VirtualMachine, Executor, + ExitCode, MatrixRecordArena, MeteredExecutor, PreflightExecutionOutput, PreflightExecutor, + Streams, VmBuilder, VmCircuitConfig, VmConfig, VmExecutionConfig, + }, + system::memory::{MemoryImage, CHUNK}, }; -pub fn air_test(config: VC, exe: impl Into>) +// NOTE on trait bounds: the compiler cannot figure out Val=BabyBear without the +// VmExecutionConfig and VmCircuitConfig bounds even though VmProverBuilder already includes them. +// The compiler also seems to need the extra VC even though VC=VB::VmConfig +pub fn air_test(builder: VB, config: VC, exe: impl Into>) where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + VB: VmBuilder< + BabyBearPoseidon2Engine, + VmConfig = VC, + RecordArena = MatrixRecordArena, + >, + VC: VmExecutionConfig + + VmCircuitConfig + + VmConfig, + >::Executor: Executor + + MeteredExecutor + + PreflightExecutor>, { - air_test_with_min_segments(config, exe, Streams::default(), 1); + air_test_with_min_segments(builder, config, exe, Streams::default(), 1); } /// Executes and proves the VM and returns the final memory state. -pub fn air_test_with_min_segments( +pub fn air_test_with_min_segments( + builder: VB, config: VC, exe: impl Into>, input: impl Into>, min_segments: usize, -) -> Option> +) -> Option where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + VB: VmBuilder< + BabyBearPoseidon2Engine, + VmConfig = VC, + RecordArena = MatrixRecordArena, + >, + VC: VmExecutionConfig + + VmCircuitConfig + + VmConfig, + >::Executor: Executor + + MeteredExecutor + + PreflightExecutor>, { - air_test_impl(config, exe, input, min_segments, true) + let mut log_blowup = 1; + while config.as_ref().max_constraint_degree > (1 << log_blowup) + 1 { + log_blowup += 1; + } + let fri_params = FriParameters::new_for_testing(log_blowup); + let (final_memory, _) = air_test_impl::( + fri_params, + builder, + config, + exe, + input, + min_segments, + true, + ) + .unwrap(); + final_memory } /// Executes and proves the VM and returns the final memory state. /// If `debug` is true, runs the debug prover. -pub fn air_test_impl( - config: VC, - exe: impl Into>, - input: impl Into>, +// +// Same implementation as VmLocalProver, but we need to do something special to run the debug prover +#[allow(clippy::type_complexity)] +pub fn air_test_impl( + fri_params: FriParameters, + builder: VB, + config: VB::VmConfig, + exe: impl Into>>, + input: impl Into>>, min_segments: usize, debug: bool, -) -> Option> +) -> eyre::Result<( + Option, + Vec>, +)> where - VC: VmConfig, - VC::Executor: Chip, - VC::Periphery: Chip, + E: StarkFriEngine, + Val: PrimeField32, + VB: VmBuilder, + >>::Executor: Executor> + + MeteredExecutor> + + PreflightExecutor, VB::RecordArena>, + Com: AsRef<[Val; CHUNK]> + From<[Val; CHUNK]>, { setup_tracing(); - let mut log_blowup = 1; - while config.system().max_constraint_degree > (1 << log_blowup) + 1 { - log_blowup += 1; - } - let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(log_blowup)); - let vm = VirtualMachine::new(engine, config); - let pk = vm.keygen(); - let mut result = vm.execute_and_generate(exe, input).unwrap(); - let final_memory = Option::take(&mut result.final_memory); - let global_airs = vm.config().create_chip_complex().unwrap().airs(); - if debug { - for proof_input in &result.per_segment { - let (airs, pks, air_proof_inputs): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(proof_input.per_air.iter().map(|(air_id, air_proof_input)| { - ( - global_airs[*air_id].clone(), - pk.per_air[*air_id].clone(), - air_proof_input.clone(), - ) - })); - vm.engine.debug(&airs, &pks, &air_proof_inputs); - } - } - let proofs = vm.prove(&pk, result); + let engine = E::new(fri_params); + let (mut vm, pk) = VirtualMachine::::new_with_keygen(engine, builder, config)?; + let vk = pk.get_vk(); + let exe = exe.into(); + let input = input.into(); + let metered_ctx = vm.build_metered_ctx(); + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + let interpreter = vm + .executor() + .metered_instance(&exe, &executor_idx_to_air_idx)?; + let (segments, _) = interpreter.execute_metered(input.clone(), metered_ctx)?; + let committed_exe = vm.commit_exe(exe); + let cached_program_trace = vm.transport_committed_exe_to_device(&committed_exe); + vm.load_program(cached_program_trace); + let exe = committed_exe.exe; - assert!(proofs.len() >= min_segments); - vm.verify(&pk.get_vk(), proofs) - .expect("segment proofs should verify"); - final_memory -} + let mut state = Some(vm.create_initial_state(&exe, input)); + let mut proofs = Vec::new(); + let mut exit_code = None; + for segment in segments { + let Segment { + instret_start, + num_insns, + trace_heights, + } = segment; + assert_eq!(state.as_ref().unwrap().instret, instret_start); + let from_state = Option::take(&mut state).unwrap(); + vm.transport_init_memory_to_device(&from_state.memory); + let PreflightExecutionOutput { + system_records, + record_arenas, + to_state, + } = vm.execute_preflight(&exe, from_state, Some(num_insns), &trace_heights)?; + state = Some(to_state); + exit_code = system_records.exit_code; -/// Generates the VM STARK circuit, in the form of AIRs and traces, but does not -/// do any proving. Output is the payload of everything the prover needs. -/// -/// The output AIRs and traces are sorted by height in descending order. -pub fn gen_vm_program_test_proof_input( - program: Program>, - input_stream: impl Into>> + Clone, - #[allow(unused_mut)] mut config: VC, -) -> ProofInputForTest -where - Val: PrimeField32, - VC: VmConfig> + Clone, - VC::Executor: Chip, - VC::Periphery: Chip, -{ - cfg_if::cfg_if! { - if #[cfg(feature = "bench-metrics")] { - // Run once with metrics collection enabled, which can improve runtime performance - config.system_mut().profiling = true; - { - let executor = VmExecutor::, VC>::new(config.clone()); - executor.execute(program.clone(), input_stream.clone()).unwrap(); - } - // Run again with metrics collection disabled and measure trace generation time - config.system_mut().profiling = false; - let start = std::time::Instant::now(); + let ctx = vm.generate_proving_ctx(system_records, record_arenas)?; + if debug { + debug_proving_ctx(&vm, &pk, &ctx); } + let proof = vm.engine.prove(vm.pk(), ctx); + proofs.push(proof); } - - let airs = config.create_chip_complex().unwrap().airs(); - let executor = VmExecutor::, VC>::new(config); - - let mut result = executor - .execute_and_generate(program, input_stream) - .unwrap(); - assert_eq!( - result.per_segment.len(), - 1, - "only proving one segment for now" - ); - - let result = result.per_segment.pop().unwrap(); - #[cfg(feature = "bench-metrics")] - metrics::gauge!("execute_and_trace_gen_time_ms").set(start.elapsed().as_millis() as f64); - // Filter out unused AIRS (where trace is empty) - let (used_airs, per_air) = result - .per_air + assert!(proofs.len() >= min_segments); + vm.verify(&vk, &proofs) + .expect("segment proofs should verify"); + let state = state.unwrap(); + let final_memory = (exit_code == Some(ExitCode::Success as u32)).then_some(state.memory.memory); + let vdata = proofs .into_iter() - .map(|(air_id, x)| (airs[air_id].clone(), x)) - .unzip(); - ProofInputForTest { - airs: used_airs, - per_air, - } -} - -type ExecuteAndProveResult = Result, VerificationError>; + .map(|proof| VerificationDataWithFriParams { + data: VerificationData { + vk: vk.clone(), + proof, + }, + fri_params: vm.engine.fri_params(), + }) + .collect(); -/// Executes program and runs simple STARK prover test (keygen, prove, verify). -pub fn execute_and_prove_program, VC>( - program: Program>, - input_stream: impl Into>> + Clone, - config: VC, - engine: &E, -) -> ExecuteAndProveResult -where - Val: PrimeField32, - VC: VmConfig> + Clone, - VC::Executor: Chip, - VC::Periphery: Chip, -{ - let span = tracing::info_span!("execute_and_prove_program").entered(); - let test_proof_input = gen_vm_program_test_proof_input(program, input_stream, config); - let vparams = test_proof_input.run_test(engine)?; - span.exit(); - Ok(vparams) + Ok((final_memory, vdata)) } diff --git a/crates/vm/src/utils/test_utils.rs b/crates/vm/src/utils/test_utils.rs index 9449aff5b8..9ab3f3891b 100644 --- a/crates/vm/src/utils/test_utils.rs +++ b/crates/vm/src/utils/test_utils.rs @@ -1,8 +1,15 @@ use std::array; +use openvm_circuit::arch::{MemoryConfig, SystemConfig}; +use openvm_instructions::{ + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + NATIVE_AS, +}; use openvm_stark_backend::p3_field::PrimeField32; use rand::{rngs::StdRng, Rng}; +use crate::system::memory::{merkle::public_values::PUBLIC_VALUES_AS, online::PAGE_SIZE}; + pub fn i32_to_f(val: i32) -> F { if val.signum() == -1 { -F::from_canonical_u32(val.unsigned_abs()) @@ -31,3 +38,26 @@ pub fn u32_sign_extend(num: u32) -> u32 { num } } + +pub fn test_system_config() -> SystemConfig { + let mut addr_spaces = MemoryConfig::empty_address_space_configs(5); + addr_spaces[RV32_REGISTER_AS as usize].num_cells = PAGE_SIZE; + addr_spaces[RV32_MEMORY_AS as usize].num_cells = 1 << 22; + addr_spaces[PUBLIC_VALUES_AS as usize].num_cells = PAGE_SIZE; + addr_spaces[NATIVE_AS as usize].num_cells = 1 << 25; + SystemConfig::new(3, MemoryConfig::new(2, addr_spaces, 29, 29, 17, 32), 32) +} + +// Testing config when native address space is not needed +pub fn test_system_config_with_continuations() -> SystemConfig { + let mut config = test_system_config(); + config.memory_config.addr_spaces[NATIVE_AS as usize].num_cells = 0; + config.with_continuations() +} + +/// Generate a random message of a given length in bytes +pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { + let mut random_message: Vec = vec![0u8; len]; + rng.fill(&mut random_message[..]); + random_message +} diff --git a/crates/vm/tests/integration_test.rs b/crates/vm/tests/integration_test.rs index 168d756111..1e749f1e16 100644 --- a/crates/vm/tests/integration_test.rs +++ b/crates/vm/tests/integration_test.rs @@ -1,21 +1,22 @@ use std::{ collections::{BTreeMap, VecDeque}, - iter::zip, + mem::transmute, sync::Arc, }; +use itertools::Itertools; use openvm_circuit::{ arch::{ + execution_mode::metered::{ + ctx::DEFAULT_SEGMENT_CHECK_INSNS, segment_ctx::SegmentationLimits, + }, hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, - ChipId, ExecutionSegment, MemoryConfig, SingleSegmentVmExecutor, SystemConfig, - SystemTraceHeights, VirtualMachine, VmComplexTraceHeights, VmConfig, - VmInventoryTraceHeights, + verify_segments, verify_single, AirInventory, ContinuationVmProver, + PreflightExecutionOutput, RowMajorMatrixArena, SingleSegmentVmProver, VirtualMachine, + VmCircuitConfig, VmExecutor, VmLocalProver, PUBLIC_VALUES_AIR_ID, }, - system::{ - memory::{MemoryTraceHeights, VolatileMemoryTraceHeights, CHUNK}, - program::trace::VmCommittedExe, - }, - utils::{air_test, air_test_with_min_segments}, + system::{memory::CHUNK, program::trace::VmCommittedExe, SystemCpuBuilder}, + utils::{air_test, air_test_with_min_segments, test_system_config}, }; use openvm_instructions::{ exe::VmExe, @@ -26,10 +27,18 @@ use openvm_instructions::{ SysPhantom, SystemOpcode::*, }; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{ + execute_program, test_native_config, test_native_continuations_config, + test_rv32_with_kernels_config, NativeConfig, NativeCpuBuilder, +}; use openvm_native_compiler::{ - FieldArithmeticOpcode::*, FieldExtensionOpcode::*, NativeBranchEqualOpcode, NativeJalOpcode::*, - NativeLoadStoreOpcode::*, NativePhantom, + CastfOpcode, + FieldArithmeticOpcode::*, + FieldExtensionOpcode::*, + FriOpcode, NativeBranchEqualOpcode, + NativeJalOpcode::{self, *}, + NativeLoadStoreOpcode::*, + NativePhantom, NativeRangeCheckOpcode, Poseidon2Opcode, }; use openvm_rv32im_transpiler::BranchEqualOpcode::*; use openvm_stark_backend::{ @@ -55,19 +64,6 @@ where rng.gen_range(0..MAX_MEMORY - len) / len * len } -fn test_native_config() -> NativeConfig { - NativeConfig { - system: SystemConfig::new(3, MemoryConfig::new(2, 1, 16, 29, 15, 32, 1024), 0), - native: Default::default(), - } -} - -fn test_native_continuations_config() -> NativeConfig { - let mut config = test_native_config(); - config.system = config.system.with_continuations(); - config -} - #[test] fn test_vm_1() { let n = 6; @@ -108,11 +104,12 @@ fn test_vm_1() { let program = Program::from_instructions(&instructions); - air_test(test_native_config(), program); + air_test(NativeCpuBuilder, test_native_config(), program); } +// See crates/sdk/src/prover/root.rs for intended usage #[test] -fn test_vm_override_executor_height() { +fn test_vm_override_trace_heights() -> eyre::Result<()> { let e = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); let program = Program::::from_instructions(&[ Instruction::large_from_isize(ADD.global_opcode(), 0, 4, 0, 4, 0, 0, 0), @@ -122,175 +119,122 @@ fn test_vm_override_executor_height() { program.into(), e.config().pcs(), )); + // It's hard to define the mapping semantically. Please recompute the following magical AIR + // heights by hands whenever something changes. + let fixed_air_heights = vec![ + 2, 2, 16, 1, 8, 4, 2, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 262144, + ]; // Test getting heights. let vm_config = NativeConfig::aggregation(8, 3); - - let executor = SingleSegmentVmExecutor::new(vm_config.clone()); - let res = executor - .execute_and_compute_heights(committed_exe.exe.clone(), vec![]) - .unwrap(); - // Memory trace heights are not computed during execution. + let (mut vm, pk) = VirtualMachine::new_with_keygen(e, NativeCpuBuilder, vm_config)?; + let vk = pk.get_vk(); + + let state = vm.create_initial_state(&committed_exe.exe, vec![]); + vm.transport_init_memory_to_device(&state.memory); + let cached_program_trace = vm.transport_committed_exe_to_device(&committed_exe); + vm.load_program(cached_program_trace); + let PreflightExecutionOutput { + system_records, + mut record_arenas, + .. + } = vm.execute_preflight(&committed_exe.exe, state, None, &fixed_air_heights)?; + + let mut expected_actual_heights = vec![0; vk.inner.per_air.len()]; + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + expected_actual_heights[executor_idx_to_air_idx[6]] = 1; // corresponds to FieldArithmeticChip assert_eq!( - res.vm_heights.system, - SystemTraceHeights { - memory: MemoryTraceHeights::Volatile(VolatileMemoryTraceHeights { - boundary: 1, - access_adapters: vec![0, 0, 0], - }), - } - ); - assert_eq!( - res.vm_heights.inventory, - VmInventoryTraceHeights { - chips: vec![ - (ChipId::Executor(0), 0), - (ChipId::Executor(1), 0), - (ChipId::Executor(2), 0), - (ChipId::Executor(3), 0), - (ChipId::Executor(4), 0), - (ChipId::Executor(5), 0), - (ChipId::Executor(6), 1), // corresponds to FieldArithmeticChip - (ChipId::Executor(7), 0), - (ChipId::Executor(8), 0), - (ChipId::Executor(9), 0), - ] - .into_iter() - .collect(), - } + record_arenas + .iter() + .map(|ra| ra.trace_offset() / ra.width()) + .collect_vec(), + expected_actual_heights ); + for ra in &mut record_arenas { + ra.force_matrix_dimensions(); + } + vm.override_system_trace_heights(&fixed_air_heights); - // Test overriding heights. - let system_overridden_heights = SystemTraceHeights { - memory: MemoryTraceHeights::Volatile(VolatileMemoryTraceHeights { - boundary: 1, - access_adapters: vec![8, 4, 2], - }), - }; - let inventory_overridden_heights = VmInventoryTraceHeights { - chips: vec![ - (ChipId::Executor(0), 16), - (ChipId::Executor(1), 32), - (ChipId::Executor(2), 64), - (ChipId::Executor(3), 128), - (ChipId::Executor(4), 256), - (ChipId::Executor(5), 512), - (ChipId::Executor(6), 1024), - (ChipId::Executor(7), 2048), - (ChipId::Executor(8), 4096), - (ChipId::Executor(9), 8192), - ] - .into_iter() - .collect(), - }; - let overridden_heights = VmComplexTraceHeights::new( - system_overridden_heights.clone(), - inventory_overridden_heights.clone(), - ); - let executor = SingleSegmentVmExecutor::new_with_overridden_trace_heights( - vm_config, - Some(overridden_heights), - ); - let proof_input = executor - .execute_and_generate(committed_exe, vec![]) - .unwrap(); - let air_heights: Vec<_> = proof_input + let ctx = vm.generate_proving_ctx(system_records, record_arenas)?; + let air_heights: Vec<_> = ctx .per_air .iter() - .map(|(_, api)| api.main_trace_height()) + .map(|(_, air_ctx)| air_ctx.main_trace_height() as u32) .collect(); - // It's hard to define the mapping semantically. Please recompute the following magical AIR - // heights by hands whenever something changes. - assert_eq!( - air_heights, - vec![2, 2, 16, 1, 8, 4, 2, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 262144] - ); + assert_eq!(air_heights, fixed_air_heights); + Ok(()) } #[test] -fn test_vm_1_optional_air() { +fn test_vm_1_optional_air() -> eyre::Result<()> { // Aggregation VmConfig has Core/Poseidon2/FieldArithmetic/FieldExtension chips. The program // only uses Core and FieldArithmetic. All other chips should not have AIR proof inputs. let config = NativeConfig::aggregation(4, 3); let engine = BabyBearPoseidon2Engine::new(standard_fri_params_with_100_bits_conjectured_security(3)); - let vm = VirtualMachine::new(engine, config); - let pk = vm.keygen(); + let (vm, pk) = VirtualMachine::new_with_keygen(engine, NativeCpuBuilder, config)?; let num_airs = pk.per_air.len(); - { - let n = 6; - let instructions = vec![ - Instruction::large_from_isize(ADD.global_opcode(), 0, n, 0, 4, 0, 0, 0), - Instruction::large_from_isize(SUB.global_opcode(), 0, 0, 1, 4, 4, 0, 0), - Instruction::from_isize( - NativeBranchEqualOpcode(BNE).global_opcode(), - 0, - 0, - -(DEFAULT_PC_STEP as isize), - 4, - 0, - ), - Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), - ]; - - let program = Program::from_instructions(&instructions); - let result = vm - .execute_and_generate(program, vec![]) - .expect("Failed to execute VM"); - assert_eq!(result.per_segment.len(), 1); - let proof_input = result.per_segment.last().unwrap(); - assert!( - proof_input.per_air.len() < num_airs, - "Expect less used AIRs" - ); - let proofs = vm.prove(&pk, result); - assert_eq!(proofs.len(), 1); - vm.verify(&pk.get_vk(), proofs) - .expect("Verification failed"); - } + let n = 6; + let instructions = vec![ + Instruction::large_from_isize(ADD.global_opcode(), 0, n, 0, 4, 0, 0, 0), + Instruction::large_from_isize(SUB.global_opcode(), 0, 0, 1, 4, 4, 0, 0), + Instruction::from_isize( + NativeBranchEqualOpcode(BNE).global_opcode(), + 0, + 0, + -(DEFAULT_PC_STEP as isize), + 4, + 0, + ), + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let com_exe = vm.commit_exe(VmExe::new(Program::from_instructions(&instructions))); + let cached_program_trace = vm.transport_committed_exe_to_device(&com_exe); + let mut prover = VmLocalProver::new(vm, com_exe.exe, cached_program_trace); + let proof = SingleSegmentVmProver::prove(&mut prover, vec![], &vec![256; num_airs])?; + assert!(proof.per_air.len() < num_airs, "Expect less used AIRs"); + verify_single(&prover.vm.engine, &pk.get_vk(), &proof)?; + Ok(()) } #[test] -fn test_vm_public_values() { +fn test_vm_public_values() -> eyre::Result<()> { setup_tracing(); let num_public_values = 100; - let config = SystemConfig::default().with_public_values(num_public_values); + let config = test_system_config().with_public_values(num_public_values); + assert!(!config.continuation_enabled); let engine = BabyBearPoseidon2Engine::new(standard_fri_params_with_100_bits_conjectured_security(3)); - let vm = VirtualMachine::new(engine, config.clone()); - let pk = vm.keygen(); + let (vm, pk) = VirtualMachine::new_with_keygen(engine, SystemCpuBuilder, config)?; - { - let instructions = vec![ - Instruction::from_usize(PUBLISH.global_opcode(), [0, 12, 2, 0, 0, 0]), - Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), - ]; - - let program = Program::from_instructions(&instructions); - let committed_exe = Arc::new(VmCommittedExe::commit( - program.clone().into(), - vm.engine.config.pcs(), - )); - let single_vm = SingleSegmentVmExecutor::new(config); - let exe_result = single_vm - .execute_and_compute_heights(program, vec![]) - .unwrap(); - assert_eq!( - exe_result.public_values, - [ - vec![None, None, Some(BabyBear::from_canonical_u32(12))], - vec![None; num_public_values - 3] - ] - .concat(), - ); - let proof_input = single_vm - .execute_and_generate(committed_exe, vec![]) - .unwrap(); - vm.engine - .prove_then_verify(&pk, proof_input) - .expect("Verification failed"); - } + let instructions = vec![ + Instruction::from_usize(PUBLISH.global_opcode(), [0, 12, 2, 0, 0, 0]), + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + let com_exe = vm.commit_exe(VmExe::new(Program::from_instructions(&instructions))); + let cached_program_trace = vm.transport_committed_exe_to_device(&com_exe); + let mut prover = VmLocalProver::new(vm, com_exe.exe, cached_program_trace); + let proof = SingleSegmentVmProver::prove(&mut prover, vec![], &vec![256; pk.per_air.len()])?; + assert_eq!( + proof.per_air[PUBLIC_VALUES_AIR_ID].air_id, + PUBLIC_VALUES_AIR_ID + ); + assert_eq!( + proof.per_air[PUBLIC_VALUES_AIR_ID].public_values, + [ + vec![ + BabyBear::ZERO, + BabyBear::ZERO, + BabyBear::from_canonical_u32(12) + ], + vec![BabyBear::ZERO; num_public_values - 3] + ] + .concat(), + ); + verify_single(&prover.vm.engine, &pk.get_vk(), &proof)?; + Ok(()) } #[test] @@ -316,9 +260,8 @@ fn test_vm_initial_memory() { Instruction::::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), ]); - let init_memory: BTreeMap<_, _> = [((4, 7), BabyBear::from_canonical_u32(101))] - .into_iter() - .collect(); + let raw = unsafe { transmute::(BabyBear::from_canonical_u32(101)) }; + let init_memory = BTreeMap::from_iter((0..4).map(|i| ((4u32, 7u32 * 4 + i), raw[i as usize]))); let config = test_native_continuations_config(); let exe = VmExe { @@ -327,21 +270,18 @@ fn test_vm_initial_memory() { init_memory, fn_bounds: Default::default(), }; - air_test(config, exe); + air_test(NativeCpuBuilder, config, exe); } #[test] -fn test_vm_1_persistent() { +fn test_vm_1_persistent() -> eyre::Result<()> { let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); let config = test_native_continuations_config(); + let merkle_air_idx = config.system.memory_boundary_air_id() + 1; let ptr_max_bits = config.system.memory_config.pointer_max_bits; - let as_height = config.system.memory_config.as_height; - let airs = VmConfig::::create_chip_complex(&config) - .unwrap() - .airs::(); + let addr_space_height = config.system.memory_config.addr_space_height; - let vm = VirtualMachine::new(engine, config); - let pk = vm.keygen(); + let (vm, pk) = VirtualMachine::new_with_keygen(engine, NativeCpuBuilder, config)?; let n = 6; let instructions = vec![ @@ -358,39 +298,34 @@ fn test_vm_1_persistent() { Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), ]; - let program = Program::from_instructions(&instructions); + let com_exe = vm.commit_exe(VmExe::new(Program::from_instructions(&instructions))); + let cached_program_trace = vm.transport_committed_exe_to_device(&com_exe); + let mut prover = VmLocalProver::new(vm, com_exe.exe, cached_program_trace); + let proof = ContinuationVmProver::prove(&mut prover, vec![])?; - let result = vm.execute_and_generate(program.clone(), vec![]).unwrap(); { - let proof_input = result.per_segment.into_iter().next().unwrap(); - - let ((_, merkle_air_proof_input), _) = zip(&proof_input.per_air, &airs) - .find(|(_, air)| air.name() == "MemoryMerkleAir<8>") - .unwrap(); - assert_eq!(merkle_air_proof_input.raw.public_values.len(), 16); - assert_eq!( - merkle_air_proof_input.raw.public_values[..8], - merkle_air_proof_input.raw.public_values[8..] - ); + assert_eq!(proof.per_segment.len(), 1); + let public_values = proof.per_segment[0].per_air[merkle_air_idx] + .public_values + .clone(); + assert_eq!(public_values.len(), 16); + assert_eq!(public_values[..8], public_values[8..]); let mut digest = [BabyBear::ZERO; CHUNK]; let compression = vm_poseidon2_hasher(); - for _ in 0..ptr_max_bits + as_height - 2 { + for _ in 0..ptr_max_bits + addr_space_height - 2 { digest = compression.compress(&digest, &digest); } assert_eq!( - merkle_air_proof_input.raw.public_values[..8], + public_values[..8], // The value when you start with zeros and repeatedly hash the value with itself - // ptr_max_bits + as_height - 2 times. - // The height of the tree is ptr_max_bits + as_height - log2(8). The leaf also must be - // hashed once with padding for security. + // ptr_max_bits + addr_space_height - 2 times. + // The height of the tree is ptr_max_bits + addr_space_height - log2(8). The leaf also + // must be hashed once with padding for security. digest ); } - - let result_for_proof = vm.execute_and_generate(program, vec![]).unwrap(); - let proofs = vm.prove(&pk, result_for_proof); - vm.verify(&pk.get_vk(), proofs) - .expect("Verification failed"); + verify_segments(&prover.vm.engine, &pk.get_vk(), &proof.per_segment)?; + Ok(()) } #[test] @@ -438,7 +373,7 @@ fn test_vm_without_field_arithmetic() { let program = Program::from_instructions(&instructions); - air_test(test_native_config(), program); + air_test(NativeCpuBuilder, test_native_config(), program); } #[test] @@ -485,7 +420,7 @@ fn test_vm_fibonacci_old() { let program = Program::from_instructions(&instructions); - air_test(test_native_config(), program); + air_test(NativeCpuBuilder, test_native_config(), program); } #[test] @@ -544,7 +479,7 @@ fn test_vm_fibonacci_old_cycle_tracker() { let program = Program::from_instructions(&instructions); - air_test(test_native_config(), program); + air_test(NativeCpuBuilder, test_native_config(), program); } #[test] @@ -568,7 +503,7 @@ fn test_vm_field_extension_arithmetic() { let program = Program::from_instructions(&instructions); - air_test(test_native_config(), program); + air_test(NativeCpuBuilder, test_native_config(), program); } #[test] @@ -594,23 +529,30 @@ fn test_vm_max_access_adapter_8() { let mut config = test_native_config(); { - let chip_complex1 = config.create_chip_complex().unwrap(); - let mem_ctrl1 = chip_complex1.base.memory_controller; + let num_sys_airs1 = config.system.num_airs(); + let inventory1: AirInventory = config.create_airs().unwrap(); + let num_ext_airs = inventory1.ext_airs().len(); + let mem_inv1 = &inventory1.system().memory; config.system.memory_config.max_access_adapter_n = 8; - let chip_complex2 = config.create_chip_complex().unwrap(); - let mem_ctrl2 = chip_complex2.base.memory_controller; + let num_sys_airs2 = config.system.num_airs(); + let inventory2: AirInventory = config.create_airs().unwrap(); + let mem_inv2 = &inventory2.system().memory; // AccessAdapterAir with N=16/32 are disabled. - assert_eq!(mem_ctrl1.air_names().len(), mem_ctrl2.air_names().len() + 2); assert_eq!( - mem_ctrl1.airs::().len(), - mem_ctrl2.airs::().len() + 2 + mem_inv1.access_adapters.len(), + mem_inv2.access_adapters.len() + 2 + ); + assert_eq!(num_sys_airs1, num_sys_airs2 + 2); + assert_eq!( + inventory1.into_airs().collect_vec().len(), + num_sys_airs1 + num_ext_airs ); assert_eq!( - mem_ctrl1.current_trace_heights().len(), - mem_ctrl2.current_trace_heights().len() + 2 + inventory2.into_airs().collect_vec().len(), + num_sys_airs2 + num_ext_airs ); } - air_test(config, program); + air_test(NativeCpuBuilder, test_native_config(), program); } #[test] @@ -634,7 +576,7 @@ fn test_vm_field_extension_arithmetic_persistent() { let program = Program::from_instructions(&instructions); let config = test_native_continuations_config(); - air_test(config, program); + air_test(NativeCpuBuilder, config, program); } #[test] @@ -656,7 +598,7 @@ fn test_vm_hint() { Instruction::from_isize(LOADW.global_opcode(), 38, 0, 32, 4, 4), Instruction::large_from_isize(ADD.global_opcode(), 44, 20, 0, 4, 4, 0, 0), Instruction::from_isize(MUL.global_opcode(), 24, 38, 1, 4, 4), - Instruction::large_from_isize(ADD.global_opcode(), 20, 20, 24, 4, 4, 1, 0), + Instruction::large_from_isize(ADD.global_opcode(), 20, 20, 24, 4, 4, 4, 0), Instruction::large_from_isize(ADD.global_opcode(), 50, 16, 0, 4, 4, 0, 0), Instruction::from_isize( JAL.global_opcode(), @@ -694,8 +636,8 @@ fn test_vm_hint() { type F = BabyBear; let input_stream: Vec> = vec![vec![F::TWO]]; - let config = NativeConfig::new(SystemConfig::default(), Default::default()); - air_test_with_min_segments(config, program, input_stream, 1); + let config = test_native_config(); + air_test_with_min_segments(NativeCpuBuilder, config, program, input_stream, 1); } #[test] @@ -712,17 +654,10 @@ fn test_hint_load_1() { ]; let program = Program::from_instructions(&instructions); + let input = vec![vec![F::ONE, F::TWO]]; - let mut segment = ExecutionSegment::new( - &test_native_config(), - program, - vec![vec![F::ONE, F::TWO]].into(), - None, - vec![], - Default::default(), - ); - segment.execute_from_pc(0).unwrap(); - let streams = segment.chip_complex.take_streams(); + let state = execute_program(program, input); + let streams = state.streams; assert!(streams.input_stream.is_empty()); assert_eq!(streams.hint_stream, VecDeque::from(vec![F::ZERO])); assert_eq!(streams.hint_space, vec![vec![F::ONE, F::TWO]]); @@ -749,24 +684,12 @@ fn test_hint_load_2() { ]; let program = Program::from_instructions(&instructions); + let input = vec![vec![F::ONE, F::TWO], vec![F::TWO, F::ONE]]; - let mut segment = ExecutionSegment::new( - &test_native_config(), - program, - vec![vec![F::ONE, F::TWO], vec![F::TWO, F::ONE]].into(), - None, - vec![], - Default::default(), - ); - segment.execute_from_pc(0).unwrap(); - assert_eq!( - segment - .chip_complex - .memory_controller() - .unsafe_read_cell(F::from_canonical_usize(4), F::from_canonical_usize(32)), - F::ZERO - ); - let streams = segment.chip_complex.take_streams(); + let state = execute_program(program, input); + let [read] = unsafe { state.memory.read::(4, 32) }; + assert_eq!(read, F::ZERO); + let streams = state.streams; assert!(streams.input_stream.is_empty()); assert_eq!(streams.hint_stream, VecDeque::from(vec![F::ONE])); assert_eq!( @@ -774,3 +697,218 @@ fn test_hint_load_2() { vec![vec![F::ONE, F::TWO], vec![F::TWO, F::ONE]] ); } + +#[test] +fn test_vm_pure_execution_non_continuation() { + type F = BabyBear; + let n = 6; + /* + Instruction 0 assigns word[0]_4 to n. + Instruction 4 terminates + The remainder is a loop that decrements word[0]_4 until it reaches 0, then terminates. + Instruction 1 checks if word[0]_4 is 0 yet, and if so sets pc to 5 in order to terminate + Instruction 2 decrements word[0]_4 (using word[1]_4) + Instruction 3 uses JAL as a simple jump to go back to instruction 1 (repeating the loop). + */ + let instructions: Vec> = vec![ + // word[0]_4 <- word[n]_0 + Instruction::large_from_isize(ADD.global_opcode(), 0, n, 0, 4, 0, 0, 0), + // if word[0]_4 == 0 then pc += 3 * DEFAULT_PC_STEP + Instruction::from_isize( + NativeBranchEqualOpcode(BEQ).global_opcode(), + 0, + 0, + 3 * DEFAULT_PC_STEP as isize, + 4, + 0, + ), + // word[0]_4 <- word[0]_4 - word[1]_4 + Instruction::large_from_isize(SUB.global_opcode(), 0, 0, 1, 4, 4, 0, 0), + // word[2]_4 <- pc + DEFAULT_PC_STEP, pc -= 2 * DEFAULT_PC_STEP + Instruction::from_isize( + JAL.global_opcode(), + 2, + -2 * DEFAULT_PC_STEP as isize, + 0, + 4, + 0, + ), + // terminate + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let exe = VmExe::new(Program::from_instructions(&instructions)); + let executor = VmExecutor::new(test_native_config()).unwrap(); + let instance = executor.instance(&exe).unwrap(); + instance.execute(vec![], None).expect("Failed to execute"); +} + +#[test] +fn test_vm_pure_execution_continuation() { + type F = BabyBear; + let instructions: Vec> = vec![ + Instruction::large_from_isize(ADD.global_opcode(), 0, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 1, 0, 2, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 2, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 3, 0, 2, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 4, 0, 2, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 5, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 6, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 7, 0, 2, 4, 0, 0, 0), + Instruction::from_isize(FE4ADD.global_opcode(), 8, 0, 4, 4, 4), + Instruction::from_isize(FE4ADD.global_opcode(), 8, 0, 4, 4, 4), + Instruction::from_isize(FE4SUB.global_opcode(), 12, 0, 4, 4, 4), + Instruction::from_isize(BBE4MUL.global_opcode(), 12, 0, 4, 4, 4), + Instruction::from_isize(BBE4DIV.global_opcode(), 12, 0, 4, 4, 4), + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let exe = VmExe::new(Program::from_instructions(&instructions)); + let executor = VmExecutor::new(test_native_continuations_config()).unwrap(); + let instance = executor.instance(&exe).unwrap(); + instance.execute(vec![], None).expect("Failed to execute"); +} + +#[test] +fn test_vm_e1_native_chips() { + type F = BabyBear; + + let instructions = vec![ + // Field Arithmetic operations (FieldArithmeticChip) + Instruction::large_from_isize(ADD.global_opcode(), 0, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(SUB.global_opcode(), 1, 10, 2, 4, 0, 0, 0), + Instruction::large_from_isize(MUL.global_opcode(), 2, 3, 4, 4, 0, 0, 0), + Instruction::large_from_isize(DIV.global_opcode(), 3, 20, 5, 4, 0, 0, 0), + // Field Extension operations (FieldExtensionChip) + Instruction::from_isize(FE4ADD.global_opcode(), 8, 0, 4, 4, 4), + Instruction::from_isize(FE4SUB.global_opcode(), 12, 8, 4, 4, 4), + Instruction::from_isize(BBE4MUL.global_opcode(), 16, 12, 8, 4, 4), + Instruction::from_isize(BBE4DIV.global_opcode(), 20, 16, 12, 4, 4), + // Branch operations (NativeBranchEqChip) + Instruction::from_isize( + NativeBranchEqualOpcode(BEQ).global_opcode(), + 0, + 0, + DEFAULT_PC_STEP as isize, + 4, + 4, + ), + Instruction::from_isize( + NativeBranchEqualOpcode(BNE).global_opcode(), + 1, + 2, + DEFAULT_PC_STEP as isize, + 4, + 4, + ), + // JAL operation (JalRangeCheckChip) + Instruction::from_isize( + NativeJalOpcode::JAL.global_opcode(), + 24, + DEFAULT_PC_STEP as isize, + 0, + 4, + 0, + ), + // Range check operation (JalRangeCheckChip) + Instruction::from_isize( + NativeRangeCheckOpcode::RANGE_CHECK.global_opcode(), + 0, + 10, + 8, + 4, + 0, + ), + // Load/Store operations (NativeLoadStoreChip) + Instruction::from_isize(STOREW.global_opcode(), 0, 0, 28, 4, 4), + Instruction::from_isize(LOADW.global_opcode(), 32, 0, 28, 4, 4), + Instruction::from_isize( + PHANTOM.global_opcode(), + 0, + 0, + NativePhantom::HintInput as isize, + 0, + 0, + ), + Instruction::from_isize(HINT_STOREW.global_opcode(), 32, 0, 0, 4, 4), + // Cast to field operation (CastFChip) + Instruction::from_usize(CastfOpcode::CASTF.global_opcode(), [36, 40, 0, 2, 4]), + // Poseidon2 operations (Poseidon2Chip) + Instruction::new( + Poseidon2Opcode::PERM_POS2.global_opcode(), + F::from_canonical_usize(44), + F::from_canonical_usize(48), + F::ZERO, + F::from_canonical_usize(4), + F::from_canonical_usize(4), + F::ZERO, + F::ZERO, + ), + Instruction::new( + Poseidon2Opcode::COMP_POS2.global_opcode(), + F::from_canonical_usize(52), + F::from_canonical_usize(44), + F::from_canonical_usize(48), + F::from_canonical_usize(4), + F::from_canonical_usize(4), + F::ZERO, + F::ZERO, + ), + // FRI operation (FriReducedOpeningChip) + Instruction::large_from_isize(ADD.global_opcode(), 60, 64, 0, 4, 4, 0, 0), /* a_pointer_pointer, */ + Instruction::large_from_isize(ADD.global_opcode(), 64, 68, 0, 4, 4, 0, 0), /* b_pointer_pointer, */ + Instruction::large_from_isize(ADD.global_opcode(), 68, 2, 0, 4, 0, 0, 0), /* length_pointer (value 2), */ + Instruction::large_from_isize(ADD.global_opcode(), 72, 1, 0, 4, 0, 0, 0), //alpha_pointer + Instruction::large_from_isize(ADD.global_opcode(), 76, 80, 0, 4, 4, 0, 0), /* result_pointer, */ + Instruction::large_from_isize(ADD.global_opcode(), 80, 1, 0, 4, 0, 0, 0), /* is_init (value 1) , */ + Instruction::from_usize( + FriOpcode::FRI_REDUCED_OPENING.global_opcode(), + [60, 64, 68, 72, 76, 0, 80], + ), + // Terminate + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let exe = VmExe::new(Program::from_instructions(&instructions)); + let input_stream: Vec> = vec![vec![]]; + + let executor = VmExecutor::new(test_rv32_with_kernels_config()).unwrap(); + let instance = executor.instance(&exe).unwrap(); + instance + .execute(input_stream, None) + .expect("Failed to execute"); +} + +// This test ensures that metered execution never segments when continuations is disabled +#[test] +fn test_single_segment_executor_no_segmentation() { + setup_tracing(); + + let mut config = test_native_config(); + config + .system + .set_segmentation_limits(SegmentationLimits::default().with_max_trace_height(1)); + + let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(3)); + let (vm, _) = VirtualMachine::new_with_keygen(engine, NativeCpuBuilder, config).unwrap(); + let instructions: Vec<_> = (0..2 * DEFAULT_SEGMENT_CHECK_INSNS) + .map(|_| Instruction::large_from_isize(ADD.global_opcode(), 0, 0, 1, 4, 0, 0, 0)) + .chain(std::iter::once(Instruction::from_isize( + TERMINATE.global_opcode(), + 0, + 0, + 0, + 0, + 0, + ))) + .collect(); + + let exe = VmExe::new(Program::from_instructions(&instructions)); + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + let metered_ctx = vm.build_metered_ctx(); + vm.executor() + .metered_instance(&exe, &executor_idx_to_air_idx) + .unwrap() + .execute_metered(vec![], metered_ctx) + .unwrap(); +} diff --git a/docs/crates/benchmarks.md b/docs/crates/benchmarks.md index 6a259cb10e..39c8bc838f 100644 --- a/docs/crates/benchmarks.md +++ b/docs/crates/benchmarks.md @@ -120,10 +120,10 @@ for more detailed profiling we generate special flamegraphs that visualize VM-sp The benchmark must be run with special configuration so that additional metrics are collected for profiling. Note that the additional metric collection will slow down the benchmark. To run a benchmark with the additional profiling, run the following command: ```bash -OUTPUT_PATH="metrics.json" GUEST_SYMBOLS_PATH="guest.syms" cargo run --release --bin --features profiling -- --profiling +OUTPUT_PATH="metrics.json" GUEST_SYMBOLS_PATH="guest.syms" cargo run --release --bin --features perf-metrics -- --profiling ``` -Add `--features aggregation,profiling` to run with leaf aggregation. The `profiling` feature tells the VM to run with additional metric collection. The `--profiling` CLI argument tells the script to build the guest program with `profile=profiling` so that the guest program is compiled without stripping debug symbols. When the `profiling` feature is enabled, the `GUEST_SYMBOLS_PATH` environment variable must be set to the file path where function symbols of the guest program will be exported. Those symbols are then used to annotate the flamegraph with function names. +Add `--features aggregation,perf-metrics` to run with leaf aggregation. The `perf-metrics` feature tells the VM to run with additional metric collection. The `--profiling` CLI argument tells the script to build the guest program with `profile=profiling` so that the guest program is compiled without stripping debug symbols. When the `perf-metrics` feature is enabled, the `GUEST_SYMBOLS_PATH` environment variable must be set to the file path where function symbols of the guest program will be exported. Those symbols are then used to annotate the flamegraph with function names. After the collected metrics are written to `$OUTPUT_PATH`, these flamegraphs can be generated if you have [inferno-flamegraph](https://crates.io/crates/inferno) installed. Install via @@ -181,19 +181,19 @@ For execution benchmarks, the ELF files need to be compiled before running the b ```bash # Build all benchmark ELFs -cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-binaries +cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-elfs # Build specific benchmark ELFs -cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-binaries -- fibonacci_recursive fibonacci_iterative +cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-elfs -- fibonacci_recursive fibonacci_iterative # Skip specific programs -cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-binaries -- --skip keccak256 sha256 +cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-elfs -- --skip keccak256 sha256 # Force rebuild even if ELFs already exist (overwrite) -cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-binaries -- --force +cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-elfs -- --force # Set build profile (debug or release) -cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-binaries -- --profile debug +cargo run --package openvm-benchmarks-utils --bin build-elfs --features build-elfs -- --profile debug ``` ## Profiling Execution diff --git a/docs/crates/metrics.md b/docs/crates/metrics.md index 362bce47e0..6fe3072add 100644 --- a/docs/crates/metrics.md +++ b/docs/crates/metrics.md @@ -2,21 +2,25 @@ We use the [`metrics`](https://docs.rs/metrics/latest/metrics/) crate to collect metrics for the STARK prover. We refer to [reth docs](https://github.com/paradigmxyz/reth/blob/main/docs/design/metrics.md) for more guidelines on how to use metrics. -Metrics will only be collected if the `bench-metrics` feature is enabled. +Metrics will only be collected if the `metrics` feature is enabled. We describe the metrics that are collected for a single VM circuit proof, which corresponds to a single execution segment. To scope metrics from different proofs, we use the [`metrics_tracing_context`](https://docs.rs/metrics-tracing-context/latest/metrics_tracing_context/) crate to provide context-dependent labels. With the exception of the `segment` label, all other labels must be set by the caller. -For a single segment proof, the following metrics are collected: +For a segment proof, the following metrics are collected: -- `execute_time_ms` (gauge): The runtime execution time of the segment in milliseconds. +- `execute_metered_time_ms` (gauge): The metered execution time of the segment in milliseconds. This is timed across **all** segments in the group. +- `execute_preflight_time_ms` (gauge): The preflight execution time of the segment in milliseconds. - If this is a segment in a VM with continuations enabled, a `segment: segment_idx` label is added to the metric. + - `memory_finalize_time_ms` (gauge): The time at the end of preflight execution spent on memory finalization. - `trace_gen_time_ms` (gauge): The time to generate non-cached trace matrices from execution records. - If this is a segment in a VM with continuations enabled, a `segment: segment_idx` label is added to the metric. - All metrics collected by [`openvm-stark-backend`](https://github.com/openvm-org/stark-backend/blob/main/docs/metrics.md), in particular `stark_prove_excluding_trace_time_ms` (gauge). - - The total proving time of the proof is the sum of `execute_time_ms + trace_gen_time_ms + stark_prove_excluding_trace_time_ms`. -- `total_cycles` (counter): The total number of cycles in the segment. +- The `total_proof_time_ms` of the proof is instrumented directly when possible. Otherwise, it is calculated as: + - The sum `execute_preflight_time_ms + trace_gen_time_ms + stark_prove_excluding_trace_time_ms`. The `execute_metered_time_ms` is excluded for app proofs because it is not run on a per-segment basis. +- `insns` (counter): The total number of instructions executed in the segment. - `main_cells_used` (counter): The total number of main trace cells used by all chips in the segment. This does not include cells needed to pad rows to power-of-two matrix heights. Only main trace cells, not preprocessed or permutation trace cells, are counted. +- `total_cells_used` (counter): The total number of preprocessed, main, and permutation trace cells used by all chips in the segment. This does not include cells needed to pad rows to power-of-two matrix heights. ## Scoping diff --git a/docs/crates/vm-extensions.md b/docs/crates/vm-extensions.md index 490bac08c5..a27f56fca3 100644 --- a/docs/crates/vm-extensions.md +++ b/docs/crates/vm-extensions.md @@ -2,7 +2,7 @@ ```rust pub trait VmExtension { - type Executor: InstructionExecutor + AnyEnum; + type Executor: PreflightExecutor + AnyEnum; type Periphery: AnyEnum; fn build( @@ -17,7 +17,7 @@ by them. This data is collected into a `VmInventory` struct, which is returned. To handle previous chip dependencies necessary for chip construction and also automatic bus index management, we provide a `VmInventoryBuilder` api. -Due to strong types, we have **two** associated trait types `Executor, Periphery`. It is expected that `Executor` is an enum of all types implementing `InstructionExecutor + Chip` that this extension will construct. It is expected that `Periphery` is an enum of all types that implement `Chip` **but are not InstructionExecutor**. In general, it is always OK for the enum to have more kinds than necessary. For easy downcasting and enum wrangling, we also have an `AnyEnum` trait, which can always be derived by a macro. +Due to strong types, we have **two** associated trait types `Executor, Periphery`. It is expected that `Executor` is an enum of all types implementing `PreflightExecutor + Chip` that this extension will construct. It is expected that `Periphery` is an enum of all types that implement `Chip` **but are not PreflightExecutor**. In general, it is always OK for the enum to have more kinds than necessary. For easy downcasting and enum wrangling, we also have an `AnyEnum` trait, which can always be derived by a macro. ### `VmInventory` @@ -90,7 +90,7 @@ We have trait `VmConfig`: ```rust pub trait VmConfig { - type Executor: InstructionExecutor + AnyEnum + ChipUsageGetter; + type Executor: PreflightExecutor + AnyEnum + ChipUsageGetter; type Periphery: AnyEnum + ChipUsageGetter; /// Must contain system config diff --git a/docs/crates/vm.md b/docs/crates/vm.md index 989e9c8a88..a814cf010b 100644 --- a/docs/crates/vm.md +++ b/docs/crates/vm.md @@ -1,12 +1,12 @@ # VM Architecture and Chips -### `InstructionExecutor` Trait +### `PreflightExecutor` Trait We define an **instruction** to be an **opcode** combined with the **operands** for the opcode. Running the instrumented runtime for an opcode is encapsulated in the following trait: ```rust -pub trait InstructionExecutor { +pub trait PreflightExecutor { /// Runtime execution of the instruction, if the instruction is owned by the /// current instance. May internally store records of this call for later trace generation. fn execute( @@ -26,14 +26,14 @@ Opcodes are partitioned into groups, each of which is handled by a single **chip type `C` and associated Air of type `A` which satisfy the following trait bounds: ```rust -C: Chip + InstructionExecutor +C: Chip + PreflightExecutor A: Air + BaseAir + BaseAirWithPublicValues ``` Together, these provide the following functionalities: - **Keygen:** Performed via the `Air::::eval()` function. -- **Trace Generation:** This is done by calling `InstructionExecutor::::execute()` which computes and stores +- **Trace Generation:** This is done by calling `PreflightExecutor::::execute()` which computes and stores execution records and then `Chip::::generate_air_proof_input()` which generates the trace using the corresponding records. @@ -59,6 +59,7 @@ pub trait PhantomSubExecutor { &mut self, memory: &MemoryController, streams: &mut Streams, + rng: &mut StdRng, discriminant: PhantomDiscriminant, a: F, b: F, @@ -88,7 +89,7 @@ The engine type `E` should be `openvm_stark_backend::engine::StarkEngine `an ```rust pub trait VmConfig: Clone + Serialize + DeserializeOwned { - type Executor: InstructionExecutor + AnyEnum + ChipUsageGetter; + type Executor: PreflightExecutor + AnyEnum + ChipUsageGetter; type Periphery: AnyEnum + ChipUsageGetter; /// Must contain system config @@ -318,7 +319,7 @@ pub struct VmAirWrapper { They implement the following traits: -- `InstructionExecutor` is implemented on `VmChipWrapper`, where the `execute()` function: +- `PreflightExecutor` is implemented on `VmChipWrapper`, where the `execute()` function: - calls `preprocess()` on `A` with `memory` and the raw `instruction` - calls `execute_instruction()` on `C` with the raw `instruction`, `from_pc`, and `reads` from `preprocess()` - calls `postprocess()` on `A` with the raw `instruction`, `from_state`, the `output: AdapterRuntimeContext` from `execute_instruction()`, and the `read_record` diff --git a/docs/specs/ISA.md b/docs/specs/ISA.md index 2755243cdf..fc6b242e39 100644 --- a/docs/specs/ISA.md +++ b/docs/specs/ISA.md @@ -31,8 +31,8 @@ OpenVM depends on the following parameters, some of which are fixed and some of | `PC_BITS` | The number of bits in the program counter. | Fixed to 30. | | `DEFAULT_PC_STEP` | The default program counter step size. | Fixed to 4. | | `LIMB_BITS` | The number of bits in a limb for RISC-V memory emulation. | Fixed to 8. | -| `as_offset` | The index of the first writable address space. | Fixed to 1. | -| `as_height` | The base 2 log of the number of writable address spaces supported. | Configurable, must satisfy `as_height <= F::bits() - 2` | +| `ADDR_SPACE_OFFSET` | The index of the first writable address space. | Fixed to 1. | +| `addr_space_height` | The base 2 log of the number of writable address spaces supported. | Configurable, must satisfy `addr_space_height <= F::bits() - 2` | | `pointer_max_bits` | The maximum number of bits in a pointer. | Configurable, must satisfy `pointer_max_bits <= F::bits() - 2` | | `num_public_values` | The number of user public values. | Configurable. If continuation is enabled, it must equal `8` times a power of two(which is nonzero). | @@ -113,12 +113,12 @@ Data memory is a random access memory (RAM) which supports read and write operat cells which represent a single field element indexed by **address space** and **pointer**. The number of supported address spaces and the size of each address space are configurable constants. -- Valid address spaces not used for immediates lie in `[1, 1 + 2^as_height)` for configuration constant `as_height`. +- Valid address spaces not used for immediates lie in `[1, 1 + 2^addr_space_height)` for configuration constant `addr_space_height`. - Valid pointers are field elements that lie in `[0, 2^pointer_max_bits)`, for configuration constant `pointer_max_bits`. When accessing an address out of `[0, 2^pointer_max_bits)`, the VM should panic. - -These configuration constants must satisfy `as_height, pointer_max_bits <= F::bits() - 2`. We use the following notation +- For the register address space (address space `1`), valid pointers lie in `[0, 128)`, corresponding to 32 registers with 4 byte limbs each. +These configuration constants must satisfy `addr_space_height, pointer_max_bits <= F::bits() - 2`. We use the following notation to denote cells in memory: - `[a]_d` denotes the single-cell value at pointer location `a` in address space `d`. This is a single @@ -133,9 +133,15 @@ to. Address space `0` is considered a read-only array with `[a]_0 = a` for any ` #### Memory Accesses and Block Accesses VM instructions can access (read or write) a contiguous list of cells (called a **block**) in a single address space. -The block size must be in the set `{1, 2, 4, 8, 16, 32}`, and the access does not need to be aligned, meaning that -it can start from any pointer address, even those not divisible by the block size. An access is called a **block access -** if it has size greater than 1. Block accesses are not supported for address space `0`. +The block size must be in the set `{1, 2, 4, 8, 16, 32}`, and each address space has a minimum block size that is +configurable. All block accesses must be at pointers that are a multiple of the minimum block size. For address +spaces `1`, `2`, and `3`, the minimum block size is 4, meaning all accesses must be at pointer addresses that are +divisible by 4. However, RISC-V instructions like `lb`, `lh`, `sb`, and `sh` still work despite having minimum +block size requirements equal to the size of the access (1 byte for `lb`/`sb`, 2 bytes for `lh`/`sh`) because these +instructions are implemented by doing a block access of size 4. For the native address space (`4`), the minimum +block size is 1, so accesses can start from any pointer address. For address spaces beyond `4`, the minimum +block size defaults to 1 but can be configured. +Block accesses are not supported for address space `0`. #### Address Spaces @@ -171,7 +177,7 @@ structures during runtime execution: - `hint_space`: a vector of vectors of field elements used to store hints during runtime execution via [phantom sub-instructions](#phantom-sub-instructions) such as `NativeHintLoad`. The outer `hint_space` vector is append-only, but each internal `hint_space[hint_id]` vector may be mutated, including deletions, by the host. -- `kv_store`: a read-only key-value store for hints. Executors(e.g. `Rv32HintLoadByKey`) can read data from `kv_store` +- `kv_store`: a read-only key-value store for hints. Executors(e.g. `Rv32HintLoadByKey`) can read data from `kv_store` at runtime. `kv_store` is designed for general purposes so both key and value are byte arrays. Encoding of key/value are decided by each executor. Users need to use the corresponding encoding when adding data to `kv_store`. @@ -327,7 +333,7 @@ unsigned integer, and convert to field element. In the instructions below, `[c:4 #### Load/Store For all load/store instructions, we assume the operand `c` is in `[0, 2^16)`, and we fix address spaces `d = 1`. -The address space `e` can be `0`, `1`, or `2` for load instructions, and `2`, `3`, or `4` for store instructions. +The address space `e` is `2` for load instructions, and can be `2`, `3`, or `4` for store instructions. The operand `g` must be a boolean. We let `sign_extend(decompose(c)[0:2], g)` denote the `i32` defined by first taking the unsigned integer encoding of `c` as 16 bits, then sign extending it to 32 bits using the sign bit `g`, and considering the 32 bits as the 2's complement of an `i32`. @@ -454,7 +460,7 @@ reads but not allowed for writes. When using immediates, we interpret `[a]_0` as | STOREW | `a,b,c,4,4` | Set `[[c]_4 + b]_4 = [a]_4`. | | LOADW4 | `a,b,c,4,4` | Set `[a:4]_4 = [[c]_4 + b:4]_4`. | | STOREW4 | `a,b,c,4,4` | Set `[[c]_4 + b:4]_4 = [a:4]_4`. | -| JAL | `a,b,_,4` | Jump to address and link: set `[a]_4 = (pc + DEFAULT_PC_STEP)` and `pc = pc + b`. | +| JAL | `a,b,_,4` | Jump to address and link: set `[a]_4 = (pc + DEFAU````LT````_PC_STEP)` and `pc = pc + b`. | | RANGE_CHECK | `a,b,c,4` | Assert that `[a]_4 = x + y * 2^16` for some `x < 2^b` and `y < 2^c`. `b` must be in [0,16] and `c` must be in [0, 14]. | | BEQ | `a,b,c,d,e` | If `[a]_d == [b]_e`, then set `pc = pc + c`. | | BNE | `a,b,c,d,e` | If `[a]_d != [b]_e`, then set `pc = pc + c`. | @@ -464,7 +470,7 @@ reads but not allowed for writes. When using immediates, we interpret `[a]_0` as #### Field Arithmetic -This instruction set does native field operations. Below, `e,f` may be any address space. +This instruction set does native field operations. Below, `e,f` must be either `0` or `4`. When either `e` or `f` is zero, `[b]_0` and `[c]_0` should be interpreted as the immediates `b` and `c`, respectively. @@ -677,12 +683,13 @@ r32_fp2(a) -> Fp2 { ### Elliptic Curve Extension -The elliptic curve extension supports arithmetic over elliptic curves `C` in Weierstrass form given by -equation `C: y^2 = x^3 + C::A * x + C::B` where `C::A` and `C::B` are constants in the coordinate field. We note that -the definitions of the -curve arithmetic operations do not depend on `C::B`. The VM configuration will specify a list of supported curves. For -each Weierstrass curve `C` there will be associated configuration parameters `C::COORD_SIZE` and `C::BLOCK_SIZE` ( -defined below). The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be +The elliptic curve extension supports arithmetic over elliptic curves `C` in the following forms: +- in short Weierstrass form given by equation `C: y^2 = x^3 + C::A * x + C::B` where `C::A` and `C::B` are constants in the coordinate field +- in twisted Edwards form given by equation `C: C::A * x^2 + y^2 = 1 + C::D * x^2 * y^2` where `C::A` and `C::D` are constants in the coordinate field + +We note that +the definitions of the curve arithmetic operations for short Weierstrass curves do not depend on `C::B`. The VM configuration will specify a list of supported curves. For +each curve `C` (of either form) there will be associated configuration parameters `C::COORD_SIZE` and `C::BLOCK_SIZE` (defined below). The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be bytes. An affine curve point `EcPoint(x, y)` is a pair of `x,y` where each element is an array of `C::COORD_SIZE` elements each @@ -700,12 +707,16 @@ r32_ec_point(a) -> EcPoint { } ``` +The instructions that have prefix `SW_` perform short Weierstrass curve operations, and those with prefix `TE_` perform twisted Edwards curve operations. + | Name | Operands | Description | | -------------------- | ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| EC_ADD_NE\ | `a,b,c,1,2` | Set `r32_ec_point(a) = r32_ec_point(b) + r32_ec_point(c)` (curve addition). Assumes that `r32_ec_point(b), r32_ec_point(c)` both lie on the curve and are not the identity point. Further assumes that `r32_ec_point(b).x, r32_ec_point(c).x` are not equal in the coordinate field. | -| SETUP_EC_ADD_NE\ | `a,b,c,1,2` | `assert(r32_ec_point(b).x == C::MODULUS)` in the chip for EC ADD. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).x != r32_ec_point(c).x)` | -| EC_DOUBLE\ | `a,b,_,1,2` | Set `r32_ec_point(a) = 2 * r32_ec_point(b)`. This doubles the input point. Assumes that `r32_ec_point(b)` lies on the curve and is not the identity point. | -| SETUP_EC_DOUBLE\ | `a,b,_,1,2` | `assert(r32_ec_point(b).x == C::MODULUS)` in the chip for EC DOUBLE. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).y != 0 mod C::MODULUS)` | +| SW_ADD_NE\ | `a,b,c,1,2` | Set `r32_ec_point(a) = r32_ec_point(b) + r32_ec_point(c)` (curve addition). Assumes that `r32_ec_point(b), r32_ec_point(c)` both lie on the curve and are not the identity point. Further assumes that `r32_ec_point(b).x, r32_ec_point(c).x` are not equal in the coordinate field. | +| SETUP_SW_ADD_NE\ | `a,b,c,1,2` | `assert(r32_ec_point(b).x == C::MODULUS)` in the chip for SW ADD. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).x != r32_ec_point(c).x)` | +| SW_DOUBLE\ | `a,b,_,1,2` | Set `r32_ec_point(a) = 2 * r32_ec_point(b)`. This doubles the input point. Assumes that `r32_ec_point(b)` lies on the curve and is not the identity point. | +| SETUP_SW_DOUBLE\ | `a,b,_,1,2` | `assert(r32_ec_point(b).x == C::MODULUS && r32_ec_point(b).y == C::A)` in the chip for SW DOUBLE. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).y != 0 mod C::MODULUS)` | +| TE_ADD\ | `a,b,c,1,2` | Set `r32_ec_point(a) = r32_ec_point(b) + r32_ec_point(c)` (curve addition). Assumes that `r32_ec_point(b), r32_ec_point(c)` both lie on the curve. | +| SETUP_TE_ADD\ | `a,b,c,1,2` | `assert(r32_ec_point(b).x == C::MODULUS && r32_ec_point(b).y == C::A && r32_ec_point(c).x == C::D)` in the chip for TE ADD. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. | ### Pairing Extension diff --git a/docs/specs/RISCV.md b/docs/specs/RISCV.md index 32d0cc63fa..31f716e3c0 100644 --- a/docs/specs/RISCV.md +++ b/docs/specs/RISCV.md @@ -76,7 +76,7 @@ the guest must take care to validate all data and account for behavior in cases |--------------|-----|-------------|---------|--------|------------------------------------------------------------------------------------------------------------------------------| | nativestorew | R | 0001011 | 111 | 0x2 | Stores the 4-byte word `rs1` at address `rd` in native address space. The address `rd` must be aligned to a 4-byte boundary. | -`nativestorew` connects RV32 address space and native address space. We put it in RV32 extension because its +`nativestorew` connects RV32 address space and native address space. We put it in RV32 extension because its implementation is here. But we use `funct3 = 111` because the native extension has an available slot. ## Keccak Extension @@ -176,13 +176,16 @@ Complex extension field arithmetic over `Fp2` depends on `Fp` where `-1` is not ## Elliptic Curve Extension -The elliptic curve extension supports arithmetic over short Weierstrass curves, which requires specification of the elliptic curve `C`. The extension must be configured to support a fixed ordered list of supported curves. We use `config.curve_idx(C)` to denote the index of `C` in this list. In the list below, `idx` denotes `config.curve_idx(C)`. +The elliptic curve extension supports arithmetic over short Weierstrass curves and twisted Edwards curves, which requires specification of the elliptic curve `C`. The extension must be configured to support two fixed ordered lists of supported curves: one list of short Weierstrass curves and one list of twisted Edwards curves. Instructions prefixed with `sw_` are for short Weierstrass curves and instructions prefixed with `te_` are for twisted Edwards curves. We use `config.curve_idx(C)` to denote the index of `C` in the appropriate list. In the list below, `idx` denotes `config.curve_idx(C)`. | RISC-V Inst | FMT | opcode[6:0] | funct3 | funct7 | RISC-V description and notes | | --------------- | --- | ----------- | ------ | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | sw_add_ne\ | R | 0101011 | 001 | `idx*8` | `EcPoint([rd:2*C::COORD_SIZE]_2) = EcPoint([rs1:2*C::COORD_SIZE]_2) + EcPoint([rs2:2*C::COORD_SIZE]_2)`. Assumes that input affine points are not identity and do not have same x-coordinate. | | sw_double\ | R | 0101011 | 001 | `idx*8+1` | `EcPoint([rd:2*C::COORD_SIZE]_2) = 2 * EcPoint([rs1:2*C::COORD_SIZE]_2)`. Assumes that input affine point is not identity. `rs2` is unused and must be set to `x0`. | -| setup\ | R | 0101011 | 001 | `idx*8+2` | `assert([rs1: C::COORD_SIZE]_2 == C::MODULUS)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: 2*C::COORD_SIZE]_2`. If `ind(rs2) != 0`, then this instruction is setup for `sw_add_ne`. Otherwise it is setup for `sw_double`. When `ind(rs2) != 0` (add_ne), it is required for proper functionality that `[rs2: C::COORD_SIZE]_2 != [rs1: C::COORD_SIZE]_2`; otherwise (double), it is required that `[rs1 + C::COORD_SIZE: C::COORD_SIZE]_2 != C::Fp::ZERO` | +| sw_setup\ | R | 0101011 | 001 | `idx*8+2` | If `ind(rs2) != 0`, then this instruction is setup for `sw_add_ne`. Otherwise it is setup for `sw_double`. If setup for `sw_add_ne`, it checks `assert([rs1: C::COORD_SIZE]_2 == C::MODULUS)`, and if setup for `sw_double`, checks `assert([rs1: 2*C::COORD_SIZE]_2 == [C::MODULUS, CURVE_A])`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: 2*C::COORD_SIZE]_2`. When `ind(rs2) != 0` (add_ne), it is required for proper functionality that `[rs2: C::COORD_SIZE]_2 != [rs1: C::COORD_SIZE]_2`; otherwise (double), it is required that `[rs1 + C::COORD_SIZE: C::COORD_SIZE]_2 != C::Fp::ZERO` | +| te_add\ | R | 0101011 | 100 | `idx*8` | `EcPoint([rd:2*C::COORD_SIZE]_2) = EcPoint([rs1:2*C::COORD_SIZE]_2) + EcPoint([rs2:2*C::COORD_SIZE]_2)`. | +| te_setup\ | R | 0101011 | 100 | `idx*8+1` | `assert([rs1: 2*C::COORD_SIZE]_2 == [C::MODULUS, C::CURVE_A] && [rs2: C::COORD_SIZE]_2 == C::CURVE_D])`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: 2*C::COORD_SIZE]_2`. | + Since `funct7` is 7-bits, up to 16 curves can be supported simultaneously. We use `idx*8` to leave some room for future expansion. diff --git a/docs/specs/circuit.md b/docs/specs/circuit.md index 4238c7c27b..46412e9ebe 100644 --- a/docs/specs/circuit.md +++ b/docs/specs/circuit.md @@ -39,7 +39,7 @@ When the program is being run, in the simple scenario, the following happens at - This instruction executor returns the new execution state (and maybe reports that the execution is finished). The timestamp and program counter change accordingly. There are limitations to how many interactions/trace rows/etc. we can have in total; see [soundness criteria](https://github.com/openvm-org/stark-backend/blob/main/docs/Soundness_of_Interactions_via_LogUp.pdf). If executing the full program would lead us to overflowing these limits, the program needs to be executed in several segments. Then the process is slightly different: -- After executing an instruction, we may decide (based on `SegmentationStrategy`, which looks at the traces) to _segment_ our execution at this point. In this case, the execution will be split into several segments. +- After executing an instruction, we may decide to _segment_ our execution at this point. In this case, the execution will be split into several segments. - The timestamp resets on segmentation. - Each segment is going to be proven separately. Of course, this means that adjacent segments need to agree on some things (mainly memory state). See [Continuations](./continuations.md) for full details. diff --git a/docs/specs/continuations.md b/docs/specs/continuations.md index 87d83e5f19..233d2f3506 100644 --- a/docs/specs/continuations.md +++ b/docs/specs/continuations.md @@ -270,9 +270,9 @@ multiple accesses. Persistent memory requires three chips: the `PersistentBoundaryChip`, the `MemoryMerkleChip`, and a chip to assist in hashing, which is by default the `Poseidon2Chip`. To simplify the discussion, define constants `C` equal to the number -of field elements in a hash value, `L` where the addresses in an address space are $0..2^L$, `M` and `AS_OFFSET` where -the address spaces are `AS_OFFSET..AS_OFFSET + 2^M`, and `H = M + L - log2(C)`. `H` is the height of the Merkle tree in -the sense that the leaves are at distance `H` from the root. We define the following interactions: +of field elements in a hash value, `L` where the addresses in an address space are $0..2^L$, `M` and `ADDR_SPACE_OFFSET` +where the address spaces are `ADDR_SPACE_OFFSET..ADDR_SPACE_OFFSET + 2^M`, and `H = M + L - log2(C)`. `H` is the height +of the Merkle tree in the sense that the leaves are at distance `H` from the root. We define the following interactions: On the MERKLE_BUS, we have interactions of the form **(expand_direction: {-1, 0, 1}, height: F, labels: (F, F), hash: [F; C])**, where @@ -309,8 +309,8 @@ The `PersistentBoundaryChip` has rows of the form `(expand_direction, address_space, leaf_label, values, hash, timestamp)` and has the following interactions on the MERKLE_BUS: -- Send **(1, 0, (as - AS_OFFSET) \* 2^L, node\*label, hash_initial)** -- Receive **(-1, 0, (as - AS_OFFSET) \* 2^L, node_label, hash_final)** +- Send **(1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node\*label, hash_initial)** +- Receive **(-1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node_label, hash_final)** It receives `values` from the `MEMORY_BUS` and constrains `hash = compress(values, 0)` via the `POSEIDON2_DIRECT_BUS`. The aggregation program takes a variable number of consecutive segment proofs and consolidates them into a single proof diff --git a/docs/specs/memory.md b/docs/specs/memory.md index 3d8ea1dc1d..87a7ed1316 100644 --- a/docs/specs/memory.md +++ b/docs/specs/memory.md @@ -163,7 +163,7 @@ Both boundary chips perform, for every subsegment ever existed in our nice set, The following invariants **must** be maintained by the memory architecture: 1. In the MEMORY_BUS, the `timestamp` is always in range `[0, 2^timestamp_max_bits)` where `timestamp_max_bits <= F::bits() - 2` is a configuration constant. -2. In the MEMORY_BUS, the `address_space` is always in range `[0, 1 + 2^as_height)` where `as_height` is a configuration constant satisfying `as_height < F::bits() - 2`. (Our current implementation only supports `as_height` less than the max bits supported by the VariableRangeCheckerBus). +2. In the MEMORY_BUS, the `address_space` is always in range `[0, 1 + 2^addr_space_height)` where `addr_space_height` is a configuration constant satisfying `addr_space_height < F::bits() - 2`. (Our current implementation only supports `addr_space_height` less than the max bits supported by the VariableRangeCheckerBus). 3. In the MEMORY_BUS, the `pointer` is always in range `[0, 2^pointer_max_bits)` where `pointer_max_bits <= F::bits() - 2` is a configuration constant. Invariant 1 is guaranteed by [time goes forward](#time-goes-forward) under the [assumption](./circuit.md#instruction-executors) that the timestamp increase during instruction execution is bounded by the number of AIR interactions. diff --git a/docs/specs/transpiler.md b/docs/specs/transpiler.md index fded65b6d8..b2e585a1e8 100644 --- a/docs/specs/transpiler.md +++ b/docs/specs/transpiler.md @@ -205,7 +205,9 @@ Each VM extension's behavior is specified below. | --------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | | sw_add_ne\ | EC_ADD_NE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` | | sw_double\ | EC_DOUBLE_RV32\ `ind(rd), ind(rs1), 0, 1, 2` | -| setup\ | SETUP_EC_ADD_NE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) != 0`, SETUP_EC_DOUBLE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) = 0` | +| sw_setup\ | SETUP_EC_ADD_NE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) != 0`, SETUP_EC_DOUBLE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) = 0` | +| te_add\ | TE_ADD_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| te_setup\ | SETUP_TE_ADD_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` | ### Pairing Extension diff --git a/examples/algebra/Cargo.toml b/examples/algebra/Cargo.toml index 9d369431c8..f9a0350c9f 100644 --- a/examples/algebra/Cargo.toml +++ b/examples/algebra/Cargo.toml @@ -17,3 +17,8 @@ num-bigint = { version = "0.4.6", features = ["serde"] } [features] default = [] + +# remove this if copying example outside of monorepo +[patch."https://github.com/openvm-org/openvm.git"] +openvm = { path = "../../crates/toolchain/openvm" } +openvm-algebra-guest = { path = "../../extensions/algebra/guest" } diff --git a/examples/algebra/openvm/app.vmexe b/examples/algebra/openvm/app.vmexe new file mode 100644 index 0000000000..801ce82638 Binary files /dev/null and b/examples/algebra/openvm/app.vmexe differ diff --git a/examples/algebra/openvm_init.rs b/examples/algebra/openvm_init.rs index 1de98b97a1..75324d5509 100644 --- a/examples/algebra/openvm_init.rs +++ b/examples/algebra/openvm_init.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "998244353", "1000000007" } -openvm_algebra_guest::complex_macros::complex_init! { Complex1 { mod_idx = 0 }, Complex2 { mod_idx = 1 } } +openvm_algebra_guest::complex_macros::complex_init! { "Complex1" { mod_idx = 0 }, "Complex2" { mod_idx = 1 } } diff --git a/examples/algebra/src/main.rs b/examples/algebra/src/main.rs index e2265a2c97..3a45d1c463 100644 --- a/examples/algebra/src/main.rs +++ b/examples/algebra/src/main.rs @@ -28,7 +28,7 @@ moduli_init! { // The order of these structs does not matter, // given that we specify the `mod_idx` parameters properly. openvm_algebra_complex_macros::complex_init! { - Complex1 { mod_idx = 0 }, Complex2 { mod_idx = 1 }, + "Complex1" { mod_idx = 0 }, "Complex2" { mod_idx = 1 }, } */ diff --git a/examples/ecc/Cargo.toml b/examples/ecc/Cargo.toml index 3e0dcdbcfc..0206cba6a0 100644 --- a/examples/ecc/Cargo.toml +++ b/examples/ecc/Cargo.toml @@ -11,9 +11,11 @@ openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ "std", ] } openvm-algebra-guest = { git = "https://github.com/openvm-org/openvm.git" } -openvm-ecc-guest = { git = "https://github.com/openvm-org/openvm.git" } +openvm-ecc-guest = { git = "https://github.com/openvm-org/openvm.git", features = ["ed25519"]} openvm-k256 = { git = "https://github.com/openvm-org/openvm.git", package = "k256" } hex-literal = { version = "0.4.1", default-features = false } +serde = { version = "1.0", default-features = false, features = [ "derive" ] } +num-bigint = { version = "0.4.6", default-features = false } [features] default = [] diff --git a/examples/ecc/openvm.toml b/examples/ecc/openvm.toml index 1dc6cf25f2..db8e420efc 100644 --- a/examples/ecc/openvm.toml +++ b/examples/ecc/openvm.toml @@ -2,11 +2,22 @@ [app_vm_config.rv32m] [app_vm_config.io] [app_vm_config.modular] -supported_moduli = ["115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337"] +supported_moduli = ["115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "57896044618658097711785492504343953926634992332820282019728792003956564819949"] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" -b = "7" \ No newline at end of file +b = "7" + +[[app_vm_config.ecc.supported_te_curves]] +struct_name = "Ed25519Point" +modulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949" +scalar = "7237005577332262213973186563042994240857116359379907606001950938285454250989" + +[app_vm_config.ecc.supported_te_curves.coeffs] +a = "57896044618658097711785492504343953926634992332820282019728792003956564819948" +d = "37095705934669439343138083508754565189542113879843219016388785533085940283555" diff --git a/examples/ecc/openvm/app.vmexe b/examples/ecc/openvm/app.vmexe new file mode 100644 index 0000000000..910f3a4efd Binary files /dev/null and b/examples/ecc/openvm/app.vmexe differ diff --git a/examples/ecc/openvm_init.rs b/examples/ecc/openvm_init.rs index bec9f527e9..a2ffd7cabd 100644 --- a/examples/ecc/openvm_init.rs +++ b/examples/ecc/openvm_init.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. -openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "57896044618658097711785492504343953926634992332820282019728792003956564819949" } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! { "Ed25519Point" } diff --git a/examples/ecc/src/main.rs b/examples/ecc/src/main.rs index f95b6272ad..7e0f0817e0 100644 --- a/examples/ecc/src/main.rs +++ b/examples/ecc/src/main.rs @@ -1,7 +1,11 @@ // ANCHOR: imports use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::weierstrass::WeierstrassPoint; +use openvm_ecc_guest::{ + algebra::IntMod, + ed25519::{Ed25519Coord, Ed25519Point}, + edwards::TwistedEdwardsPoint, + weierstrass::WeierstrassPoint, +}; use openvm_k256::{Secp256k1Coord, Secp256k1Point}; // ANCHOR_END: imports @@ -9,13 +13,11 @@ use openvm_k256::{Secp256k1Coord, Secp256k1Point}; openvm::init!(); /* The init! macro will expand to the following openvm_algebra_guest::moduli_macros::moduli_init! { - "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F", - "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141" -} - -openvm_ecc_guest::sw_macros::sw_init! { - Secp256k1Point, +"115792089237316195423570985008687907853269984665640564039457584007908834671663", +"115792089237316195423570985008687907852837564279074904382605163141518161494337" } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! { "Ed25519Point" } */ // ANCHOR_END: init @@ -35,5 +37,22 @@ pub fn main() { #[allow(clippy::op_ref)] let _p3 = &p1 + &p2; + + let x1 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "216936D3CD6E53FEC0A4E231FDD6DC5C692CC7609525A7B2C9562D608F25D51A" + )); + let y1 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "6666666666666666666666666666666666666666666666666666666666666658" + )); + let p1 = Ed25519Point::from_xy(x1, y1).unwrap(); + + let x2 = Ed25519Coord::from_u32(2); + let y2 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "1A43BF127BDDC4D71FF910403C11DDB5BA2BCDD2815393924657EF111E712631" + )); + let p2 = Ed25519Point::from_xy(x2, y2).unwrap(); + + #[allow(clippy::op_ref)] + let _p3 = &p1 + &p2; } // ANCHOR_END: main diff --git a/examples/i256/openvm/app.vmexe b/examples/i256/openvm/app.vmexe new file mode 100644 index 0000000000..e45a699ef3 Binary files /dev/null and b/examples/i256/openvm/app.vmexe differ diff --git a/examples/i256/src/main.rs b/examples/i256/src/main.rs index 8f008f40a0..ec911bc1cd 100644 --- a/examples/i256/src/main.rs +++ b/examples/i256/src/main.rs @@ -1,4 +1,6 @@ #![allow(clippy::needless_range_loop)] +openvm::entry!(main); + use core::array; use alloy_primitives::I256; diff --git a/examples/keccak/Cargo.toml b/examples/keccak/Cargo.toml index 74f15e6234..3c5cd8a26e 100644 --- a/examples/keccak/Cargo.toml +++ b/examples/keccak/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" members = [] [dependencies] -openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ +openvm = { git = "https://github.com/openvm-org/openvm.git", branch = "develop", features = [ "std", ] } openvm-keccak256 = { git = "https://github.com/openvm-org/openvm.git" } diff --git a/examples/keccak/src/main.rs b/examples/keccak/src/main.rs index 7b98d36ed1..0d138d5694 100644 --- a/examples/keccak/src/main.rs +++ b/examples/keccak/src/main.rs @@ -1,3 +1,5 @@ +openvm::entry!(main); + // ANCHOR: imports use core::hint::black_box; diff --git a/examples/pairing/openvm_init.rs b/examples/pairing/openvm_init.rs index 991d1237fc..f25d37c998 100644 --- a/examples/pairing/openvm_init.rs +++ b/examples/pairing/openvm_init.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } diff --git a/examples/pairing/src/main.rs b/examples/pairing/src/main.rs index 1a85b2e271..681527ca99 100644 --- a/examples/pairing/src/main.rs +++ b/examples/pairing/src/main.rs @@ -19,7 +19,7 @@ openvm_algebra_moduli_macros::moduli_init! { } openvm_algebra_complex_macros::complex_init! { - Bls12_381Fp2 { mod_idx = 0 }, + "Bls12_381Fp2" { mod_idx = 0 }, } */ // ANCHOR_END: init diff --git a/examples/sha256/src/main.rs b/examples/sha256/src/main.rs index a6195390a4..6389aaa1dc 100644 --- a/examples/sha256/src/main.rs +++ b/examples/sha256/src/main.rs @@ -1,3 +1,5 @@ +openvm::entry!(main); + // ANCHOR: imports use core::hint::black_box; diff --git a/examples/u256/src/main.rs b/examples/u256/src/main.rs index 75b80afd3d..05319a2a17 100644 --- a/examples/u256/src/main.rs +++ b/examples/u256/src/main.rs @@ -1,4 +1,6 @@ #![allow(clippy::needless_range_loop)] +openvm::entry!(main); + use core::array; use openvm_ruint::aliases::U256; diff --git a/extensions/algebra/circuit/Cargo.toml b/extensions/algebra/circuit/Cargo.toml index 258bff450b..1df5dd0107 100644 --- a/extensions/algebra/circuit/Cargo.toml +++ b/extensions/algebra/circuit/Cargo.toml @@ -20,21 +20,24 @@ openvm-rv32im-circuit = { workspace = true } openvm-rv32-adapters = { workspace = true } openvm-algebra-transpiler = { workspace = true } +halo2curves-axiom = { workspace = true } itertools = { workspace = true } num-bigint = { workspace = true, features = ["serde"] } num-traits = { workspace = true } rand = { workspace = true } -derive_more = { workspace = true, features = ["from"] } +derive_more = { workspace = true, features = ["from", "deref", "deref_mut"] } strum = { workspace = true } derive-new = { workspace = true } serde.workspace = true serde_with = { workspace = true } -serde-big-array = { workspace = true } eyre = { workspace = true } [dev-dependencies] -halo2curves-axiom = { workspace = true } openvm-mod-circuit-builder = { workspace = true, features = ["test-utils"] } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } openvm-pairing-guest = { workspace = true, features = ["halo2curves"] } +test-case = { workspace = true } + +[package.metadata.cargo-shear] +ignored = ["derive_more"] diff --git a/extensions/algebra/circuit/src/config.rs b/extensions/algebra/circuit/src/config.rs index 5b43163b77..fbe580261d 100644 --- a/extensions/algebra/circuit/src/config.rs +++ b/extensions/algebra/circuit/src/config.rs @@ -1,15 +1,33 @@ +use std::result::Result; + use num_bigint::BigUint; -use openvm_circuit::arch::{InitFileGenerator, SystemConfig}; +use openvm_circuit::{ + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, + VmBuilder, VmChipComplex, VmProverExtension, + }, + system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, +}; use openvm_circuit_derive::VmConfig; -use openvm_rv32im_circuit::*; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; use serde::{Deserialize, Serialize}; -use super::*; +use crate::{ + AlgebraCpuProverExt, Fp2Extension, Fp2ExtensionExecutor, ModularExtension, + ModularExtensionExecutor, +}; #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] pub struct Rv32ModularConfig { - #[system] + #[config(executor = "SystemExecutor")] pub system: SystemConfig, #[extension] pub base: Rv32I, @@ -44,16 +62,8 @@ impl Rv32ModularConfig { #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] pub struct Rv32ModularWithFp2Config { - #[system] - pub system: SystemConfig, - #[extension] - pub base: Rv32I, - #[extension] - pub mul: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub modular: ModularExtension, + #[config(generics = true)] + pub modular: Rv32ModularConfig, #[extension] pub fp2: Fp2Extension, } @@ -65,11 +75,7 @@ impl Rv32ModularWithFp2Config { .map(|(_, modulus)| modulus.clone()) .collect(); Self { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(moduli), + modular: Rv32ModularConfig::new(moduli), fp2: Fp2Extension::new(moduli_with_names), } } @@ -79,8 +85,73 @@ impl InitFileGenerator for Rv32ModularWithFp2Config { fn generate_init_file_contents(&self) -> Option { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", - self.modular.generate_moduli_init(), - self.fp2.generate_complex_init(&self.modular) + self.modular.modular.generate_moduli_init(), + self.fp2.generate_complex_init(&self.modular.modular) )) } } + +#[derive(Clone)] +pub struct Rv32ModularCpuBuilder; + +impl VmBuilder for Rv32ModularCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Rv32ModularConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Rv32ModularConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.base, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.mul, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover( + &AlgebraCpuProverExt, + &config.modular, + inventory, + )?; + Ok(chip_complex) + } +} + +#[derive(Clone)] +pub struct Rv32ModularWithFp2CpuBuilder; + +impl VmBuilder for Rv32ModularWithFp2CpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Rv32ModularWithFp2Config; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Rv32ModularWithFp2Config, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&Rv32ModularCpuBuilder, &config.modular, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&AlgebraCpuProverExt, &config.fp2, inventory)?; + Ok(chip_complex) + } +} diff --git a/extensions/algebra/circuit/src/fields.rs b/extensions/algebra/circuit/src/fields.rs new file mode 100644 index 0000000000..fdba871da7 --- /dev/null +++ b/extensions/algebra/circuit/src/fields.rs @@ -0,0 +1,376 @@ +use halo2curves_axiom::ff::PrimeField; +use num_bigint::BigUint; +use num_traits::Num; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FieldType { + K256Coordinate = 0, + K256Scalar = 1, + P256Coordinate = 2, + P256Scalar = 3, + BN254Coordinate = 4, + BN254Scalar = 5, + BLS12_381Coordinate = 6, + BLS12_381Scalar = 7, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Operation { + Add = 0, + Sub = 1, + Mul = 2, + Div = 3, +} + +fn get_modulus_as_bigint() -> BigUint { + BigUint::from_str_radix(F::MODULUS.trim_start_matches("0x"), 16).unwrap() +} + +pub fn get_field_type(modulus: &BigUint) -> Option { + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::K256Coordinate); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::K256Scalar); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::P256Coordinate); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::P256Scalar); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::BN254Coordinate); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::BN254Scalar); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::BLS12_381Coordinate); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::BLS12_381Scalar); + } + + None +} + +pub fn get_fp2_field_type(modulus: &BigUint) -> Option { + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::BN254Coordinate); + } + + if modulus == &get_modulus_as_bigint::() { + return Some(FieldType::BLS12_381Coordinate); + } + + None +} + +#[inline(always)] +pub fn field_operation< + const FIELD: u8, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const OP: u8, +>( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + match FIELD { + x if x == FieldType::K256Coordinate as u8 => { + field_operation_256bit::( + input_data, + ) + } + x if x == FieldType::K256Scalar as u8 => { + field_operation_256bit::( + input_data, + ) + } + x if x == FieldType::P256Coordinate as u8 => { + field_operation_256bit::( + input_data, + ) + } + x if x == FieldType::P256Scalar as u8 => { + field_operation_256bit::( + input_data, + ) + } + x if x == FieldType::BN254Coordinate as u8 => { + field_operation_256bit::( + input_data, + ) + } + x if x == FieldType::BN254Scalar as u8 => { + field_operation_256bit::( + input_data, + ) + } + x if x == FieldType::BLS12_381Coordinate as u8 => { + field_operation_bls12_381_coordinate::(input_data) + } + x if x == FieldType::BLS12_381Scalar as u8 => { + field_operation_256bit::( + input_data, + ) + } + _ => panic!("Unsupported field type: {}", FIELD), + } +} + +#[inline(always)] +pub fn fp2_operation< + const FIELD: u8, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const OP: u8, +>( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + match FIELD { + x if x == FieldType::BN254Coordinate as u8 => { + fp2_operation_bn254::(input_data) + } + x if x == FieldType::BLS12_381Coordinate as u8 => { + fp2_operation_bls12_381::(input_data) + } + _ => panic!("Unsupported field type for Fp2: {}", FIELD), + } +} + +#[inline(always)] +fn field_operation_256bit< + F: PrimeField, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const OP: u8, +>( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + let a = blocks_to_field_element::(input_data[0].as_flattened()); + let b = blocks_to_field_element::(input_data[1].as_flattened()); + let c = match OP { + x if x == Operation::Add as u8 => a + b, + x if x == Operation::Sub as u8 => a - b, + x if x == Operation::Mul as u8 => a * b, + x if x == Operation::Div as u8 => a * b.invert().unwrap(), + _ => panic!("Unsupported operation: {}", OP), + }; + + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + field_element_to_blocks(&c, &mut output); + output +} + +#[inline(always)] +fn field_operation_bls12_381_coordinate< + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const OP: u8, +>( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + let a = blocks_to_field_element_bls12_381_coordinate(input_data[0].as_flattened()); + let b = blocks_to_field_element_bls12_381_coordinate(input_data[1].as_flattened()); + let c = match OP { + x if x == Operation::Add as u8 => a + b, + x if x == Operation::Sub as u8 => a - b, + x if x == Operation::Mul as u8 => a * b, + x if x == Operation::Div as u8 => a * b.invert().unwrap(), + _ => panic!("Unsupported operation: {}", OP), + }; + + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + field_element_to_blocks_bls12_381_coordinate(&c, &mut output); + output +} + +#[inline(always)] +fn fp2_operation_bn254( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + let a = blocks_to_fp2_bn254::(input_data[0].as_ref()); + let b = blocks_to_fp2_bn254::(input_data[1].as_ref()); + let c = match OP { + x if x == Operation::Add as u8 => a + b, + x if x == Operation::Sub as u8 => a - b, + x if x == Operation::Mul as u8 => a * b, + x if x == Operation::Div as u8 => a * b.invert().unwrap(), + _ => panic!("Unsupported operation: {}", OP), + }; + + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + fp2_to_blocks_bn254(&c, &mut output); + output +} + +#[inline(always)] +fn fp2_operation_bls12_381( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + let a = blocks_to_fp2_bls12_381::(input_data[0].as_ref()); + let b = blocks_to_fp2_bls12_381::(input_data[1].as_ref()); + let c = match OP { + x if x == Operation::Add as u8 => a + b, + x if x == Operation::Sub as u8 => a - b, + x if x == Operation::Mul as u8 => a * b, + x if x == Operation::Div as u8 => a * b.invert().unwrap(), + _ => panic!("Unsupported operation: {}", OP), + }; + + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + fp2_to_blocks_bls12_381(&c, &mut output); + output +} + +#[inline(always)] +fn from_repr_with_reduction>(bytes: [u8; 32]) -> F { + F::from_repr_vartime(bytes).unwrap_or_else(|| { + // Reduce modulo the field's modulus for non-canonical representations + let modulus = get_modulus_as_bigint::(); + let value = BigUint::from_bytes_le(&bytes); + let reduced = value % modulus; + + let reduced_le_bytes = reduced.to_bytes_le(); + let mut reduced_bytes = [0u8; 32]; + reduced_bytes[..reduced_le_bytes.len()] + .copy_from_slice(&reduced_le_bytes[..reduced_le_bytes.len()]); + + F::from_repr_vartime(reduced_bytes).unwrap() + }) +} + +#[inline(always)] +fn from_repr_with_reduction_bls12_381_coordinate( + bytes: [u8; 48], +) -> halo2curves_axiom::bls12_381::Fq { + halo2curves_axiom::bls12_381::Fq::from_bytes(&bytes).unwrap_or_else(|| { + // Reduce modulo the field's modulus for non-canonical representations + let modulus = get_modulus_as_bigint::(); + let value = BigUint::from_bytes_le(&bytes); + let reduced = value % modulus; + + let reduced_le_bytes = reduced.to_bytes_le(); + let mut reduced_bytes = [0u8; 48]; + reduced_bytes[..reduced_le_bytes.len()] + .copy_from_slice(&reduced_le_bytes[..reduced_le_bytes.len()]); + + halo2curves_axiom::bls12_381::Fq::from_bytes(&reduced_bytes).unwrap() + }) +} + +#[inline(always)] +pub fn blocks_to_field_element>(blocks: &[u8]) -> F { + debug_assert!(blocks.len() == 32); + let mut bytes = [0u8; 32]; + bytes[..blocks.len()].copy_from_slice(&blocks[..blocks.len()]); + + from_repr_with_reduction::(bytes) +} + +#[inline(always)] +pub fn field_element_to_blocks, const BLOCK_SIZE: usize>( + field_element: &F, + output: &mut [[u8; BLOCK_SIZE]], +) { + debug_assert!(output.len() * BLOCK_SIZE == 32); + let bytes = field_element.to_repr(); + let mut byte_idx = 0; + + for block in output.iter_mut() { + for byte in block.iter_mut() { + *byte = if byte_idx < bytes.len() { + bytes[byte_idx] + } else { + 0 + }; + byte_idx += 1; + } + } +} + +#[inline(always)] +pub fn blocks_to_field_element_bls12_381_coordinate( + blocks: &[u8], +) -> halo2curves_axiom::bls12_381::Fq { + debug_assert!(blocks.len() == 48); + let mut bytes = [0u8; 48]; + bytes[..blocks.len()].copy_from_slice(&blocks[..blocks.len()]); + + from_repr_with_reduction_bls12_381_coordinate(bytes) +} + +#[inline(always)] +pub fn field_element_to_blocks_bls12_381_coordinate( + field_element: &halo2curves_axiom::bls12_381::Fq, + output: &mut [[u8; BLOCK_SIZE]], +) { + debug_assert!(output.len() * BLOCK_SIZE == 48); + let bytes = field_element.to_bytes(); + let mut byte_idx = 0; + + for block in output.iter_mut() { + for byte in block.iter_mut() { + *byte = if byte_idx < bytes.len() { + bytes[byte_idx] + } else { + 0 + }; + byte_idx += 1; + } + } +} + +#[inline(always)] +fn blocks_to_fp2_bn254( + blocks: &[[u8; BLOCK_SIZE]], +) -> halo2curves_axiom::bn256::Fq2 { + let c0 = blocks_to_field_element::( + blocks[..BLOCKS / 2].as_flattened(), + ); + let c1 = blocks_to_field_element::( + blocks[BLOCKS / 2..].as_flattened(), + ); + halo2curves_axiom::bn256::Fq2::new(c0, c1) +} + +#[inline(always)] +fn fp2_to_blocks_bn254( + fp2: &halo2curves_axiom::bn256::Fq2, + output: &mut [[u8; BLOCK_SIZE]; BLOCKS], +) { + field_element_to_blocks::( + &fp2.c0, + &mut output[..BLOCKS / 2], + ); + field_element_to_blocks::( + &fp2.c1, + &mut output[BLOCKS / 2..], + ); +} + +#[inline(always)] +fn blocks_to_fp2_bls12_381( + blocks: &[[u8; BLOCK_SIZE]], +) -> halo2curves_axiom::bls12_381::Fq2 { + let c0 = blocks_to_field_element_bls12_381_coordinate(blocks[..BLOCKS / 2].as_flattened()); + let c1 = blocks_to_field_element_bls12_381_coordinate(blocks[BLOCKS / 2..].as_flattened()); + halo2curves_axiom::bls12_381::Fq2 { c0, c1 } +} + +#[inline(always)] +fn fp2_to_blocks_bls12_381( + fp2: &halo2curves_axiom::bls12_381::Fq2, + output: &mut [[u8; BLOCK_SIZE]; BLOCKS], +) { + field_element_to_blocks_bls12_381_coordinate(&fp2.c0, &mut output[..BLOCKS / 2]); + field_element_to_blocks_bls12_381_coordinate(&fp2.c1, &mut output[BLOCKS / 2..]); +} diff --git a/extensions/algebra/circuit/src/fp2_chip/addsub.rs b/extensions/algebra/circuit/src/fp2_chip/addsub.rs index 4eca1ad102..4fd1510fc9 100644 --- a/extensions/algebra/circuit/src/fp2_chip/addsub.rs +++ b/extensions/algebra/circuit/src/fp2_chip/addsub.rs @@ -1,62 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Fp2Opcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldExpressionExecutor, + FieldExpressionFiller, +}; +use openvm_rv32_adapters::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; -use crate::Fp2; - -// Input: Fp2 * 2 -// Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp2AddSubChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp2AddSubChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Fp2Opcode::ADD as usize, - Fp2Opcode::SUB as usize, - Fp2Opcode::SETUP_ADDSUB as usize, - ], - vec![is_add_flag, is_sub_flag], - range_checker, - "Fp2AddSub", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} +use super::{Fp2Air, Fp2Chip, Fp2Executor}; +use crate::{FieldExprVecHeapExecutor, Fp2}; pub fn fp2_addsub_expr( config: ExprBuilderConfig, @@ -85,15 +48,98 @@ pub fn fp2_addsub_expr( ) } +// Input: Fp2 * 2 +// Output: Fp2 +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, +) -> (FieldExpr, Vec, Vec) { + let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(config, range_checker_bus); + + let local_opcode_idx = vec![ + Fp2Opcode::ADD as usize, + Fp2Opcode::SUB as usize, + Fp2Opcode::SETUP_ADDSUB as usize, + ]; + let opcode_flag_idx = vec![is_add_flag, is_sub_flag]; + + (expr, local_opcode_idx, opcode_flag_idx) +} + +pub fn get_fp2_addsub_air( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, +) -> Fp2Air { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); + Fp2Air::new( + Rv32VecHeapAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx), + ) +} + +pub fn get_fp2_addsub_step( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, +) -> Fp2Executor { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); + + FieldExprVecHeapExecutor(FieldExpressionExecutor::new( + Rv32VecHeapAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + "Fp2AddSub", + )) +} + +pub fn get_fp2_addsub_chip( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, +) -> Fp2Chip { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker.bus()); + Fp2Chip::new( + FieldExpressionFiller::new( + Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), + expr, + local_opcode_idx, + opcode_flag_idx, + range_checker, + false, + ), + mem_helper, + ) +} + #[cfg(test)] mod tests { + use std::sync::Arc; + use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; + use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; + use openvm_circuit::arch::testing::{ + TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, + }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupBus, BitwiseOperationLookupChip, }; use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; use openvm_mod_circuit_builder::{ @@ -101,63 +147,49 @@ mod tests { ExprBuilderConfig, }; use openvm_pairing_guest::bn254::BN254_MODULUS; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; + use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; - use super::Fp2AddSubChip; + use crate::fp2_chip::{ + get_fp2_addsub_air, get_fp2_addsub_chip, get_fp2_addsub_step, Fp2Air, Fp2Chip, Fp2Executor, + }; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; + const MAX_INS_CAPACITY: usize = 128; + const OFFSET: usize = Fp2Opcode::CLASS_OFFSET; type F = BabyBear; + type Harness = TestChipHarness< + F, + Fp2Executor<2, NUM_LIMBS>, + Fp2Air<2, NUM_LIMBS>, + Fp2Chip, + >; - #[test] - fn test_fp2_addsub() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let modulus = BN254_MODULUS.clone(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = Fp2AddSubChip::new( - adapter, - config, - Fp2Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(42); + fn set_and_execute_rand( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + modulus: &BigUint, + ) { + let mut rng = create_seeded_rng(); let x = Fq2::random(&mut rng); let y = Fq2::random(&mut rng); let inputs = [x.c0, x.c1, y.c0, y.c1].map(bn254_fq_to_biguint); let expected_sum = bn254_fq2_to_biguint_vec(x + y); - let r_sum = chip - .0 - .core - .expr() + let r_sum = harness + .executor + .expr .execute_with_output(inputs.to_vec(), vec![true, false]); assert_eq!(r_sum.len(), 2); assert_eq!(r_sum[0], expected_sum[0]); assert_eq!(r_sum[1], expected_sum[1]); let expected_sub = bn254_fq2_to_biguint_vec(x - y); - let r_sub = chip - .0 - .core - .expr() + let r_sub = harness + .executor + .expr .execute_with_output(inputs.to_vec(), vec![false, true]); assert_eq!(r_sub.len(), 2); assert_eq!(r_sub[0], expected_sub[0]); @@ -177,31 +209,76 @@ mod tests { .map(BabyBear::from_canonical_u32) }) .collect_vec(); - let modulus = - biguint_to_limbs::(modulus, LIMB_BITS).map(BabyBear::from_canonical_u32); + let modulus = biguint_to_limbs::(modulus.clone(), LIMB_BITS) + .map(BabyBear::from_canonical_u32); let zero = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( - &mut tester, + tester, vec![modulus, zero], vec![zero; 2], - chip.0.core.air.offset + Fp2Opcode::SETUP_ADDSUB as usize, + OFFSET + Fp2Opcode::SETUP_ADDSUB as usize, ); let instruction1 = rv32_write_heap_default( - &mut tester, + tester, x_limbs.clone(), y_limbs.clone(), - chip.0.core.air.offset + Fp2Opcode::ADD as usize, + OFFSET + Fp2Opcode::ADD as usize, + ); + let instruction2 = + rv32_write_heap_default(tester, x_limbs, y_limbs, OFFSET + Fp2Opcode::SUB as usize); + + tester.execute(harness, &setup_instruction); + tester.execute(harness, &instruction1); + tester.execute(harness, &instruction2); + } + + #[test] + fn test_fp2_addsub() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let modulus = BN254_MODULUS.clone(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = get_fp2_addsub_air( + tester.execution_bridge(), + tester.memory_bridge(), + config.clone(), + tester.range_checker().bus(), + bitwise_bus, + tester.address_bits(), + OFFSET, + ); + let executor = get_fp2_addsub_step( + config.clone(), + tester.range_checker().bus(), + tester.address_bits(), + OFFSET, ); - let instruction2 = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.0.core.air.offset + Fp2Opcode::SUB as usize, + let chip = get_fp2_addsub_chip( + config, + tester.memory_helper(), + tester.range_checker(), + bitwise_chip.clone(), + tester.address_bits(), ); - tester.execute(&mut chip, &setup_instruction); - tester.execute(&mut chip, &instruction1); - tester.execute(&mut chip, &instruction2); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let mut harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + let num_ops = 10; + for _ in 0..num_ops { + set_and_execute_rand(&mut tester, &mut harness, &modulus); + } + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); tester.simple_test().expect("Verification failed"); } } diff --git a/extensions/algebra/circuit/src/fp2_chip/mod.rs b/extensions/algebra/circuit/src/fp2_chip/mod.rs index cd316fd70c..6a44aa96b3 100644 --- a/extensions/algebra/circuit/src/fp2_chip/mod.rs +++ b/extensions/algebra/circuit/src/fp2_chip/mod.rs @@ -1,5 +1,24 @@ +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_mod_circuit_builder::{FieldExpressionCoreAir, FieldExpressionFiller}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterFiller}; + +use crate::FieldExprVecHeapExecutor; + mod addsub; pub use addsub::*; mod muldiv; pub use muldiv::*; + +pub type Fp2Air = VmAirWrapper< + Rv32VecHeapAdapterAir<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>, + FieldExpressionCoreAir, +>; + +pub type Fp2Executor = + FieldExprVecHeapExecutor; + +pub type Fp2Chip = VmChipWrapper< + F, + FieldExpressionFiller>, +>; diff --git a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs index 83ef9565f3..5aed4941f4 100644 --- a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs @@ -1,62 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Fp2Opcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, SymbolicExpr, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldExpressionExecutor, + FieldExpressionFiller, SymbolicExpr, +}; +use openvm_rv32_adapters::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; -use crate::Fp2; - -// Input: Fp2 * 2 -// Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp2MulDivChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp2MulDivChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Fp2Opcode::MUL as usize, - Fp2Opcode::DIV as usize, - Fp2Opcode::SETUP_MULDIV as usize, - ], - vec![is_mul_flag, is_div_flag], - range_checker, - "Fp2MulDiv", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} +use super::{Fp2Air, Fp2Chip, Fp2Executor}; +use crate::{FieldExprVecHeapExecutor, Fp2}; pub fn fp2_muldiv_expr( config: ExprBuilderConfig, @@ -124,15 +87,99 @@ pub fn fp2_muldiv_expr( ) } +// Input: Fp2 * 2 +// Output: Fp2 + +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, +) -> (FieldExpr, Vec, Vec) { + let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker_bus); + + let local_opcode_idx = vec![ + Fp2Opcode::MUL as usize, + Fp2Opcode::DIV as usize, + Fp2Opcode::SETUP_MULDIV as usize, + ]; + let opcode_flag_idx = vec![is_mul_flag, is_div_flag]; + + (expr, local_opcode_idx, opcode_flag_idx) +} + +pub fn get_fp2_muldiv_air( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, +) -> Fp2Air { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); + Fp2Air::new( + Rv32VecHeapAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx), + ) +} + +pub fn get_fp2_muldiv_step( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, +) -> Fp2Executor { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); + + FieldExprVecHeapExecutor(FieldExpressionExecutor::new( + Rv32VecHeapAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + "Fp2MulDiv", + )) +} + +pub fn get_fp2_muldiv_chip( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, +) -> Fp2Chip { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker.bus()); + Fp2Chip::new( + FieldExpressionFiller::new( + Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), + expr, + local_opcode_idx, + opcode_flag_idx, + range_checker, + false, + ), + mem_helper, + ) +} + #[cfg(test)] mod tests { + use std::sync::Arc; + use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; + use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; + use openvm_circuit::arch::testing::{ + TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, + }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupBus, BitwiseOperationLookupChip, }; use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; use openvm_mod_circuit_builder::{ @@ -140,68 +187,49 @@ mod tests { ExprBuilderConfig, }; use openvm_pairing_guest::bn254::BN254_MODULUS; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; + use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; - use super::Fp2MulDivChip; + use crate::fp2_chip::{ + get_fp2_muldiv_air, get_fp2_muldiv_chip, get_fp2_muldiv_step, Fp2Air, Fp2Chip, Fp2Executor, + }; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; + const OFFSET: usize = Fp2Opcode::CLASS_OFFSET; + const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; + type Harness = TestChipHarness< + F, + Fp2Executor<2, NUM_LIMBS>, + Fp2Air<2, NUM_LIMBS>, + Fp2Chip, + >; - #[test] - fn test_fp2_muldiv() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let modulus = BN254_MODULUS.clone(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = Fp2MulDivChip::new( - adapter, - config, - Fp2Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!( - chip.0.core.expr().builder.num_variables, - 2, - "Fp2MulDiv should only introduce new z Fp2 variable (2 Fp var)" - ); - - let mut rng = StdRng::seed_from_u64(42); + fn set_and_execute_rand( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + modulus: &BigUint, + ) { + let mut rng = create_seeded_rng(); let x = Fq2::random(&mut rng); let y = Fq2::random(&mut rng); let inputs = [x.c0, x.c1, y.c0, y.c1].map(bn254_fq_to_biguint); let expected_mul = bn254_fq2_to_biguint_vec(x * y); - let r_mul = chip - .0 - .core - .expr() + let r_mul = harness + .executor + .expr .execute_with_output(inputs.to_vec(), vec![true, false]); assert_eq!(r_mul.len(), 2); assert_eq!(r_mul[0], expected_mul[0]); assert_eq!(r_mul[1], expected_mul[1]); let expected_div = bn254_fq2_to_biguint_vec(x * y.invert().unwrap()); - let r_div = chip - .0 - .core - .expr() + let r_div = harness + .executor + .expr .execute_with_output(inputs.to_vec(), vec![false, true]); assert_eq!(r_div.len(), 2); assert_eq!(r_div[0], expected_div[0]); @@ -221,31 +249,81 @@ mod tests { .map(BabyBear::from_canonical_u32) }) .collect_vec(); - let modulus = - biguint_to_limbs::(modulus, LIMB_BITS).map(BabyBear::from_canonical_u32); + let modulus = biguint_to_limbs::(modulus.clone(), LIMB_BITS) + .map(BabyBear::from_canonical_u32); let zero = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( - &mut tester, + tester, vec![modulus, zero], vec![zero; 2], - chip.0.core.air.offset + Fp2Opcode::SETUP_MULDIV as usize, + OFFSET + Fp2Opcode::SETUP_MULDIV as usize, ); let instruction1 = rv32_write_heap_default( - &mut tester, + tester, x_limbs.clone(), y_limbs.clone(), - chip.0.core.air.offset + Fp2Opcode::MUL as usize, + OFFSET + Fp2Opcode::MUL as usize, + ); + let instruction2 = + rv32_write_heap_default(tester, x_limbs, y_limbs, OFFSET + Fp2Opcode::DIV as usize); + tester.execute(harness, &setup_instruction); + tester.execute(harness, &instruction1); + tester.execute(harness, &instruction2); + } + + #[test] + fn test_fp2_muldiv() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let modulus = BN254_MODULUS.clone(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = get_fp2_muldiv_air( + tester.execution_bridge(), + tester.memory_bridge(), + config.clone(), + tester.range_checker().bus(), + bitwise_bus, + tester.address_bits(), + OFFSET, + ); + let executor = get_fp2_muldiv_step( + config.clone(), + tester.range_checker().bus(), + tester.address_bits(), + OFFSET, ); - let instruction2 = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.0.core.air.offset + Fp2Opcode::DIV as usize, + let chip = get_fp2_muldiv_chip( + config, + tester.memory_helper(), + tester.range_checker(), + bitwise_chip.clone(), + tester.address_bits(), ); - tester.execute(&mut chip, &setup_instruction); - tester.execute(&mut chip, &instruction1); - tester.execute(&mut chip, &instruction2); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let mut harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + assert_eq!( + harness.executor.expr.builder.num_variables, 2, + "Fp2MulDiv should only introduce new z Fp2 variable (2 Fp var)" + ); + + let num_ops = 10; + for _ in 0..num_ops { + set_and_execute_rand(&mut tester, &mut harness, &modulus); + } + + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); tester.simple_test().expect("Verification failed"); } } diff --git a/extensions/algebra/circuit/src/fp2_extension.rs b/extensions/algebra/circuit/src/fp2_extension.rs index 37968081bd..3081c88565 100644 --- a/extensions/algebra/circuit/src/fp2_extension.rs +++ b/extensions/algebra/circuit/src/fp2_extension.rs @@ -1,26 +1,41 @@ -use derive_more::derive::From; +use std::sync::Arc; + use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::{ - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, - system::phantom::PhantomChip, + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{memory::SharedMemoryHelper, SystemPort}, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, + var_range::VariableRangeCheckerBus, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{LocalOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use strum::EnumCount; use crate::{ - fp2_chip::{Fp2AddSubChip, Fp2MulDivChip}, - ModularExtension, + fp2_chip::{ + get_fp2_addsub_air, get_fp2_addsub_chip, get_fp2_addsub_step, get_fp2_muldiv_air, + get_fp2_muldiv_chip, get_fp2_muldiv_step, Fp2Air, Fp2Executor, + }, + AlgebraCpuProverExt, ModularExtension, }; #[serde_as] @@ -47,7 +62,7 @@ impl Fp2Extension { .iter() .map(|(name, modulus)| { format!( - "{} {{ mod_idx = {} }}", + "\"{}\" {{ mod_idx = {} }}", name, get_index_of_modulus(modulus, modular_config) ) @@ -59,144 +74,287 @@ impl Fp2Extension { } } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] -pub enum Fp2ExtensionExecutor { +#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Fp2ExtensionExecutor { // 32 limbs prime - Fp2AddSubRv32_32(Fp2AddSubChip), - Fp2MulDivRv32_32(Fp2MulDivChip), + Fp2AddSubRv32_32(Fp2Executor<2, 32>), // Fp2AddSub + Fp2MulDivRv32_32(Fp2Executor<2, 32>), // Fp2MulDiv // 48 limbs prime - Fp2AddSubRv32_48(Fp2AddSubChip), - Fp2MulDivRv32_48(Fp2MulDivChip), -} - -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] -pub enum Fp2ExtensionPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - // We put this only to get the generic to work - Phantom(PhantomChip), + Fp2AddSubRv32_48(Fp2Executor<6, 16>), // Fp2AddSub + Fp2MulDivRv32_48(Fp2Executor<6, 16>), // Fp2MulDiv } -impl VmExtension for Fp2Extension { - type Executor = Fp2ExtensionExecutor; - type Periphery = Fp2ExtensionPeriphery; +impl VmExecutionExtension for Fp2Extension { + type Executor = Fp2ExtensionExecutor; - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); - let SystemPort { - execution_bus, - program_bus, - memory_bridge, - } = builder.system_port(); - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip - }; - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; - - let addsub_opcodes = (Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize); - let muldiv_opcodes = (Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize); - + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + // TODO: somehow get the range checker bus from `ExecutorInventory` + let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16); for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() { // determine the number of bytes needed to represent a prime field element let bytes = modulus.bits().div_ceil(8); let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT; - let config32 = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: 32, - limb_bits: 8, - }; - let config48 = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: 48, - limb_bits: 8, - }; - let adapter_chip_32 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - let adapter_chip_48 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - if bytes <= 32 { - let addsub_chip = Fp2AddSubChip::new( - adapter_chip_32.clone(), - config32.clone(), + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + let addsub = get_fp2_addsub_step( + config.clone(), + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub_chip), - addsub_opcodes - .clone() + Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub), + ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; - let muldiv_chip = Fp2MulDivChip::new( - adapter_chip_32.clone(), - config32.clone(), + + let muldiv = get_fp2_muldiv_step( + config, + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv_chip), - muldiv_opcodes - .clone() + Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv), + ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; } else if bytes <= 48 { - let addsub_chip = Fp2AddSubChip::new( - adapter_chip_48.clone(), - config48.clone(), + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + let addsub = get_fp2_addsub_step( + config.clone(), + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub_chip), - addsub_opcodes - .clone() + Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub), + ((Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; - let muldiv_chip = Fp2MulDivChip::new( - adapter_chip_48.clone(), - config48.clone(), + + let muldiv = get_fp2_muldiv_step( + config, + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv_chip), - muldiv_opcodes - .clone() + Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv), + ((Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; } else { panic!("Modulus too large"); } } + Ok(()) + } +} + +impl VmCircuitExtension for Fp2Extension { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = inventory.system().port(); + + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker_bus = inventory.range_checker().bus; + let pointer_max_bits = inventory.pointer_max_bits(); + + let bitwise_lu = { + // A trick to get around Rust's borrow rules + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } + }; + for (i, (_, modulus)) in self.supported_moduli.iter().enumerate() { + // determine the number of bytes needed to represent a prime field element + let bytes = modulus.bits().div_ceil(8); + let start_offset = Fp2Opcode::CLASS_OFFSET + i * Fp2Opcode::COUNT; + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + let addsub = get_fp2_addsub_air::<2, 32>( + exec_bridge, + memory_bridge, + config.clone(), + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(addsub); + + let muldiv = get_fp2_muldiv_air::<2, 32>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(muldiv); + } else if bytes <= 48 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + + let addsub = get_fp2_addsub_air::<6, 16>( + exec_bridge, + memory_bridge, + config.clone(), + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(addsub); + + let muldiv = get_fp2_muldiv_air::<6, 16>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(muldiv); + } else { + panic!("Modulus too large"); + } + } + + Ok(()) + } +} + +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for AlgebraCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + extension: &Fp2Extension, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + for (_, modulus) in extension.supported_moduli.iter() { + // determine the number of bytes needed to represent a prime field element + let bytes = modulus.bits().div_ceil(8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + inventory.next_air::>()?; + let addsub = get_fp2_addsub_chip::, 2, 32>( + config.clone(), + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(addsub); + + inventory.next_air::>()?; + let muldiv = get_fp2_muldiv_chip::, 2, 32>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(muldiv); + } else if bytes <= 48 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + + inventory.next_air::>()?; + let addsub = get_fp2_addsub_chip::, 6, 16>( + config.clone(), + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(addsub); + + inventory.next_air::>()?; + let muldiv = get_fp2_muldiv_chip::, 6, 16>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(muldiv); + } else { + panic!("Modulus too large"); + } + } - Ok(inventory) + Ok(()) } } diff --git a/extensions/algebra/circuit/src/lib.rs b/extensions/algebra/circuit/src/lib.rs index ffddacc61a..cbb9cc106f 100644 --- a/extensions/algebra/circuit/src/lib.rs +++ b/extensions/algebra/circuit/src/lib.rs @@ -1,3 +1,83 @@ +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, +}; + +use derive_more::derive::{Deref, DerefMut}; +use num_bigint::BigUint; +use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode}; +use openvm_circuit::{ + arch::*, + system::memory::{online::GuestMemory, POINTER_MAX_BITS}, +}; +use openvm_circuit_derive::PreflightExecutor; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, +}; +use openvm_mod_circuit_builder::{ + run_field_expression_precomputed, FieldExpr, FieldExpressionExecutor, +}; +use openvm_rv32_adapters::Rv32VecHeapAdapterExecutor; +use openvm_stark_backend::p3_field::PrimeField32; + +use self::fields::{ + field_operation, fp2_operation, get_field_type, get_fp2_field_type, FieldType, Operation, +}; + +macro_rules! generate_field_dispatch { + ( + $field_type:expr, + $op:expr, + $blocks:expr, + $block_size:expr, + $execute_fn:ident, + [$(($curve:ident, $operation:ident)),* $(,)?] + ) => { + match ($field_type, $op) { + $( + (FieldType::$curve, Operation::$operation) => Ok($execute_fn::< + _, + _, + $blocks, + $block_size, + false, + { FieldType::$curve as u8 }, + { Operation::$operation as u8 }, + >), + )* + } + }; +} + +macro_rules! generate_fp2_dispatch { + ( + $field_type:expr, + $op:expr, + $blocks:expr, + $block_size:expr, + $execute_fn:ident, + [$(($curve:ident, $operation:ident)),* $(,)?] + ) => { + match ($field_type, $op) { + $( + (FieldType::$curve, Operation::$operation) => Ok($execute_fn::< + _, + _, + $blocks, + $block_size, + true, + { FieldType::$curve as u8 }, + { Operation::$operation as u8 }, + >), + )* + _ => panic!("Unsupported fp2 field") + } + }; +} + pub mod fp2_chip; pub mod modular_chip; @@ -9,3 +89,535 @@ mod fp2_extension; pub use fp2_extension::*; mod config; pub use config::*; +pub mod fields; + +pub struct AlgebraCpuProverExt; + +#[derive(Clone, PreflightExecutor, Deref, DerefMut)] +pub struct FieldExprVecHeapExecutor< + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, +>(FieldExpressionExecutor>); + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FieldExpressionPreCompute<'a> { + expr: &'a FieldExpr, + rs_addrs: [u8; 2], + a: u8, + flag_idx: u8, +} + +impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool> + FieldExprVecHeapExecutor +{ + fn pre_compute_impl( + &'a self, + pc: u32, + inst: &Instruction, + data: &mut FieldExpressionPreCompute<'a>, + ) -> Result, StaticProgramError> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let local_opcode = opcode.local_opcode_idx(self.0.offset); + + let needs_setup = self.0.expr.needs_setup(); + let mut flag_idx = self.0.expr.num_flags() as u8; + if needs_setup { + if let Some(opcode_position) = self + .0 + .local_opcode_idx + .iter() + .position(|&idx| idx == local_opcode) + { + if opcode_position < self.0.opcode_flag_idx.len() { + flag_idx = self.0.opcode_flag_idx[opcode_position] as u8; + } + } + } + + let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8); + *data = FieldExpressionPreCompute { + a: a as u8, + rs_addrs, + expr: &self.0.expr, + flag_idx, + }; + + if IS_FP2 { + let is_setup = local_opcode == Fp2Opcode::SETUP_ADDSUB as usize + || local_opcode == Fp2Opcode::SETUP_MULDIV as usize; + + let op = if is_setup { + None + } else { + match local_opcode { + x if x == Fp2Opcode::ADD as usize => Some(Operation::Add), + x if x == Fp2Opcode::SUB as usize => Some(Operation::Sub), + x if x == Fp2Opcode::MUL as usize => Some(Operation::Mul), + x if x == Fp2Opcode::DIV as usize => Some(Operation::Div), + _ => unreachable!(), + } + }; + + Ok(op) + } else { + let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize + || local_opcode == Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize; + + let op = if is_setup { + None + } else { + match local_opcode { + x if x == Rv32ModularArithmeticOpcode::ADD as usize => Some(Operation::Add), + x if x == Rv32ModularArithmeticOpcode::SUB as usize => Some(Operation::Sub), + x if x == Rv32ModularArithmeticOpcode::MUL as usize => Some(Operation::Mul), + x if x == Rv32ModularArithmeticOpcode::DIV as usize => Some(Operation::Div), + _ => unreachable!(), + } + }; + + Ok(op) + } + } +} + +impl Executor + for FieldExprVecHeapExecutor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut(); + + let op = self.pre_compute_impl(pc, inst, pre_compute)?; + + if let Some(op) = op { + let modulus = &pre_compute.expr.prime; + if IS_FP2 { + if let Some(field_type) = get_fp2_field_type(modulus) { + generate_fp2_dispatch!( + field_type, + op, + BLOCKS, + BLOCK_SIZE, + execute_e1_impl, + [ + (BN254Coordinate, Add), + (BN254Coordinate, Sub), + (BN254Coordinate, Mul), + (BN254Coordinate, Div), + (BLS12_381Coordinate, Add), + (BLS12_381Coordinate, Sub), + (BLS12_381Coordinate, Mul), + (BLS12_381Coordinate, Div), + ] + ) + } else { + Ok(execute_e1_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } else if let Some(field_type) = get_field_type(modulus) { + generate_field_dispatch!( + field_type, + op, + BLOCKS, + BLOCK_SIZE, + execute_e1_impl, + [ + (K256Coordinate, Add), + (K256Coordinate, Sub), + (K256Coordinate, Mul), + (K256Coordinate, Div), + (K256Scalar, Add), + (K256Scalar, Sub), + (K256Scalar, Mul), + (K256Scalar, Div), + (P256Coordinate, Add), + (P256Coordinate, Sub), + (P256Coordinate, Mul), + (P256Coordinate, Div), + (P256Scalar, Add), + (P256Scalar, Sub), + (P256Scalar, Mul), + (P256Scalar, Div), + (BN254Coordinate, Add), + (BN254Coordinate, Sub), + (BN254Coordinate, Mul), + (BN254Coordinate, Div), + (BN254Scalar, Add), + (BN254Scalar, Sub), + (BN254Scalar, Mul), + (BN254Scalar, Div), + (BLS12_381Coordinate, Add), + (BLS12_381Coordinate, Sub), + (BLS12_381Coordinate, Mul), + (BLS12_381Coordinate, Div), + (BLS12_381Scalar, Add), + (BLS12_381Scalar, Sub), + (BLS12_381Scalar, Mul), + (BLS12_381Scalar, Div), + ] + ) + } else { + Ok(execute_e1_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } else { + Ok(execute_e1_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } +} + +impl + MeteredExecutor for FieldExprVecHeapExecutor +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let op = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + if let Some(op) = op { + let modulus = &pre_compute.data.expr.prime; + if IS_FP2 { + if let Some(field_type) = get_fp2_field_type(modulus) { + generate_fp2_dispatch!( + field_type, + op, + BLOCKS, + BLOCK_SIZE, + execute_e2_impl, + [ + (BN254Coordinate, Add), + (BN254Coordinate, Sub), + (BN254Coordinate, Mul), + (BN254Coordinate, Div), + (BLS12_381Coordinate, Add), + (BLS12_381Coordinate, Sub), + (BLS12_381Coordinate, Mul), + (BLS12_381Coordinate, Div), + ] + ) + } else { + Ok(execute_e2_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } else if let Some(field_type) = get_field_type(modulus) { + generate_field_dispatch!( + field_type, + op, + BLOCKS, + BLOCK_SIZE, + execute_e2_impl, + [ + (K256Coordinate, Add), + (K256Coordinate, Sub), + (K256Coordinate, Mul), + (K256Coordinate, Div), + (K256Scalar, Add), + (K256Scalar, Sub), + (K256Scalar, Mul), + (K256Scalar, Div), + (P256Coordinate, Add), + (P256Coordinate, Sub), + (P256Coordinate, Mul), + (P256Coordinate, Div), + (P256Scalar, Add), + (P256Scalar, Sub), + (P256Scalar, Mul), + (P256Scalar, Div), + (BN254Coordinate, Add), + (BN254Coordinate, Sub), + (BN254Coordinate, Mul), + (BN254Coordinate, Div), + (BN254Scalar, Add), + (BN254Scalar, Sub), + (BN254Scalar, Mul), + (BN254Scalar, Div), + (BLS12_381Coordinate, Add), + (BLS12_381Coordinate, Sub), + (BLS12_381Coordinate, Mul), + (BLS12_381Coordinate, Div), + (BLS12_381Scalar, Add), + (BLS12_381Scalar, Sub), + (BLS12_381Scalar, Mul), + (BLS12_381Scalar, Div), + ] + ) + } else { + Ok(execute_e2_generic_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } else { + Ok(execute_e2_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>) + } + } +} +unsafe fn execute_e1_setup_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow(); + execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(pre_compute, vm_state); +} + +unsafe fn execute_e2_setup_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(&pre_compute.data, vm_state); +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, + const FIELD_TYPE: u8, + const OP: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow(); + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, + const FIELD_TYPE: u8, + const OP: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2, FIELD_TYPE, OP>( + &pre_compute.data, + vm_state, + ); +} + +unsafe fn execute_e1_generic_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow(); + execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(pre_compute, vm_state); +} + +unsafe fn execute_e2_generic_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_generic_impl::<_, _, BLOCKS, BLOCK_SIZE>(&pre_compute.data, vm_state); +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, + const FIELD_TYPE: u8, + const OP: u8, +>( + pre_compute: &FieldExpressionPreCompute, + vm_state: &mut VmExecState, +) { + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| { + debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32)) + }); + + let output_data = if IS_FP2 { + fp2_operation::(read_data) + } else { + field_operation::(read_data) + }; + + let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + + for (i, block) in output_data.into_iter().enumerate() { + vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e12_generic_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +>( + pre_compute: &FieldExpressionPreCompute, + vm_state: &mut VmExecState, +) { + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| { + debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32)) + }); + let read_data_dyn: DynArray = read_data.into(); + + let writes = run_field_expression_precomputed::( + pre_compute.expr, + pre_compute.flag_idx as usize, + &read_data_dyn.0, + ); + + let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + + let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into(); + for (i, block) in data.into_iter().enumerate() { + vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e12_setup_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const IS_FP2: bool, +>( + pre_compute: &FieldExpressionPreCompute, + vm_state: &mut VmExecState, +) { + // Read the first input (which should be the prime) + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| { + debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32)) + }); + + // Extract first field element as the prime + let input_prime = if IS_FP2 { + BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened()) + } else { + BigUint::from_bytes_le(read_data[0].as_flattened()) + }; + + if input_prime != pre_compute.expr.prime { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "ModularSetup: mismatched prime", + }); + return; + } + + let read_data_dyn: DynArray = read_data.into(); + + let writes = run_field_expression_precomputed::( + pre_compute.expr, + pre_compute.flag_idx as usize, + &read_data_dyn.0, + ); + + let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + + let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into(); + for (i, block) in data.into_iter().enumerate() { + vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} diff --git a/extensions/algebra/circuit/src/modular_chip/addsub.rs b/extensions/algebra/circuit/src/modular_chip/addsub.rs index 34bede150f..abde6ea6dd 100644 --- a/extensions/algebra/circuit/src/modular_chip/addsub.rs +++ b/extensions/algebra/circuit/src/modular_chip/addsub.rs @@ -1,21 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldExpressionExecutor, + FieldExpressionFiller, FieldVariable, +}; +use openvm_rv32_adapters::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; + +use super::{ModularAir, ModularChip, ModularExecutor}; +use crate::FieldExprVecHeapExecutor; pub fn addsub_expr( config: ExprBuilderConfig, @@ -29,12 +33,12 @@ pub fn addsub_expr( let x2 = ExprBuilder::new_input(builder.clone()); let x3 = x1.clone() + x2.clone(); let x4 = x1.clone() - x2.clone(); - let is_add_flag = builder.borrow_mut().new_flag(); - let is_sub_flag = builder.borrow_mut().new_flag(); + let is_add_flag = (*builder).borrow_mut().new_flag(); + let is_sub_flag = (*builder).borrow_mut().new_flag(); let x5 = FieldVariable::select(is_sub_flag, &x4, &x1); let mut x6 = FieldVariable::select(is_add_flag, &x3, &x5); x6.save_output(); - let builder = builder.borrow().clone(); + let builder = (*builder).borrow().clone(); ( FieldExpr::new(builder, range_bus, true), @@ -43,39 +47,78 @@ pub fn addsub_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct ModularAddSubChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, +) -> (FieldExpr, Vec, Vec) { + let (expr, is_add_flag, is_sub_flag) = addsub_expr(config, range_checker_bus); + + let local_opcode_idx = vec![ + Rv32ModularArithmeticOpcode::ADD as usize, + Rv32ModularArithmeticOpcode::SUB as usize, + Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, + ]; + let opcode_flag_idx = vec![is_add_flag, is_sub_flag]; + + (expr, local_opcode_idx, opcode_flag_idx) +} + +pub fn get_modular_addsub_air( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, +) -> ModularAir { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); + ModularAir::new( + Rv32VecHeapAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx), + ) +} + +pub fn get_modular_addsub_step( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, +) -> ModularExecutor { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); -impl - ModularAddSubChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_add_flag, is_sub_flag) = addsub_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( + FieldExprVecHeapExecutor(FieldExpressionExecutor::new( + Rv32VecHeapAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + "ModularAddSub", + )) +} + +pub fn get_modular_addsub_chip( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, +) -> ModularChip { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker.bus()); + ModularChip::new( + FieldExpressionFiller::new( + Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), expr, - offset, - vec![ - Rv32ModularArithmeticOpcode::ADD as usize, - Rv32ModularArithmeticOpcode::SUB as usize, - Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, - ], - vec![is_add_flag, is_sub_flag], + local_opcode_idx, + opcode_flag_idx, range_checker, - "ModularAddSub", false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } + ), + mem_helper, + ) } diff --git a/extensions/algebra/circuit/src/modular_chip/is_eq.rs b/extensions/algebra/circuit/src/modular_chip/is_eq.rs index fe91585466..e9b71b8036 100644 --- a/extensions/algebra/circuit/src/modular_chip/is_eq.rs +++ b/extensions/algebra/circuit/src/modular_chip/is_eq.rs @@ -5,32 +5,41 @@ use std::{ use num_bigint::BigUint; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, POINTER_MAX_BITS, + }, }; use openvm_circuit_primitives::{ bigint::utils::big_uint_to_limbs, bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, is_equal_array::{IsEqArrayIo, IsEqArraySubAir}, - SubAir, TraceSubRowGenerator, + AlignedBytesBorrow, SubAir, TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32_adapters::Rv32IsEqualModAdapterExecutor; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; + +use crate::modular_chip::VmModularIsEqualExecutor; // Given two numbers b and c, we want to prove that a) b == c or b != c, depending on // result of cmp_result and b) b, c < N for some modulus N that is passed into the AIR // at runtime (i.e. when chip is instantiated). #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct ModularIsEqualCoreCols { pub is_valid: T, pub is_setup: T, @@ -278,155 +287,395 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ModularIsEqualCoreRecord { - #[serde(with = "BigArray")] - pub b: [T; READ_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; READ_LIMBS], - pub cmp_result: T, - #[serde(with = "BigArray")] - pub eq_marker: [T; READ_LIMBS], - pub b_diff_idx: usize, - pub c_diff_idx: usize, +#[derive(AlignedBytesBorrow, Debug)] +pub struct ModularIsEqualRecord { pub is_setup: bool, + pub b: [u8; READ_LIMBS], + pub c: [u8; READ_LIMBS], } -pub struct ModularIsEqualCoreChip< +#[derive(derive_new::new, Clone)] +pub struct ModularIsEqualExecutor< + A, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize, > { - pub air: ModularIsEqualCoreAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + adapter: A, + pub offset: usize, + pub modulus_limbs: [u8; READ_LIMBS], } -impl - ModularIsEqualCoreChip -{ - pub fn new( - modulus: BigUint, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: ModularIsEqualCoreAir::new(modulus, bitwise_lookup_chip.bus(), offset), - bitwise_lookup_chip, - } - } +#[derive(derive_new::new, Clone)] +pub struct ModularIsEqualFiller< + A, + const READ_LIMBS: usize, + const WRITE_LIMBS: usize, + const LIMB_BITS: usize, +> { + adapter: A, + pub offset: usize, + pub modulus_limbs: [u8; READ_LIMBS], + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl< - F: PrimeField32, - I: VmAdapterInterface, - const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, - > VmCoreChip for ModularIsEqualCoreChip +impl + PreflightExecutor for ModularIsEqualExecutor where - I::Reads: Into<[[F; READ_LIMBS]; 2]>, - I::Writes: From<[[F; WRITE_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData: Into<[[u8; READ_LIMBS]; 2]>, + WriteData: From<[u8; WRITE_LIMBS]>, + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut ModularIsEqualRecord, + ), + >, { - type Record = ModularIsEqualCoreRecord; - type Air = ModularIsEqualCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let data: [[F; READ_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (b_cmp, b_diff_idx) = run_unsigned_less_than::(&b, &self.air.modulus_limbs); - let (c_cmp, c_diff_idx) = run_unsigned_less_than::(&c, &self.air.modulus_limbs); - let is_setup = instruction.opcode.local_opcode_idx(self.air.offset) + ) -> Result<(), ExecutionError> { + let Instruction { opcode, .. } = instruction; + + let local_opcode = + Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + matches!( + local_opcode, + Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ + ); + + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.is_setup = instruction.opcode.local_opcode_idx(self.offset) == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize; - if !is_setup { - assert!(b_cmp, "{:?} >= {:?}", b, self.air.modulus_limbs); - } - assert!(c_cmp, "{:?} >= {:?}", c, self.air.modulus_limbs); - if !is_setup { - self.bitwise_lookup_chip.request_range( - self.air.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1, - self.air.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1, - ); - } + let mut write_data = [0u8; WRITE_LIMBS]; + write_data[0] = (core_record.b == core_record.c) as u8; - let mut eq_marker = [F::ZERO; READ_LIMBS]; - let mut cmp_result = F::ZERO; - self.air - .subair - .generate_subrow((&data[0], &data[1]), (&mut eq_marker, &mut cmp_result)); - - let mut writes = [F::ZERO; WRITE_LIMBS]; - writes[0] = cmp_result; - - let output = AdapterRuntimeContext::without_pc([writes]); - let record = ModularIsEqualCoreRecord { - is_setup, - b: data[0], - c: data[1], - cmp_result, - eq_marker, - b_diff_idx, - c_diff_idx, - }; + self.adapter.write( + state.memory, + instruction, + write_data.into(), + &mut adapter_record, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - Ok((output, record)) + Ok(()) } fn get_opcode_name(&self, opcode: usize) -> String { format!( "{:?}", - Rv32ModularArithmeticOpcode::from_usize(opcode - self.air.offset) + Rv32ModularArithmeticOpcode::from_usize(opcode - self.offset) ) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut(); - row_slice.is_valid = F::ONE; - row_slice.is_setup = F::from_bool(record.is_setup); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.cmp_result = record.cmp_result; - - row_slice.eq_marker = record.eq_marker; +impl TraceFiller + for ModularIsEqualFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = row_slice.split_at_mut(A::WIDTH); + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &ModularIsEqualRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let cols: &mut ModularIsEqualCoreCols = core_row.borrow_mut(); + let (b_cmp, b_diff_idx) = + run_unsigned_less_than::(&record.b, &self.modulus_limbs); + let (c_cmp, c_diff_idx) = + run_unsigned_less_than::(&record.c, &self.modulus_limbs); if !record.is_setup { - row_slice.b_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.b_diff_idx]) - - record.b[record.b_diff_idx]; + assert!(b_cmp, "{:?} >= {:?}", record.b, self.modulus_limbs); } - row_slice.c_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.c_diff_idx]) - - record.c[record.c_diff_idx]; - row_slice.c_lt_mark = if record.b_diff_idx == record.c_diff_idx { + assert!(c_cmp, "{:?} >= {:?}", record.c, self.modulus_limbs); + + // Writing in reverse order + cols.c_lt_mark = if b_diff_idx == c_diff_idx { F::ONE } else { - F::from_canonical_u8(2) + F::TWO }; - row_slice.lt_marker = from_fn(|i| { - if i == record.b_diff_idx { + + cols.c_lt_diff = + F::from_canonical_u8(self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx]); + if !record.is_setup { + cols.b_lt_diff = + F::from_canonical_u8(self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx]); + self.bitwise_lookup_chip.request_range( + (self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx] - 1) as u32, + (self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx] - 1) as u32, + ); + } else { + cols.b_lt_diff = F::ZERO; + } + + cols.lt_marker = from_fn(|i| { + if i == b_diff_idx { F::ONE - } else if i == record.c_diff_idx { - row_slice.c_lt_mark + } else if i == c_diff_idx { + cols.c_lt_mark } else { F::ZERO } }); + + cols.c = record.c.map(F::from_canonical_u8); + cols.b = record.b.map(F::from_canonical_u8); + let sub_air = IsEqArraySubAir::; + sub_air.generate_subrow( + (&cols.b, &cols.c), + (&mut cols.eq_marker, &mut cols.cmp_result), + ); + + cols.is_setup = F::from_bool(record.is_setup); + cols.is_valid = F::ONE; } +} - fn air(&self) -> &Self::Air { - &self.air +impl + VmModularIsEqualExecutor +{ + pub fn new( + adapter: Rv32IsEqualModAdapterExecutor<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + offset: usize, + modulus_limbs: [u8; TOTAL_LIMBS], + ) -> Self { + Self(ModularIsEqualExecutor::new(adapter, offset, modulus_limbs)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct ModularIsEqualPreCompute { + a: u8, + rs_addrs: [u8; 2], + modulus_limbs: [u8; READ_LIMBS], +} + +impl + VmModularIsEqualExecutor +{ + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut ModularIsEqualPreCompute, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + let local_opcode = + Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.0.offset)); + + // Validate instruction format + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + if !matches!( + local_opcode, + Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ + ) { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8); + *data = ModularIsEqualPreCompute { + a: a as u8, + rs_addrs, + modulus_limbs: self.0.modulus_limbs, + }; + + let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ; + + Ok(is_setup) } } +impl Executor + for VmModularIsEqualExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut ModularIsEqualPreCompute = data.borrow_mut(); + + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = if is_setup { + execute_e1_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true> + } else { + execute_e1_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false> + }; + + Ok(fn_ptr) + } +} + +impl + MeteredExecutor for VmModularIsEqualExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::mem::size_of::>>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = if is_setup { + execute_e2_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true> + } else { + execute_e2_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false> + }; + + Ok(fn_ptr) + } +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_READ_SIZE: usize, + const IS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &ModularIsEqualPreCompute = pre_compute.borrow(); + + execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>( + pre_compute, + vm_state, + ); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_READ_SIZE: usize, + const IS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = + pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>( + &pre_compute.data, + vm_state, + ); +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_READ_SIZE: usize, + const IS_SETUP: bool, +>( + pre_compute: &ModularIsEqualPreCompute, + vm_state: &mut VmExecState, +) { + // Read register values + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + // Read memory values + let [b, c]: [[u8; TOTAL_READ_SIZE]; 2] = rs_vals.map(|address| { + debug_assert!(address as usize + TOTAL_READ_SIZE - 1 < (1 << POINTER_MAX_BITS)); + from_fn::<_, NUM_LANES, _>(|i| { + vm_state.vm_read::<_, LANE_SIZE>(RV32_MEMORY_AS, address + (i * LANE_SIZE) as u32) + }) + .concat() + .try_into() + .unwrap() + }); + + if !IS_SETUP { + let (b_cmp, _) = run_unsigned_less_than::(&b, &pre_compute.modulus_limbs); + debug_assert!(b_cmp, "{:?} >= {:?}", b, pre_compute.modulus_limbs); + } + + let (c_cmp, _) = run_unsigned_less_than::(&c, &pre_compute.modulus_limbs); + debug_assert!(c_cmp, "{:?} >= {:?}", c, pre_compute.modulus_limbs); + + // Compute result + let mut write_data = [0u8; RV32_REGISTER_NUM_LIMBS]; + write_data[0] = (b == c) as u8; + + // Write result to register + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + // Returns (cmp_result, diff_idx) +#[inline(always)] pub(super) fn run_unsigned_less_than( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize) { for i in (0..NUM_LIMBS).rev() { if x[i] != y[i] { diff --git a/extensions/algebra/circuit/src/modular_chip/mod.rs b/extensions/algebra/circuit/src/modular_chip/mod.rs index 2dd9838206..24c0aa8b1a 100644 --- a/extensions/algebra/circuit/src/modular_chip/mod.rs +++ b/extensions/algebra/circuit/src/modular_chip/mod.rs @@ -1,17 +1,61 @@ -mod addsub; -pub use addsub::*; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit_derive::PreflightExecutor; +use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use openvm_mod_circuit_builder::{FieldExpressionCoreAir, FieldExpressionFiller}; +use openvm_rv32_adapters::{ + Rv32IsEqualModAdapterAir, Rv32IsEqualModAdapterExecutor, Rv32IsEqualModAdapterFiller, + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterFiller, +}; + +use crate::FieldExprVecHeapExecutor; + mod is_eq; pub use is_eq::*; +mod addsub; +pub use addsub::*; mod muldiv; pub use muldiv::*; -use openvm_circuit::arch::VmChipWrapper; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use openvm_rv32_adapters::Rv32IsEqualModAdapterChip; #[cfg(test)] mod tests; +pub type ModularAir = VmAirWrapper< + Rv32VecHeapAdapterAir<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>, + FieldExpressionCoreAir, +>; + +pub type ModularExecutor = + FieldExprVecHeapExecutor; + +pub type ModularChip = VmChipWrapper< + F, + FieldExpressionFiller>, +>; + // Must have TOTAL_LIMBS = NUM_LANES * LANE_SIZE +pub type ModularIsEqualAir< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, +> = VmAirWrapper< + Rv32IsEqualModAdapterAir<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + ModularIsEqualCoreAir, +>; + +#[derive(Clone, PreflightExecutor)] +pub struct VmModularIsEqualExecutor< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, +>( + ModularIsEqualExecutor< + Rv32IsEqualModAdapterExecutor<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + TOTAL_LIMBS, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >, +); + pub type ModularIsEqualChip< F, const NUM_LANES: usize, @@ -19,6 +63,10 @@ pub type ModularIsEqualChip< const TOTAL_LIMBS: usize, > = VmChipWrapper< F, - Rv32IsEqualModAdapterChip, - ModularIsEqualCoreChip, + ModularIsEqualFiller< + Rv32IsEqualModAdapterFiller<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + TOTAL_LIMBS, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >, >; diff --git a/extensions/algebra/circuit/src/modular_chip/muldiv.rs b/extensions/algebra/circuit/src/modular_chip/muldiv.rs index 30f063e2b1..fef9d1e9a7 100644 --- a/extensions/algebra/circuit/src/modular_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/modular_chip/muldiv.rs @@ -1,21 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldExpressionExecutor, + FieldExpressionFiller, FieldVariable, SymbolicExpr, +}; +use openvm_rv32_adapters::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; + +use super::{ModularAir, ModularChip, ModularExecutor}; +use crate::FieldExprVecHeapExecutor; pub fn muldiv_expr( config: ExprBuilderConfig, @@ -26,17 +30,19 @@ pub fn muldiv_expr( let builder = Rc::new(RefCell::new(builder)); let x = ExprBuilder::new_input(builder.clone()); let y = ExprBuilder::new_input(builder.clone()); - let (z_idx, z) = builder.borrow_mut().new_var(); + let (z_idx, z) = (*builder).borrow_mut().new_var(); let mut z = FieldVariable::from_var(builder.clone(), z); - let is_mul_flag = builder.borrow_mut().new_flag(); - let is_div_flag = builder.borrow_mut().new_flag(); + let is_mul_flag = (*builder).borrow_mut().new_flag(); + let is_div_flag = (*builder).borrow_mut().new_flag(); // constraint is x * y = z, or z * y = x let lvar = FieldVariable::select(is_mul_flag, &x, &z); let rvar = FieldVariable::select(is_mul_flag, &z, &x); // When it's SETUP op, x = p == 0, y = 0, both flags are false, and it still works: z * 0 - x = // 0, whatever z is. let constraint = lvar * y.clone() - rvar; - builder.borrow_mut().set_constraint(z_idx, constraint.expr); + (*builder) + .borrow_mut() + .set_constraint(z_idx, constraint.expr); let compute = SymbolicExpr::Select( is_mul_flag, Box::new(x.expr.clone() * y.expr.clone()), @@ -46,10 +52,10 @@ pub fn muldiv_expr( Box::new(x.expr.clone()), )), ); - builder.borrow_mut().set_compute(z_idx, compute); + (*builder).borrow_mut().set_compute(z_idx, compute); z.save_output(); - let builder = builder.borrow().clone(); + let builder = (*builder).borrow().clone(); ( FieldExpr::new(builder, range_bus, true), @@ -58,39 +64,78 @@ pub fn muldiv_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct ModularMulDivChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, +) -> (FieldExpr, Vec, Vec) { + let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker_bus); + + let local_opcode_idx = vec![ + Rv32ModularArithmeticOpcode::MUL as usize, + Rv32ModularArithmeticOpcode::DIV as usize, + Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, + ]; + let opcode_flag_idx = vec![is_mul_flag, is_div_flag]; + + (expr, local_opcode_idx, opcode_flag_idx) +} + +pub fn get_modular_muldiv_air( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, +) -> ModularAir { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); + ModularAir::new( + Rv32VecHeapAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx), + ) +} + +pub fn get_modular_muldiv_step( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, +) -> ModularExecutor { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker_bus); -impl - ModularMulDivChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( + FieldExprVecHeapExecutor(FieldExpressionExecutor::new( + Rv32VecHeapAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + "ModularMulDiv", + )) +} + +pub fn get_modular_muldiv_chip( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, +) -> ModularChip { + let (expr, local_opcode_idx, opcode_flag_idx) = gen_base_expr(config, range_checker.bus()); + ModularChip::new( + FieldExpressionFiller::new( + Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), expr, - offset, - vec![ - Rv32ModularArithmeticOpcode::MUL as usize, - Rv32ModularArithmeticOpcode::DIV as usize, - Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, - ], - vec![is_mul_flag, is_div_flag], + local_opcode_idx, + opcode_flag_idx, range_checker, - "ModularMulDiv", false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } + ), + mem_helper, + ) } diff --git a/extensions/algebra/circuit/src/modular_chip/tests.rs b/extensions/algebra/circuit/src/modular_chip/tests.rs index 1ad3310f76..c599d40b05 100644 --- a/extensions/algebra/circuit/src/modular_chip/tests.rs +++ b/extensions/algebra/circuit/src/modular_chip/tests.rs @@ -1,122 +1,135 @@ -use std::{array::from_fn, borrow::BorrowMut}; +use std::{array::from_fn, borrow::BorrowMut, sync::Arc}; use num_bigint::BigUint; use num_traits::Zero; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; use openvm_circuit::arch::{ instructions::LocalOpcode, - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - AdapterRuntimeContext, Result, VmAdapterInterface, VmChipWrapper, VmCoreChip, + testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + Arena, DenseRecordArena, MatrixRecordArena, PreflightExecutor, }; use openvm_circuit_primitives::{ bigint::utils::{big_uint_to_limbs, secp256k1_coord_prime, secp256k1_scalar_prime}, - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, }; use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, VmOpcode}; use openvm_mod_circuit_builder::{ test_utils::{biguint_to_limbs, generate_field_element}, - ExprBuilderConfig, + ExprBuilderConfig, FieldExpressionCoreRecordMut, }; use openvm_pairing_guest::bls12_381::BLS12_381_MODULUS; -use openvm_rv32_adapters::{ - rv32_write_heap_default, write_ptr_reg, Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip, -}; +use openvm_rv32_adapters::{rv32_write_heap_default, write_ptr_reg, Rv32VecHeapAdapterRecord}; use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS; -use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; +use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; -use super::{ - ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreAir, ModularIsEqualCoreChip, - ModularIsEqualCoreCols, ModularIsEqualCoreRecord, ModularMulDivChip, +use crate::modular_chip::{ + get_modular_addsub_air, get_modular_addsub_chip, get_modular_addsub_step, + get_modular_muldiv_air, get_modular_muldiv_chip, get_modular_muldiv_step, ModularAir, + ModularChip, ModularExecutor, ModularIsEqualAir, ModularIsEqualChip, ModularIsEqualCoreAir, + ModularIsEqualCoreCols, ModularIsEqualFiller, VmModularIsEqualExecutor, }; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 32; +const _BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; -const ADD_LOCAL: usize = Rv32ModularArithmeticOpcode::ADD as usize; -const MUL_LOCAL: usize = Rv32ModularArithmeticOpcode::MUL as usize; +#[cfg(test)] +mod addsubtests { + use test_case::test_case; -#[test] -fn test_coord_addsub() { - let opcode_offset = 0; - let modulus = secp256k1_coord_prime(); - test_addsub(opcode_offset, modulus); -} + use super::*; -#[test] -fn test_scalar_addsub() { - let opcode_offset = 4; - let modulus = secp256k1_scalar_prime(); - test_addsub(opcode_offset, modulus); -} + type Harness = TestChipHarness< + F, + ModularExecutor<1, NUM_LIMBS>, + ModularAir<1, NUM_LIMBS>, + ModularChip, + RA, + >; + const ADD_LOCAL: usize = Rv32ModularArithmeticOpcode::ADD as usize; -fn test_addsub(opcode_offset: usize, modulus: BigUint) { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - // doing 1xNUM_LIMBS reads and writes - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = ModularAddSubChip::new( - adapter, - config, - Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - let mut rng = create_seeded_rng(); - let num_tests = 50; - let mut all_ops = vec![ADD_LOCAL + 2]; // setup - let mut all_a = vec![modulus.clone()]; - let mut all_b = vec![BigUint::zero()]; - - // First loop: generate all random test data. - for _ in 0..num_tests { - let a_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut a = BigUint::new(a_digits.clone()); - let b_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut b = BigUint::new(b_digits.clone()); - - let op = rng.gen_range(0..2) + ADD_LOCAL; // 0 for add, 1 for sub - a %= &modulus; - b %= &modulus; - - all_ops.push(op); - all_a.push(a); - all_b.push(b); + fn create_test_chip( + tester: &VmChipTestBuilder, + config: ExprBuilderConfig, + offset: usize, + ) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), + ) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = get_modular_addsub_air( + tester.execution_bridge(), + tester.memory_bridge(), + config.clone(), + tester.range_checker().bus(), + bitwise_bus, + tester.address_bits(), + offset, + ); + let executor = get_modular_addsub_step( + config.clone(), + tester.range_checker().bus(), + tester.address_bits(), + offset, + ); + let chip = get_modular_addsub_chip( + config, + tester.memory_helper(), + tester.range_checker(), + bitwise_chip.clone(), + tester.address_bits(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) } - // Second loop: actually run the tests. - for i in 0..=num_tests { - let op = all_ops[i]; - let a = all_a[i].clone(); - let b = all_b[i].clone(); - if i > 0 { - // if not setup - assert!(a < modulus); - assert!(b < modulus); - } + fn set_and_execute_addsub( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + modulus: &BigUint, + is_setup: bool, + offset: usize, + ) where + ModularExecutor<1, NUM_LIMBS>: PreflightExecutor, + { + let mut rng = create_seeded_rng(); + + let (a, b, op) = if is_setup { + (modulus.clone(), BigUint::zero(), ADD_LOCAL + 2) + } else { + let a_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut a = BigUint::new(a_digits.clone()); + let b_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut b = BigUint::new(b_digits.clone()); + + let op = rng.gen_range(0..2) + ADD_LOCAL; // 0 for add, 1 for sub + a %= modulus; + b %= modulus; + (a, b, op) + }; + let expected_answer = match op - ADD_LOCAL { - 0 => (&a + &b) % &modulus, - 1 => (&a + &modulus - &b) % &modulus, - 2 => a.clone() % &modulus, + 0 => (&a + &b) % modulus, + 1 => (&a + modulus - &b) % modulus, + 2 => a.clone() % modulus, _ => panic!(), }; @@ -133,11 +146,11 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { let data_as = 2; let address1 = 0u32; let address2 = 128u32; - let address3 = (1 << 28) + 1234; // a large memory address to test heap adapter + let address3 = (1 << 28) + 1228; // a large memory address to test heap adapter - write_ptr_reg(&mut tester, ptr_as, addr_ptr1, address1); - write_ptr_reg(&mut tester, ptr_as, addr_ptr2, address2); - write_ptr_reg(&mut tester, ptr_as, addr_ptr3, address3); + write_ptr_reg(tester, ptr_as, addr_ptr1, address1); + write_ptr_reg(tester, ptr_as, addr_ptr2, address2); + write_ptr_reg(tester, ptr_as, addr_ptr3, address3); let a_limbs: [BabyBear; NUM_LIMBS] = biguint_to_limbs(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); @@ -147,105 +160,133 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.0.core.air.offset + op), + VmOpcode::from_usize(offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, ptr_as as isize, data_as as isize, ); - tester.execute(&mut chip, &instruction); + tester.execute(harness, &instruction); let expected_limbs = biguint_to_limbs::(expected_answer, LIMB_BITS); - for (i, expected) in expected_limbs.into_iter().enumerate() { - let address = address3 as usize + i; - let read_val = tester.read_cell(data_as, address); - assert_eq!(BabyBear::from_canonical_u32(expected), read_val); - } + let read_vals = tester.read::(data_as, address3 as usize); + assert_eq!(read_vals, expected_limbs.map(F::from_canonical_u32)); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn test_addsub(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let offset = Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset; + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; -#[test] -fn test_coord_muldiv() { - let opcode_offset = 0; - let modulus = secp256k1_coord_prime(); - test_muldiv(opcode_offset, modulus); -} + // doing 1xNUM_LIMBS reads and writes + let (mut harness, bitwise) = + create_test_chip::>(&tester, config, offset); -#[test] -fn test_scalar_muldiv() { - let opcode_offset = 4; - let modulus = secp256k1_scalar_prime(); - test_muldiv(opcode_offset, modulus); -} + for i in 0..num_ops { + set_and_execute_addsub(&mut tester, &mut harness, &modulus, i == 0, offset); + } -fn test_muldiv(opcode_offset: usize, modulus: BigUint) { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - // doing 1xNUM_LIMBS reads and writes - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = ModularMulDivChip::new( - adapter, - config, - Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - let mut rng = create_seeded_rng(); - let num_tests = 50; - let mut all_ops = vec![MUL_LOCAL + 2]; - let mut all_a = vec![modulus.clone()]; - let mut all_b = vec![BigUint::zero()]; - - // First loop: generate all random test data. - for _ in 0..num_tests { - let a_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut a = BigUint::new(a_digits.clone()); - let b_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut b = BigUint::new(b_digits.clone()); - - // let op = rng.gen_range(2..4); // 2 for mul, 3 for div - let op = MUL_LOCAL; - a %= &modulus; - b %= &modulus; - - all_ops.push(op); - all_a.push(a); - all_b.push(b); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); } - // Second loop: actually run the tests. - for i in 0..=num_tests { - let op = all_ops[i]; - let a = all_a[i].clone(); - let b = all_b[i].clone(); - if i > 0 { - // if not setup - assert!(a < modulus); - assert!(b < modulus); + + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn dense_record_arena_test(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let offset = Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset; + + let (mut sparse_harness, bitwise) = + create_test_chip::>(&tester, config.clone(), offset); + + { + // doing 1xNUM_LIMBS reads and writes + let mut dense_harness = create_test_chip::(&tester, config, offset).0; + + for i in 0..num_ops { + set_and_execute_addsub(&mut tester, &mut dense_harness, &modulus, i == 0, offset); + } + + type Record<'a> = ( + &'a mut Rv32VecHeapAdapterRecord<2, 1, 1, NUM_LIMBS, NUM_LIMBS>, + FieldExpressionCoreRecordMut<'a>, + ); + let mut record_interpreter = dense_harness.arena.get_record_seeker::(); + record_interpreter.transfer_to_matrix_arena( + &mut sparse_harness.arena, + dense_harness.executor.get_record_layout::(), + ); } + + let tester = tester + .build() + .load(sparse_harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); + } +} + +#[cfg(test)] +mod muldivtests { + use test_case::test_case; + + use super::*; + + const MUL_LOCAL: usize = Rv32ModularArithmeticOpcode::MUL as usize; + type Harness = TestChipHarness< + F, + ModularExecutor<1, NUM_LIMBS>, + ModularAir<1, NUM_LIMBS>, + ModularChip, + >; + + fn set_and_execute_muldiv( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + modulus: &BigUint, + is_setup: bool, + ) { + let mut rng = create_seeded_rng(); + + let (a, b, op) = if is_setup { + (modulus.clone(), BigUint::zero(), MUL_LOCAL + 2) + } else { + let a_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut a = BigUint::new(a_digits.clone()); + let b_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut b = BigUint::new(b_digits.clone()); + + let op = rng.gen_range(0..2) + MUL_LOCAL; // 0 for add, 1 for sub + a %= modulus; + b %= modulus; + (a, b, op) + }; + let expected_answer = match op - MUL_LOCAL { - 0 => (&a * &b) % &modulus, - 1 => (&a * b.modinv(&modulus).unwrap()) % &modulus, - 2 => a.clone() % &modulus, + 0 => (&a * &b) % modulus, + 1 => (&a * b.modinv(modulus).unwrap()) % modulus, + 2 => a.clone() % modulus, _ => panic!(), }; @@ -264,307 +305,411 @@ fn test_muldiv(opcode_offset: usize, modulus: BigUint) { let address2 = 128; let address3 = 256; - write_ptr_reg(&mut tester, ptr_as, addr_ptr1, address1); - write_ptr_reg(&mut tester, ptr_as, addr_ptr2, address2); - write_ptr_reg(&mut tester, ptr_as, addr_ptr3, address3); + write_ptr_reg(tester, ptr_as, addr_ptr1, address1); + write_ptr_reg(tester, ptr_as, addr_ptr2, address2); + write_ptr_reg(tester, ptr_as, addr_ptr3, address3); - let a_limbs: [BabyBear; NUM_LIMBS] = - biguint_to_limbs(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let a_limbs: [F; NUM_LIMBS] = + biguint_to_limbs(a.clone(), LIMB_BITS).map(F::from_canonical_u32); tester.write(data_as, address1 as usize, a_limbs); - let b_limbs: [BabyBear; NUM_LIMBS] = - biguint_to_limbs(b.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let b_limbs: [F; NUM_LIMBS] = + biguint_to_limbs(b.clone(), LIMB_BITS).map(F::from_canonical_u32); tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.0.core.air.offset + op), + VmOpcode::from_usize(harness.executor.offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, ptr_as as isize, data_as as isize, ); - tester.execute(&mut chip, &instruction); + tester.execute(harness, &instruction); let expected_limbs = biguint_to_limbs::(expected_answer, LIMB_BITS); - for (i, expected) in expected_limbs.into_iter().enumerate() { - let address = address3 as usize + i; - let read_val = tester.read_cell(data_as, address); - assert_eq!(BabyBear::from_canonical_u32(expected), read_val); - } + let read_vals = tester.read::(data_as, address3 as usize); + assert_eq!(read_vals, expected_limbs.map(F::from_canonical_u32)); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - - tester.simple_test().expect("Verification failed"); -} -fn test_is_equal( - opcode_offset: usize, - modulus: BigUint, - num_tests: usize, -) { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ModularIsEqualChip::::new( - Rv32IsEqualModAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn test_muldiv(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let offset = Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + // doing 1xNUM_LIMBS reads and writes + let air = get_modular_muldiv_air( + tester.execution_bridge(), tester.memory_bridge(), + config.clone(), + tester.range_checker().bus(), + bitwise_bus, tester.address_bits(), + offset, + ); + let executor = get_modular_muldiv_step( + config.clone(), + tester.range_checker().bus(), + tester.address_bits(), + offset, + ); + let chip = get_modular_muldiv_chip( + config, + tester.memory_helper(), + tester.range_checker(), bitwise_chip.clone(), - ), - ModularIsEqualCoreChip::new(modulus.clone(), bitwise_chip.clone(), opcode_offset), - tester.offline_memory_mutex_arc(), - ); + tester.address_bits(), + ); - { - let vec = big_uint_to_limbs(&modulus, LIMB_BITS); - let modulus_limbs: [F; TOTAL_LIMBS] = std::array::from_fn(|i| { - if i < vec.len() { - F::from_canonical_usize(vec[i]) - } else { - F::ZERO - } - }); + let mut harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - let setup_instruction = rv32_write_heap_default::( - &mut tester, - vec![modulus_limbs], - vec![[F::ZERO; TOTAL_LIMBS]], - opcode_offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, + for i in 0..num_ops { + set_and_execute_muldiv(&mut tester, &mut harness, &modulus, i == 0); + } + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); + + tester.simple_test().expect("Verification failed"); + } +} + +#[cfg(test)] +mod is_equal_tests { + use openvm_rv32_adapters::{ + Rv32IsEqualModAdapterAir, Rv32IsEqualModAdapterExecutor, Rv32IsEqualModAdapterFiller, + }; + use openvm_stark_backend::{ + p3_air::BaseAir, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, + }; + + use super::*; + + type Harness = + TestChipHarness< + F, + VmModularIsEqualExecutor, + ModularIsEqualAir, + ModularIsEqualChip, + >; + + fn create_test_chips< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, + >( + tester: &mut VmChipTestBuilder, + modulus: &BigUint, + modulus_limbs: [u8; TOTAL_LIMBS], + offset: usize, + ) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), + ) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); + + let air = ModularIsEqualAir::new( + Rv32IsEqualModAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + ModularIsEqualCoreAir::new(modulus.clone(), bitwise_bus, offset), ); - tester.execute(&mut chip, &setup_instruction); + let executor = VmModularIsEqualExecutor::new( + Rv32IsEqualModAdapterExecutor::new(tester.address_bits()), + offset, + modulus_limbs, + ); + let chip = ModularIsEqualChip::::new( + ModularIsEqualFiller::new( + Rv32IsEqualModAdapterFiller::new(tester.address_bits(), bitwise_chip.clone()), + offset, + modulus_limbs, + bitwise_chip.clone(), + ), + tester.memory_helper(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) } - for _ in 0..num_tests { - let b = generate_field_element::(&modulus, &mut rng); - let c = if rng.gen_bool(0.5) { - b + + #[allow(clippy::too_many_arguments)] + fn set_and_execute_is_equal< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, + >( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + modulus: &BigUint, + offset: usize, + modulus_limbs: [F; TOTAL_LIMBS], + is_setup: bool, + b: Option<[F; TOTAL_LIMBS]>, + c: Option<[F; TOTAL_LIMBS]>, + ) { + let instruction = if is_setup { + rv32_write_heap_default::( + tester, + vec![modulus_limbs], + vec![[F::ZERO; TOTAL_LIMBS]], + offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, + ) } else { - generate_field_element::(&modulus, &mut rng) + let b = b.unwrap_or( + generate_field_element::(modulus, rng) + .map(F::from_canonical_u32), + ); + let c = c.unwrap_or(if rng.gen_bool(0.5) { + b + } else { + generate_field_element::(modulus, rng) + .map(F::from_canonical_u32) + }); + + rv32_write_heap_default::( + tester, + vec![b], + vec![c], + offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, + ) }; + tester.execute(harness, &instruction); + } - let instruction = rv32_write_heap_default::( - &mut tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - opcode_offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, - ); - tester.execute(&mut chip, &instruction); + ////////////////////////////////////////////////////////////////////////////////////// + // POSITIVE TESTS + // + // Randomly generate computations and execute, ensuring that the generated trace + // passes all constraints. + ////////////////////////////////////////////////////////////////////////////////////// + + #[test] + fn test_modular_is_equal_1x32() { + test_is_equal::<1, 32, 32>(17, secp256k1_coord_prime(), 100); } - // Special case where b == c are close to the prime - let b_vec = big_uint_to_limbs(&modulus, LIMB_BITS); - let mut b = from_fn(|i| if i < b_vec.len() { b_vec[i] as u32 } else { 0 }); - b[0] -= 1; - let instruction = rv32_write_heap_default::( - &mut tester, - vec![b.map(F::from_canonical_u32)], - vec![b.map(F::from_canonical_u32)], - opcode_offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, - ); - tester.execute(&mut chip, &instruction); - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + #[test] + fn test_modular_is_equal_3x16() { + test_is_equal::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 100); + } -#[test] -fn test_modular_is_equal_1x32() { - test_is_equal::<1, 32, 32>(17, secp256k1_coord_prime(), 100); -} + fn test_is_equal( + opcode_offset: usize, + modulus: BigUint, + num_tests: usize, + ) { + let mut rng = create_seeded_rng(); + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); -#[test] -fn test_modular_is_equal_3x16() { - test_is_equal::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 100); -} + let vec = big_uint_to_limbs(&modulus, LIMB_BITS); + let modulus_limbs: [u8; TOTAL_LIMBS] = + from_fn(|i| if i < vec.len() { vec[i] as u8 } else { 0 }); -// Wrapper chip for testing a bad setup row -type BadModularIsEqualChip< - F, - const NUM_LANES: usize, - const LANE_SIZE: usize, - const TOTAL_LIMBS: usize, -> = VmChipWrapper< - F, - Rv32IsEqualModAdapterChip, - BadModularIsEqualCoreChip, ->; - -// Wrapper chip for testing a bad setup row -struct BadModularIsEqualCoreChip< - const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, -> { - chip: ModularIsEqualCoreChip, -} + let (mut harness, bitwise) = create_test_chips::( + &mut tester, + &modulus, + modulus_limbs, + opcode_offset, + ); -impl - BadModularIsEqualCoreChip -{ - pub fn new( - modulus: BigUint, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - chip: ModularIsEqualCoreChip::new(modulus, bitwise_lookup_chip, offset), + let modulus_limbs = modulus_limbs.map(F::from_canonical_u8); + + for i in 0..num_tests { + set_and_execute_is_equal( + &mut tester, + &mut harness, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + i == 0, // the first test is a setup test + None, + None, + ); } + + // Special case where b == c are close to the prime + let mut b = modulus_limbs; + b[0] -= F::ONE; + set_and_execute_is_equal( + &mut tester, + &mut harness, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + false, + Some(b), + Some(b), + ); + + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); } -} -impl< - F: PrimeField32, - I: VmAdapterInterface, + ////////////////////////////////////////////////////////////////////////////////////// + // NEGATIVE TESTS + // + // Given a fake trace of a single operation, setup a chip and run the test. We replace + // part of the trace and check that the chip throws the expected error. + ////////////////////////////////////////////////////////////////////////////////////// + + /// Negative tests test for 3 "type" of errors determined by the value of b[0]: + fn run_negative_is_equal_test< + const NUM_LANES: usize, + const LANE_SIZE: usize, const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, - > VmCoreChip for BadModularIsEqualCoreChip -where - I::Reads: Into<[[F; READ_LIMBS]; 2]>, - I::Writes: From<[[F; WRITE_LIMBS]; 1]>, -{ - type Record = ModularIsEqualCoreRecord; - type Air = ModularIsEqualCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - // Override the b_diff_idx to be out of bounds. - // This will cause lt_marker to be all zeros except a 2. - // There was a bug in this case which allowed b to be less than N. - self.chip.execute_instruction(instruction, from_pc, reads) - } + >( + modulus: BigUint, + opcode_offset: usize, + test_case: usize, + expected_error: VerificationError, + ) { + let mut rng = create_seeded_rng(); + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - fn get_opcode_name(&self, opcode: usize) -> String { - as VmCoreChip>::get_opcode_name(&self.chip, opcode) - } + let vec = big_uint_to_limbs(&modulus, LIMB_BITS); + let modulus_limbs: [u8; READ_LIMBS] = + from_fn(|i| if i < vec.len() { vec[i] as u8 } else { 0 }); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - as VmCoreChip>::generate_trace_row(&self.chip, row_slice, record.clone()); - let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut(); - // decide which bug to test based on b[0] - if record.b[0] == F::ONE { - // test the constraint that c_lt_mark = 2 when is_setup = 1 - row_slice.c_lt_mark = F::ONE; - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::ONE; - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - row_slice.b_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.b[READ_LIMBS - 1]; - } else if record.b[0] == F::from_canonical_u32(2) { - // test the constraint that b[i] = N[i] for all i when prefix_sum is not 1 or - // lt_marker_sum - is_setup - row_slice.c_lt_mark = F::from_canonical_u8(2); - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - } else if record.b[0] == F::from_canonical_u32(3) { - // test the constraint that sum_i lt_marker[i] = 2 when is_setup = 1 - row_slice.c_lt_mark = F::from_canonical_u8(2); - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); - row_slice.lt_marker[0] = F::ONE; - row_slice.b_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[0]) - record.b[0]; - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - } - } + let (mut harness, bitwise) = create_test_chips::( + &mut tester, + &modulus, + modulus_limbs, + opcode_offset, + ); - fn air(&self) -> &Self::Air { - as VmCoreChip>::air( - &self.chip, - ) - } -} + let modulus_limbs = modulus_limbs.map(F::from_canonical_u8); -// Test that passes the wrong modulus in the setup instruction. -// This proof should fail to verify. -fn test_is_equal_setup_bad< - const NUM_LANES: usize, - const LANE_SIZE: usize, - const TOTAL_LIMBS: usize, ->( - opcode_offset: usize, - modulus: BigUint, - b_val: u32, /* used to select which bug to test. currently only 1, 2, and 3 are supported - * (because there are three bugs to test) */ -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = BadModularIsEqualChip::::new( - Rv32IsEqualModAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ), - BadModularIsEqualCoreChip::new(modulus.clone(), bitwise_chip.clone(), opcode_offset), - tester.offline_memory_mutex_arc(), - ); - - let mut b_limbs = [F::ZERO; TOTAL_LIMBS]; - b_limbs[0] = F::from_canonical_u32(b_val); - let setup_instruction = rv32_write_heap_default::( - &mut tester, - vec![b_limbs], - vec![[F::ZERO; TOTAL_LIMBS]], - opcode_offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, - ); - tester.execute(&mut chip, &setup_instruction); - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + set_and_execute_is_equal( + &mut tester, + &mut harness, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + true, + None, + None, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_1_1x32() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 1); -} + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); + let cols: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = + trace_row.split_at_mut(adapter_width).1.borrow_mut(); + if test_case == 1 { + // test the constraint that c_lt_mark = 2 when is_setup = 1 + cols.b[0] = F::from_canonical_u32(1); + cols.c_lt_mark = F::ONE; + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::ONE; + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + cols.b_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.b[READ_LIMBS - 1]; + } else if test_case == 2 { + // test the constraint that b[i] = N[i] for all i when prefix_sum is not 1 or + // lt_marker_sum - is_setup + cols.b[0] = F::from_canonical_u32(2); + cols.c_lt_mark = F::from_canonical_u8(2); + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + } else if test_case == 3 { + // test the constraint that sum_i lt_marker[i] = 2 when is_setup = 1 + cols.b[0] = F::from_canonical_u32(3); + cols.c_lt_mark = F::from_canonical_u8(2); + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); + cols.lt_marker[0] = F::ONE; + cols.b_lt_diff = modulus_limbs[0] - cols.b[0]; + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_2_1x32_2() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 2); -} + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) + .finalize(); + tester.simple_test_with_expected_error(expected_error); + } -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_3_1x32() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 3); -} + #[test] + fn negative_test_modular_is_equal_1x32() { + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 1, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_1_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 1); -} + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 2, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_2_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 2); -} + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 3, + VerificationError::OodEvaluationMismatch, + ); + } + + #[test] + fn negative_test_modular_is_equal_3x16() { + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 1, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_3_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 3); + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 2, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 3, + VerificationError::OodEvaluationMismatch, + ); + } } diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index 99632d6ce3..8946daa9c3 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -1,28 +1,50 @@ -use derive_more::derive::From; +use std::{array, sync::Arc}; + use num_bigint::{BigUint, RandBigInt}; use num_traits::{FromPrimitive, One}; use openvm_algebra_transpiler::{ModularPhantom, Rv32ModularArithmeticOpcode}; use openvm_circuit::{ self, - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, - system::phantom::PhantomChip, + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{memory::SharedMemoryHelper, SystemPort}, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; +use openvm_circuit_primitives::{ + bigint::utils::big_uint_to_limbs, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, + var_range::VariableRangeCheckerBus, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::{Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip}; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_rv32_adapters::{ + Rv32IsEqualModAdapterAir, Rv32IsEqualModAdapterExecutor, Rv32IsEqualModAdapterFiller, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; use rand::Rng; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use strum::EnumCount; -use crate::modular_chip::{ - ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivChip, +use crate::{ + modular_chip::{ + get_modular_addsub_air, get_modular_addsub_chip, get_modular_addsub_step, + get_modular_muldiv_air, get_modular_muldiv_chip, get_modular_muldiv_step, ModularAir, + ModularExecutor, ModularIsEqualAir, ModularIsEqualChip, ModularIsEqualCoreAir, + ModularIsEqualFiller, VmModularIsEqualExecutor, + }, + AlgebraCpuProverExt, }; #[serde_as] @@ -46,205 +68,415 @@ impl ModularExtension { } } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] -pub enum ModularExtensionExecutor { +#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum ModularExtensionExecutor { // 32 limbs prime - ModularAddSubRv32_32(ModularAddSubChip), - ModularMulDivRv32_32(ModularMulDivChip), - ModularIsEqualRv32_32(ModularIsEqualChip), + ModularAddSubRv32_32(ModularExecutor<1, 32>), // ModularAddSub + ModularMulDivRv32_32(ModularExecutor<1, 32>), // ModularMulDiv + ModularIsEqualRv32_32(VmModularIsEqualExecutor<1, 32, 32>), // ModularIsEqual // 48 limbs prime - ModularAddSubRv32_48(ModularAddSubChip), - ModularMulDivRv32_48(ModularMulDivChip), - ModularIsEqualRv32_48(ModularIsEqualChip), -} - -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] -pub enum ModularExtensionPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - // We put this only to get the generic to work - Phantom(PhantomChip), + ModularAddSubRv32_48(ModularExecutor<3, 16>), // ModularAddSub + ModularMulDivRv32_48(ModularExecutor<3, 16>), // ModularMulDiv + ModularIsEqualRv32_48(VmModularIsEqualExecutor<3, 16, 48>), // ModularIsEqual } -impl VmExtension for ModularExtension { - type Executor = ModularExtensionExecutor; - type Periphery = ModularExtensionPeriphery; +impl VmExecutionExtension for ModularExtension { + type Executor = ModularExtensionExecutor; - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); - let SystemPort { - execution_bus, - program_bus, - memory_bridge, - } = builder.system_port(); - let range_checker = builder.system_base().range_checker_chip.clone(); - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip - }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; - - let addsub_opcodes = (Rv32ModularArithmeticOpcode::ADD as usize) - ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize); - let muldiv_opcodes = (Rv32ModularArithmeticOpcode::MUL as usize) - ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize); - let iseq_opcodes = (Rv32ModularArithmeticOpcode::IS_EQ as usize) - ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize); - + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + // TODO: somehow get the range checker bus from `ExecutorInventory` + let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16); for (i, modulus) in self.supported_moduli.iter().enumerate() { // determine the number of bytes needed to represent a prime field element let bytes = modulus.bits().div_ceil(8); let start_offset = Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT; - - let config32 = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: 32, - limb_bits: 8, - }; - let config48 = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: 48, - limb_bits: 8, - }; - let adapter_chip_32 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - let adapter_chip_48 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - + let modulus_limbs = big_uint_to_limbs(modulus, 8); if bytes <= 32 { - let addsub_chip = ModularAddSubChip::new( - adapter_chip_32.clone(), - config32.clone(), + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + let addsub = get_modular_addsub_step( + config.clone(), + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - ModularExtensionExecutor::ModularAddSubRv32_32(addsub_chip), - addsub_opcodes - .clone() + ModularExtensionExecutor::ModularAddSubRv32_32(addsub), + ((Rv32ModularArithmeticOpcode::ADD as usize) + ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; - let muldiv_chip = ModularMulDivChip::new( - adapter_chip_32.clone(), - config32.clone(), + + let muldiv = get_modular_muldiv_step( + config, + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - ModularExtensionExecutor::ModularMulDivRv32_32(muldiv_chip), - muldiv_opcodes - .clone() + ModularExtensionExecutor::ModularMulDivRv32_32(muldiv), + ((Rv32ModularArithmeticOpcode::MUL as usize) + ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; - let isequal_chip = ModularIsEqualChip::new( - Rv32IsEqualModAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ), - ModularIsEqualCoreChip::new( - modulus.clone(), - bitwise_lu_chip.clone(), - start_offset, - ), - offline_memory.clone(), + + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); + + let is_eq = VmModularIsEqualExecutor::new( + Rv32IsEqualModAdapterExecutor::new(pointer_max_bits), + start_offset, + modulus_limbs, ); + inventory.add_executor( - ModularExtensionExecutor::ModularIsEqualRv32_32(isequal_chip), - iseq_opcodes - .clone() + ModularExtensionExecutor::ModularIsEqualRv32_32(is_eq), + ((Rv32ModularArithmeticOpcode::IS_EQ as usize) + ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; } else if bytes <= 48 { - let addsub_chip = ModularAddSubChip::new( - adapter_chip_48.clone(), - config48.clone(), + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + let addsub = get_modular_addsub_step( + config.clone(), + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - ModularExtensionExecutor::ModularAddSubRv32_48(addsub_chip), - addsub_opcodes - .clone() + ModularExtensionExecutor::ModularAddSubRv32_48(addsub), + ((Rv32ModularArithmeticOpcode::ADD as usize) + ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; - let muldiv_chip = ModularMulDivChip::new( - adapter_chip_48.clone(), - config48.clone(), + + let muldiv = get_modular_muldiv_step( + config, + dummy_range_checker_bus, + pointer_max_bits, start_offset, - range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( - ModularExtensionExecutor::ModularMulDivRv32_48(muldiv_chip), - muldiv_opcodes - .clone() + ModularExtensionExecutor::ModularMulDivRv32_48(muldiv), + ((Rv32ModularArithmeticOpcode::MUL as usize) + ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; - let isequal_chip = ModularIsEqualChip::new( - Rv32IsEqualModAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ), - ModularIsEqualCoreChip::new( - modulus.clone(), - bitwise_lu_chip.clone(), - start_offset, - ), - offline_memory.clone(), + + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); + + let is_eq = VmModularIsEqualExecutor::new( + Rv32IsEqualModAdapterExecutor::new(pointer_max_bits), + start_offset, + modulus_limbs, ); + inventory.add_executor( - ModularExtensionExecutor::ModularIsEqualRv32_48(isequal_chip), - iseq_opcodes - .clone() + ModularExtensionExecutor::ModularIsEqualRv32_48(is_eq), + ((Rv32ModularArithmeticOpcode::IS_EQ as usize) + ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; } else { panic!("Modulus too large"); } } + let non_qr_hint_sub_ex = phantom::NonQrHintSubEx::new(self.supported_moduli.clone()); - builder.add_phantom_sub_executor( + inventory.add_phantom_sub_executor( non_qr_hint_sub_ex.clone(), PhantomDiscriminant(ModularPhantom::HintNonQr as u16), )?; let sqrt_hint_sub_ex = phantom::SqrtHintSubEx::new(non_qr_hint_sub_ex); - builder.add_phantom_sub_executor( + inventory.add_phantom_sub_executor( sqrt_hint_sub_ex, PhantomDiscriminant(ModularPhantom::HintSqrt as u16), )?; - Ok(inventory) + Ok(()) + } +} + +impl VmCircuitExtension for ModularExtension { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = inventory.system().port(); + + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker_bus = inventory.range_checker().bus; + let pointer_max_bits = inventory.pointer_max_bits(); + + let bitwise_lu = { + // A trick to get around Rust's borrow rules + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } + }; + for (i, modulus) in self.supported_moduli.iter().enumerate() { + // determine the number of bytes needed to represent a prime field element + let bytes = modulus.bits().div_ceil(8); + let start_offset = + Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT; + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + let addsub = get_modular_addsub_air::<1, 32>( + exec_bridge, + memory_bridge, + config.clone(), + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(addsub); + + let muldiv = get_modular_muldiv_air::<1, 32>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(muldiv); + + let is_eq = ModularIsEqualAir::<1, 32, 32>::new( + Rv32IsEqualModAdapterAir::new( + exec_bridge, + memory_bridge, + bitwise_lu, + pointer_max_bits, + ), + ModularIsEqualCoreAir::new(modulus.clone(), bitwise_lu, start_offset), + ); + inventory.add_air(is_eq); + } else if bytes <= 48 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + + let addsub = get_modular_addsub_air::<3, 16>( + exec_bridge, + memory_bridge, + config.clone(), + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(addsub); + + let muldiv = get_modular_muldiv_air::<3, 16>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(muldiv); + + let is_eq = ModularIsEqualAir::<3, 16, 48>::new( + Rv32IsEqualModAdapterAir::new( + exec_bridge, + memory_bridge, + bitwise_lu, + pointer_max_bits, + ), + ModularIsEqualCoreAir::new(modulus.clone(), bitwise_lu, start_offset), + ); + inventory.add_air(is_eq); + } else { + panic!("Modulus too large"); + } + } + + Ok(()) + } +} + +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for AlgebraCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + extension: &ModularExtension, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + for (i, modulus) in extension.supported_moduli.iter().enumerate() { + // determine the number of bytes needed to represent a prime field element + let bytes = modulus.bits().div_ceil(8); + let start_offset = + Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT; + + let modulus_limbs = big_uint_to_limbs(modulus, 8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + inventory.next_air::>()?; + let addsub = get_modular_addsub_chip::, 1, 32>( + config.clone(), + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(addsub); + + inventory.next_air::>()?; + let muldiv = get_modular_muldiv_chip::, 1, 32>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(muldiv); + + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); + inventory.next_air::>()?; + let is_eq = ModularIsEqualChip::, 1, 32, 32>::new( + ModularIsEqualFiller::new( + Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + start_offset, + modulus_limbs, + bitwise_lu.clone(), + ), + mem_helper.clone(), + ); + inventory.add_executor_chip(is_eq); + } else if bytes <= 48 { + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + + inventory.next_air::>()?; + let addsub = get_modular_addsub_chip::, 3, 16>( + config.clone(), + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(addsub); + + inventory.next_air::>()?; + let muldiv = get_modular_muldiv_chip::, 3, 16>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(muldiv); + + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); + inventory.next_air::>()?; + let is_eq = ModularIsEqualChip::, 3, 16, 48>::new( + ModularIsEqualFiller::new( + Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + start_offset, + modulus_limbs, + bitwise_lu.clone(), + ), + mem_helper.clone(), + ); + inventory.add_executor_chip(is_eq); + } else { + panic!("Modulus too large"); + } + } + + Ok(()) } } @@ -258,10 +490,10 @@ pub(crate) mod phantom { use num_bigint::BigUint; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant}; - use openvm_rv32im_circuit::adapters::unsafe_read_rv32_register; + use openvm_rv32im_circuit::adapters::read_rv32_register; use openvm_stark_backend::p3_field::PrimeField32; use rand::{rngs::StdRng, SeedableRng}; @@ -282,12 +514,13 @@ pub(crate) mod phantom { // Note that non_qr is fixed for each modulus. impl PhantomSubExecutor for SqrtHintSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { let mod_idx = c_upper as usize; @@ -306,15 +539,12 @@ pub(crate) mod phantom { bail!("Modulus too large") }; - let rs1 = unsafe_read_rv32_register(memory, a); - let mut x_limbs: Vec = Vec::with_capacity(num_limbs); - for i in 0..num_limbs { - let limb = memory.unsafe_read_cell( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1 + i as u32), - ); - x_limbs.push(limb.as_canonical_u32() as u8); - } + let rs1 = read_rv32_register(memory, a); + // SAFETY: + // - MEMORY_AS consists of `u8`s + // - MEMORY_AS is in bounds + let x_limbs: Vec = + unsafe { memory.memory.get_slice((RV32_MEMORY_AS, rs1), num_limbs) }.to_vec(); let x = BigUint::from_bytes_le(&x_limbs); let (success, sqrt) = match mod_sqrt(&x, modulus, &self.non_qrs[mod_idx]) { @@ -372,12 +602,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NonQrHintSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { let mod_idx = c_upper as usize; diff --git a/extensions/algebra/complex-macros/README.md b/extensions/algebra/complex-macros/README.md index 8536b9c371..fb510736c4 100644 --- a/extensions/algebra/complex-macros/README.md +++ b/extensions/algebra/complex-macros/README.md @@ -22,7 +22,7 @@ openvm_algebra_moduli_macros::moduli_init!( ); openvm_algebra_complex_macros::complex_init! { - Complex { mod_idx = 0 }, + "Complex" { mod_idx = 0 }, } */ @@ -39,7 +39,7 @@ The crate provides two macros: `complex_declare!` and `complex_init!`. The signa - `complex_declare!` receives comma-separated list of moduli classes descriptions. Each description looks like `ComplexStruct { mod_type = ModulusName }`. Here `ModulusName` is the name of any struct that implements `trait IntMod` -- in particular, the ones created by `moduli_declare!` do, and `ComplexStruct` is the name for the complex arithmetic struct to create. -- `complex_init!` receives comma-separated list of struct descriptions. Each description looks like `ComplexStruct { mod_idx = idx }`. Here `ComplexStruct` is the name of the complex struct used in `complex_declare!`, and `idx` is the index of the modulus **in the `moduli_init!` macro**. +- `complex_init!` receives comma-separated list of struct descriptions. Each description looks like `"ComplexStruct" { mod_idx = idx }`. Here `ComplexStruct` is the name of the complex struct used in `complex_declare!`, and `idx` is the index of the modulus **in the `moduli_init!` macro**. What happens under the hood: @@ -96,7 +96,7 @@ complex_declare! { pub type Fp2 = Bn254Fp2; complex_init! { - Fp2 { mod_idx = 0 }, + "Fp2" { mod_idx = 0 }, } ``` diff --git a/extensions/algebra/complex-macros/src/lib.rs b/extensions/algebra/complex-macros/src/lib.rs index ba25ee7279..dfd31f66fa 100644 --- a/extensions/algebra/complex-macros/src/lib.rs +++ b/extensions/algebra/complex-macros/src/lib.rs @@ -1,10 +1,12 @@ extern crate proc_macro; -use openvm_macros_common::MacroArgs; +use openvm_macros_common::{MacroArgs, Param}; use proc_macro::TokenStream; use syn::{ parse::{Parse, ParseStream}, - parse_macro_input, Expr, ExprPath, Path, Token, + parse_macro_input, + punctuated::Punctuated, + Expr, ExprPath, LitStr, Path, Token, }; /// This macro is used to declare the complex extension fields. @@ -529,6 +531,38 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { TokenStream::from_iter(output) } +// Override the MacroArgs struct to use LitStr for item names instead of Ident. +// This removes the need to import the complex struct when using the complex_init macro. +struct ComplexInitArgs { + pub items: Vec, +} + +struct ComplexInitItem { + pub name: LitStr, + pub params: Punctuated, +} + +impl Parse for ComplexInitArgs { + fn parse(input: ParseStream) -> syn::Result { + Ok(ComplexInitArgs { + items: input + .parse_terminated(ComplexInitItem::parse, Token![,])? + .into_iter() + .collect(), + }) + } +} + +impl Parse for ComplexInitItem { + fn parse(input: ParseStream) -> syn::Result { + let name = input.parse()?; + let content; + syn::braced!(content in input); + let params = content.parse_terminated(Param::parse, Token![,])?; + Ok(ComplexInitItem { name, params }) + } +} + /// This macro is used to initialize the complex extension fields. /// It must be called after `moduli_init!` is called. /// @@ -543,14 +577,14 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { /// the `moduli_init!` macro (not `moduli_declare!`). #[proc_macro] pub fn complex_init(input: TokenStream) -> TokenStream { - let MacroArgs { items } = parse_macro_input!(input as MacroArgs); + let ComplexInitArgs { items } = parse_macro_input!(input as ComplexInitArgs); let mut externs = Vec::new(); let span = proc_macro::Span::call_site(); for (complex_idx, item) in items.into_iter().enumerate() { - let struct_name = item.name.to_string(); + let struct_name = item.name.value(); let struct_name = syn::Ident::new(&struct_name, span.into()); let mut intmod_idx: Option = None; for param in item.params { diff --git a/extensions/algebra/moduli-macros/src/lib.rs b/extensions/algebra/moduli-macros/src/lib.rs index fc30341195..0dc7128588 100644 --- a/extensions/algebra/moduli-macros/src/lib.rs +++ b/extensions/algebra/moduli-macros/src/lib.rs @@ -965,7 +965,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { let ModuliDefine { items } = parse_macro_input!(input as ModuliDefine); let mut externs = Vec::new(); - let mut openvm_section = Vec::new(); // List of all modular limbs in one (that is, with a compile-time known size) array. let mut two_modular_limbs_flattened_list = Vec::::new(); @@ -976,8 +975,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { for (mod_idx, item) in items.into_iter().enumerate() { let modulus = item.value(); - println!("[init] modulus #{} = {}", mod_idx, modulus); - let modulus_bytes = string_to_bytes(&modulus); let mut limbs = modulus_bytes.len(); let mut block_size = 32; @@ -1012,31 +1009,11 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { .collect::>() .join(""); - let serialized_modulus = - core::iter::once(1) // 1 for "modulus" - .chain(core::iter::once(mod_idx as u8)) // mod_idx is u8 for now (can make it u32), because we don't know the order of - // variables in the elf - .chain((modulus_bytes.len() as u32).to_le_bytes().iter().copied()) - .chain(modulus_bytes.iter().copied()) - .collect::>(); - let serialized_name = syn::Ident::new( - &format!("OPENVM_SERIALIZED_MODULUS_{}", mod_idx), - span.into(), - ); - let serialized_len = serialized_modulus.len(); let setup_extern_func = syn::Ident::new( &format!("moduli_setup_extern_func_{}", modulus_hex), span.into(), ); - openvm_section.push(quote::quote_spanned! { span.into() => - #[cfg(target_os = "zkvm")] - #[link_section = ".openvm"] - #[no_mangle] - #[used] - static #serialized_name: [u8; #serialized_len] = [#(#serialized_modulus),*]; - }); - for op_type in ["add", "sub", "mul", "div"] { let func_name = syn::Ident::new( &format!("{}_extern_func_{}", op_type, modulus_hex), @@ -1126,19 +1103,12 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { extern "C" fn #setup_extern_func() { #[cfg(target_os = "zkvm")] { - let mut ptr = 0; - assert_eq!(super::#serialized_name[ptr], 1); - ptr += 1; - assert_eq!(super::#serialized_name[ptr], #mod_idx as u8); - ptr += 1; - assert_eq!(super::#serialized_name[ptr..ptr+4].iter().rev().fold(0, |acc, &x| acc * 256 + x as usize), #limbs); - ptr += 4; - let remaining = &super::#serialized_name[ptr..]; - // To avoid importing #struct_name, we create a placeholder struct with the same size and alignment. #[repr(C, align(#block_size))] struct AlignedPlaceholder([u8; #limbs]); + const MODULUS_BYTES: AlignedPlaceholder = AlignedPlaceholder([#(#modulus_bytes),*]); + // We are going to use the numeric representation of the `rs2` register to distinguish the chip to setup. // The transpiler will transform this instruction, based on whether `rs2` is `x0`, `x1` or `x2`, into a `SETUP_ADDSUB`, `SETUP_MULDIV` or `SETUP_ISEQ` instruction. let mut uninit: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); @@ -1149,7 +1119,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = In uninit.as_mut_ptr(), - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_ADDMOD ); openvm::platform::custom_insn_r!( @@ -1159,7 +1129,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = In uninit.as_mut_ptr(), - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x1" // will be parsed as 1 and therefore transpiled to SETUP_MULDIV ); unsafe { @@ -1172,7 +1142,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = InOut tmp, - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x2" // will be parsed as 2 and therefore transpiled to SETUP_ISEQ ); // rd = inout(reg) is necessary because this instruction will write to `rd` register @@ -1185,7 +1155,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { let total_limbs_cnt = two_modular_limbs_flattened_list.len(); let cnt_limbs_list_len = limb_list_borders.len(); TokenStream::from(quote::quote_spanned! { span.into() => - #(#openvm_section)* #[allow(non_snake_case)] #[cfg(target_os = "zkvm")] mod openvm_intrinsics_ffi { diff --git a/extensions/algebra/tests/programs/openvm_init_complex_redundant_modulus.rs b/extensions/algebra/tests/programs/openvm_init_complex_redundant_modulus.rs index c32e692510..523c9d5c6f 100644 --- a/extensions/algebra/tests/programs/openvm_init_complex_redundant_modulus.rs +++ b/extensions/algebra/tests/programs/openvm_init_complex_redundant_modulus.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "998244353", "1000000007", "1000000009", "987898789" } -openvm_algebra_guest::complex_macros::complex_init! { Complex2 { mod_idx = 2 } } +openvm_algebra_guest::complex_macros::complex_init! { "Complex2" { mod_idx = 2 } } diff --git a/extensions/algebra/tests/programs/openvm_init_complex_secp256k1.rs b/extensions/algebra/tests/programs/openvm_init_complex_secp256k1.rs index af98350ae4..938958f392 100644 --- a/extensions/algebra/tests/programs/openvm_init_complex_secp256k1.rs +++ b/extensions/algebra/tests/programs/openvm_init_complex_secp256k1.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663" } -openvm_algebra_guest::complex_macros::complex_init! { Complex { mod_idx = 0 } } +openvm_algebra_guest::complex_macros::complex_init! { "Complex" { mod_idx = 0 } } diff --git a/extensions/algebra/tests/programs/openvm_init_complex_two_moduli.rs b/extensions/algebra/tests/programs/openvm_init_complex_two_moduli.rs index 1de98b97a1..75324d5509 100644 --- a/extensions/algebra/tests/programs/openvm_init_complex_two_moduli.rs +++ b/extensions/algebra/tests/programs/openvm_init_complex_two_moduli.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "998244353", "1000000007" } -openvm_algebra_guest::complex_macros::complex_init! { Complex1 { mod_idx = 0 }, Complex2 { mod_idx = 1 } } +openvm_algebra_guest::complex_macros::complex_init! { "Complex1" { mod_idx = 0 }, "Complex2" { mod_idx = 1 } } diff --git a/extensions/algebra/tests/src/lib.rs b/extensions/algebra/tests/src/lib.rs index 181f592544..6901018359 100644 --- a/extensions/algebra/tests/src/lib.rs +++ b/extensions/algebra/tests/src/lib.rs @@ -5,10 +5,11 @@ mod tests { use eyre::Result; use num_bigint::BigUint; use openvm_algebra_circuit::{ - Fp2Extension, ModularExtension, Rv32ModularConfig, Rv32ModularWithFp2Config, + Fp2Extension, Rv32ModularConfig, Rv32ModularCpuBuilder, Rv32ModularWithFp2Config, + Rv32ModularWithFp2CpuBuilder, }; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; - use openvm_circuit::{arch::SystemConfig, utils::air_test}; + use openvm_circuit::utils::{air_test, test_system_config_with_continuations}; use openvm_ecc_circuit::SECP256K1_CONFIG; use openvm_instructions::exe::VmExe; use openvm_rv32im_transpiler::{ @@ -20,11 +21,27 @@ mod tests { type F = BabyBear; + #[cfg(test)] + fn test_rv32modular_config(moduli: Vec) -> Rv32ModularConfig { + let mut config = Rv32ModularConfig::new(moduli); + config.system = test_system_config_with_continuations(); + config + } + + #[cfg(test)] + fn test_rv32modularwithfp2_config( + moduli_with_names: Vec<(String, BigUint)>, + ) -> Rv32ModularWithFp2Config { + let mut config = Rv32ModularWithFp2Config::new(moduli_with_names); + *config.as_mut() = test_system_config_with_continuations(); + config + } + #[test] fn test_moduli_setup() -> Result<()> { let moduli = ["4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787", "1000000000000000003", "2305843009213693951"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path(get_programs_dir!(), "moduli_setup", &config)?; let openvm_exe = VmExe::from_elf( elf, @@ -35,13 +52,13 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_modular() -> Result<()> { - let config = Rv32ModularConfig::new(vec![SECP256K1_CONFIG.modulus.clone()]); + let config = test_rv32modular_config(vec![SECP256K1_CONFIG.modulus.clone()]); let elf = build_example_program_at_path(get_programs_dir!(), "little", &config)?; let openvm_exe = VmExe::from_elf( elf, @@ -51,13 +68,13 @@ mod tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_complex_two_moduli() -> Result<()> { - let config = Rv32ModularWithFp2Config::new(vec![ + let config = test_rv32modularwithfp2_config(vec![ ( "Complex1".to_string(), BigUint::from_str("998244353").unwrap(), @@ -78,18 +95,14 @@ mod tests { .with_extension(Fp2TranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularWithFp2CpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_complex_redundant_modulus() -> Result<()> { let config = Rv32ModularWithFp2Config { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(vec![ + modular: test_rv32modular_config(vec![ BigUint::from_str("998244353").unwrap(), BigUint::from_str("1000000007").unwrap(), BigUint::from_str("1000000009").unwrap(), @@ -114,13 +127,13 @@ mod tests { .with_extension(Fp2TranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularWithFp2CpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_complex() -> Result<()> { - let config = Rv32ModularWithFp2Config::new(vec![( + let config = test_rv32modularwithfp2_config(vec![( "Complex".to_string(), SECP256K1_CONFIG.modulus.clone(), )]); @@ -134,14 +147,14 @@ mod tests { .with_extension(Fp2TranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularWithFp2CpuBuilder, config, openvm_exe); Ok(()) } #[test] #[should_panic] fn test_invalid_setup() { - let config = Rv32ModularConfig::new(vec![ + let config = test_rv32modular_config(vec![ BigUint::from_str("998244353").unwrap(), BigUint::from_str("1000000007").unwrap(), ]); @@ -163,12 +176,12 @@ mod tests { .with_extension(ModularTranspilerExtension), ) .unwrap(); - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); } #[test] fn test_sqrt() -> Result<()> { - let config = Rv32ModularConfig::new(vec![SECP256K1_CONFIG.modulus.clone()]); + let config = test_rv32modular_config(vec![SECP256K1_CONFIG.modulus.clone()]); let elf = build_example_program_at_path(get_programs_dir!(), "sqrt", &config)?; let openvm_exe = VmExe::from_elf( elf, @@ -178,7 +191,7 @@ mod tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } } diff --git a/extensions/bigint/circuit/Cargo.toml b/extensions/bigint/circuit/Cargo.toml index 09d68a9d1b..aa9114c34a 100644 --- a/extensions/bigint/circuit/Cargo.toml +++ b/extensions/bigint/circuit/Cargo.toml @@ -29,6 +29,8 @@ serde.workspace = true openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } +test-case.workspace = true +alloy-primitives = { version = "1.2.1" } [features] default = ["parallel", "jemalloc"] diff --git a/extensions/bigint/circuit/src/base_alu.rs b/extensions/bigint/circuit/src/base_alu.rs new file mode 100644 index 0000000000..55685115ab --- /dev/null +++ b/extensions/bigint/circuit/src/base_alu.rs @@ -0,0 +1,239 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::transmute, +}; + +use openvm_bigint_transpiler::Rv32BaseAlu256Opcode; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::Rv32HeapAdapterExecutor; +use openvm_rv32im_circuit::BaseAluExecutor; +use openvm_rv32im_transpiler::BaseAluOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{Rv32BaseAlu256Executor, INT256_NUM_LIMBS}; + +type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32BaseAlu256Executor { + pub fn new(adapter: AdapterExecutor, offset: usize) -> Self { + Self(BaseAluExecutor::new(adapter, offset)) + } +} + +#[derive(AlignedBytesBorrow)] +struct BaseAluPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor for Rv32BaseAlu256Executor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BaseAluPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BaseAluOpcode::ADD => execute_e1_impl::<_, _, AddOp>, + BaseAluOpcode::SUB => execute_e1_impl::<_, _, SubOp>, + BaseAluOpcode::XOR => execute_e1_impl::<_, _, XorOp>, + BaseAluOpcode::OR => execute_e1_impl::<_, _, OrOp>, + BaseAluOpcode::AND => execute_e1_impl::<_, _, AndOp>, + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for Rv32BaseAlu256Executor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BaseAluOpcode::ADD => execute_e2_impl::<_, _, AddOp>, + BaseAluOpcode::SUB => execute_e2_impl::<_, _, SubOp>, + BaseAluOpcode::XOR => execute_e2_impl::<_, _, XorOp>, + BaseAluOpcode::OR => execute_e2_impl::<_, _, OrOp>, + BaseAluOpcode::AND => execute_e2_impl::<_, _, AndOp>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BaseAluPreCompute, + vm_state: &mut VmExecState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let rd = ::compute(rs1, rs2); + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &BaseAluPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32BaseAlu256Executor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BaseAluPreCompute, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = BaseAluPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + let local_opcode = + BaseAluOpcode::from_usize(opcode.local_opcode_idx(Rv32BaseAlu256Opcode::CLASS_OFFSET)); + Ok(local_opcode) + } +} + +trait AluOp { + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS]; +} +struct AddOp; +struct SubOp; +struct XorOp; +struct OrOp; +struct AndOp; +impl AluOp for AddOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + let (res, mut carry) = rs1_u64[0].overflowing_add(rs2_u64[0]); + rd_u64[0] = res; + for i in 1..4 { + let (res1, c1) = rs1_u64[i].overflowing_add(rs2_u64[i]); + let (res2, c2) = res1.overflowing_add(carry as u64); + carry = c1 || c2; + rd_u64[i] = res2; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for SubOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + let (res, mut borrow) = rs1_u64[0].overflowing_sub(rs2_u64[0]); + rd_u64[0] = res; + for i in 1..4 { + let (res1, c1) = rs1_u64[i].overflowing_sub(rs2_u64[i]); + let (res2, c2) = res1.overflowing_sub(borrow as u64); + borrow = c1 || c2; + rd_u64[i] = res2; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for XorOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + // Compiler will expand this loop. + for i in 0..4 { + rd_u64[i] = rs1_u64[i] ^ rs2_u64[i]; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for OrOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + // Compiler will expand this loop. + for i in 0..4 { + rd_u64[i] = rs1_u64[i] | rs2_u64[i]; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for AndOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + // Compiler will expand this loop. + for i in 0..4 { + rd_u64[i] = rs1_u64[i] & rs2_u64[i]; + } + unsafe { transmute(rd_u64) } + } +} diff --git a/extensions/bigint/circuit/src/branch_eq.rs b/extensions/bigint/circuit/src/branch_eq.rs new file mode 100644 index 0000000000..782af76a6a --- /dev/null +++ b/extensions/bigint/circuit/src/branch_eq.rs @@ -0,0 +1,170 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32BranchEqual256Opcode; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor; +use openvm_rv32im_circuit::BranchEqualExecutor; +use openvm_rv32im_transpiler::BranchEqualOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{Rv32BranchEqual256Executor, INT256_NUM_LIMBS}; + +type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>; + +impl Rv32BranchEqual256Executor { + pub fn new(adapter_step: AdapterExecutor, offset: usize, pc_step: u32) -> Self { + Self(BranchEqualExecutor::new(adapter_step, offset, pc_step)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchEqPreCompute { + imm: isize, + a: u8, + b: u8, +} + +impl Executor for Rv32BranchEqual256Executor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BranchEqPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BranchEqualOpcode::BEQ => execute_e1_impl::<_, _, false>, + BranchEqualOpcode::BNE => execute_e1_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for Rv32BranchEqual256Executor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BranchEqualOpcode::BEQ => execute_e2_impl::<_, _, false>, + BranchEqualOpcode::BNE => execute_e2_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchEqPreCompute, + vm_state: &mut VmExecState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let cmp_result = u256_eq(rs1, rs2); + if cmp_result ^ IS_NE { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + } + + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &BranchEqPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32BranchEqual256Executor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchEqPreCompute, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = BranchEqPreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + let local_opcode = BranchEqualOpcode::from_usize( + opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET), + ); + Ok(local_opcode) + } +} + +fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + for i in 0..4 { + if rs1_u64[i] != rs2_u64[i] { + return false; + } + } + true +} diff --git a/extensions/bigint/circuit/src/branch_lt.rs b/extensions/bigint/circuit/src/branch_lt.rs new file mode 100644 index 0000000000..da93d7ca22 --- /dev/null +++ b/extensions/bigint/circuit/src/branch_lt.rs @@ -0,0 +1,198 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::Rv32HeapBranchAdapterExecutor; +use openvm_rv32im_circuit::BranchLessThanExecutor; +use openvm_rv32im_transpiler::BranchLessThanOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + common::{i256_lt, u256_lt}, + Rv32BranchLessThan256Executor, INT256_NUM_LIMBS, +}; + +type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>; + +impl Rv32BranchLessThan256Executor { + pub fn new(adapter: AdapterExecutor, offset: usize) -> Self { + Self(BranchLessThanExecutor::new(adapter, offset)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchLtPreCompute { + imm: isize, + a: u8, + b: u8, +} + +impl Executor for Rv32BranchLessThan256Executor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BranchLtPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e1_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e1_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e1_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e1_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for Rv32BranchLessThan256Executor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e2_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e2_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e2_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e2_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchLtPreCompute, + vm_state: &mut VmExecState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let cmp_result = OP::compute(rs1, rs2); + if cmp_result { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + } + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &BranchLtPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32BranchLessThan256Executor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchLtPreCompute, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = BranchLtPreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + let local_opcode = BranchLessThanOpcode::from_usize( + opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET), + ); + Ok(local_opcode) + } +} + +trait BranchLessThanOp { + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool; +} +struct BltOp; +struct BltuOp; +struct BgeOp; +struct BgeuOp; + +impl BranchLessThanOp for BltOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + i256_lt(rs1, rs2) + } +} +impl BranchLessThanOp for BltuOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + u256_lt(rs1, rs2) + } +} +impl BranchLessThanOp for BgeOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + !i256_lt(rs1, rs2) + } +} +impl BranchLessThanOp for BgeuOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + !u256_lt(rs1, rs2) + } +} diff --git a/extensions/bigint/circuit/src/common.rs b/extensions/bigint/circuit/src/common.rs new file mode 100644 index 0000000000..14c49ce68c --- /dev/null +++ b/extensions/bigint/circuit/src/common.rs @@ -0,0 +1,66 @@ +use crate::{INT256_NUM_LIMBS, RV32_CELL_BITS}; + +#[inline(always)] +pub(crate) fn u256_lt(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + for i in (0..4).rev() { + if rs1_u64[i] != rs2_u64[i] { + return rs1_u64[i] < rs2_u64[i]; + } + } + false +} + +#[inline(always)] +pub(crate) fn i256_lt(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + // true for negative. false for positive + let rs1_sign = rs1[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) == 1; + let rs2_sign = rs2[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) == 1; + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + for i in (0..4).rev() { + if rs1_u64[i] != rs2_u64[i] { + return (rs1_u64[i] < rs2_u64[i]) ^ rs1_sign ^ rs2_sign; + } + } + false +} + +#[cfg(test)] +mod tests { + use alloy_primitives::{I256, U256}; + use rand::{prelude::StdRng, Rng, SeedableRng}; + + use crate::{ + common::{i256_lt, u256_lt}, + INT256_NUM_LIMBS, + }; + + #[test] + fn test_u256_lt() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u64; 4] = rng.gen(); + let limbs_b: [u64; 4] = rng.gen(); + let a = U256::from_limbs(limbs_a); + let b = U256::from_limbs(limbs_b); + let a_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_a) }; + let b_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_b) }; + assert_eq!(u256_lt(a_u8, b_u8), a < b); + } + } + #[test] + fn test_i256_lt() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u64; 4] = rng.gen(); + let limbs_b: [u64; 4] = rng.gen(); + let a = I256::from_limbs(limbs_a); + let b = I256::from_limbs(limbs_b); + let a_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_a) }; + let b_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_b) }; + assert_eq!(i256_lt(a_u8, b_u8), a < b); + } + } +} diff --git a/extensions/bigint/circuit/src/extension.rs b/extensions/bigint/circuit/src/extension.rs index 390b79cc63..1bbd9c80d7 100644 --- a/extensions/bigint/circuit/src/extension.rs +++ b/extensions/bigint/circuit/src/extension.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use derive_more::derive::From; use openvm_bigint_transpiler::{ Rv32BaseAlu256Opcode, Rv32BranchEqual256Opcode, Rv32BranchLessThan256Opcode, @@ -5,56 +7,35 @@ use openvm_bigint_transpiler::{ }; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, - VmInventoryError, + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, }, - system::phantom::PhantomChip, + system::{memory::SharedMemoryHelper, SystemPort}, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, + range_tuple::{ + RangeTupleCheckerAir, RangeTupleCheckerBus, RangeTupleCheckerChip, + SharedRangeTupleCheckerChip, + }, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode}; -use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, }; -use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use crate::*; -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Int256Rv32Config { - #[system] - pub system: SystemConfig, - #[extension] - pub rv32i: Rv32I, - #[extension] - pub rv32m: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub bigint: Int256, -} - -// Default implementation uses no init file -impl InitFileGenerator for Int256Rv32Config {} - -impl Default for Int256Rv32Config { - fn default() -> Self { - Self { - system: SystemConfig::default().with_continuations(), - rv32i: Rv32I, - rv32m: Rv32M::default(), - io: Rv32Io, - bigint: Int256::default(), - } - } -} - +// =================================== VM Extension Implementation ================================= #[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub struct Int256 { #[serde(default = "default_range_tuple_checker_sizes")] @@ -73,172 +54,272 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { [1 << 8, 32 * (1 << 8)] } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum Int256Executor { - BaseAlu256(Rv32BaseAlu256Chip), - LessThan256(Rv32LessThan256Chip), - BranchEqual256(Rv32BranchEqual256Chip), - BranchLessThan256(Rv32BranchLessThan256Chip), - Multiplication256(Rv32Multiplication256Chip), - Shift256(Rv32Shift256Chip), +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Int256Executor { + BaseAlu256(Rv32BaseAlu256Executor), + LessThan256(Rv32LessThan256Executor), + BranchEqual256(Rv32BranchEqual256Executor), + BranchLessThan256(Rv32BranchLessThan256Executor), + Multiplication256(Rv32Multiplication256Executor), + Shift256(Rv32Shift256Executor), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Int256Periphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - /// Only needed for multiplication extension - RangeTupleChecker(SharedRangeTupleCheckerChip<2>), - Phantom(PhantomChip), -} +impl VmExecutionExtension for Int256 { + type Executor = Int256Executor; -impl VmExtension for Int256 { - type Executor = Int256Executor; - type Periphery = Int256Periphery; - - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + + let alu = Rv32BaseAlu256Executor::new( + Rv32HeapAdapterExecutor::new(pointer_max_bits), + Rv32BaseAlu256Opcode::CLASS_OFFSET, + ); + inventory.add_executor(alu, Rv32BaseAlu256Opcode::iter().map(|x| x.global_opcode()))?; + + let lt = Rv32LessThan256Executor::new( + Rv32HeapAdapterExecutor::new(pointer_max_bits), + Rv32LessThan256Opcode::CLASS_OFFSET, + ); + inventory.add_executor(lt, Rv32LessThan256Opcode::iter().map(|x| x.global_opcode()))?; + + let beq = Rv32BranchEqual256Executor::new( + Rv32HeapBranchAdapterExecutor::new(pointer_max_bits), + Rv32BranchEqual256Opcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ); + inventory.add_executor( + beq, + Rv32BranchEqual256Opcode::iter().map(|x| x.global_opcode()), + )?; + + let blt = Rv32BranchLessThan256Executor::new( + Rv32HeapBranchAdapterExecutor::new(pointer_max_bits), + Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ); + inventory.add_executor( + blt, + Rv32BranchLessThan256Opcode::iter().map(|x| x.global_opcode()), + )?; + + let mult = Rv32Multiplication256Executor::new( + Rv32HeapAdapterExecutor::new(pointer_max_bits), + Rv32Mul256Opcode::CLASS_OFFSET, + ); + inventory.add_executor(mult, Rv32Mul256Opcode::iter().map(|x| x.global_opcode()))?; + + let shift = Rv32Shift256Executor::new( + Rv32HeapAdapterExecutor::new(pointer_max_bits), + Rv32Shift256Opcode::CLASS_OFFSET, + ); + inventory.add_executor(shift, Rv32Shift256Opcode::iter().map(|x| x.global_opcode()))?; + + Ok(()) + } +} + +impl VmCircuitExtension for Int256 { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { let SystemPort { execution_bus, program_bus, memory_bridge, - } = builder.system_port(); - let range_checker_chip = builder.system_base().range_checker_chip.clone(); - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip + } = inventory.system().port(); + + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = inventory.range_checker().bus; + let pointer_max_bits = inventory.pointer_max_bits(); + + let bitwise_lu = { + // A trick to get around Rust's borrow rules + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; - - let range_tuple_chip = if let Some(chip) = builder - .find_chip::>() - .into_iter() - .find(|c| { - c.bus().sizes[0] >= self.range_tuple_checker_sizes[0] - && c.bus().sizes[1] >= self.range_tuple_checker_sizes[1] - }) { - chip.clone() - } else { - let range_tuple_bus = - RangeTupleCheckerBus::new(builder.new_bus_idx(), self.range_tuple_checker_sizes); - let chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - inventory.add_periphery_chip(chip.clone()); - chip + + let range_tuple_checker = { + let existing_air = inventory.find_air::>().find(|c| { + c.bus.sizes[0] >= self.range_tuple_checker_sizes[0] + && c.bus.sizes[1] >= self.range_tuple_checker_sizes[1] + }); + if let Some(air) = existing_air { + air.bus + } else { + let bus = RangeTupleCheckerBus::new( + inventory.new_bus_idx(), + self.range_tuple_checker_sizes, + ); + let air = RangeTupleCheckerAir { bus }; + inventory.add_air(air); + air.bus + } }; - let base_alu_chip = Rv32BaseAlu256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ), - BaseAluCoreChip::new(bitwise_lu_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - offline_memory.clone(), + let alu = Rv32BaseAlu256Air::new( + Rv32HeapAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu, pointer_max_bits), + BaseAluCoreAir::new(bitwise_lu, Rv32BaseAlu256Opcode::CLASS_OFFSET), ); - inventory.add_executor( - base_alu_chip, - Rv32BaseAlu256Opcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_air(alu); + + let lt = Rv32LessThan256Air::new( + Rv32HeapAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu, pointer_max_bits), + LessThanCoreAir::new(bitwise_lu, Rv32LessThan256Opcode::CLASS_OFFSET), + ); + inventory.add_air(lt); + + let beq = Rv32BranchEqual256Air::new( + Rv32HeapBranchAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu, pointer_max_bits), + BranchEqualCoreAir::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ); + inventory.add_air(beq); + + let blt = Rv32BranchLessThan256Air::new( + Rv32HeapBranchAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu, pointer_max_bits), + BranchLessThanCoreAir::new(bitwise_lu, Rv32BranchLessThan256Opcode::CLASS_OFFSET), + ); + inventory.add_air(blt); + + let mult = Rv32Multiplication256Air::new( + Rv32HeapAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu, pointer_max_bits), + MultiplicationCoreAir::new(range_tuple_checker, Rv32Mul256Opcode::CLASS_OFFSET), + ); + inventory.add_air(mult); + + let shift = Rv32Shift256Air::new( + Rv32HeapAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu, pointer_max_bits), + ShiftCoreAir::new(bitwise_lu, range_checker, Rv32Shift256Opcode::CLASS_OFFSET), + ); + inventory.add_air(shift); + + Ok(()) + } +} + +pub struct Int256CpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Int256CpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + extension: &Int256, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let pointer_max_bits = inventory.airs().config().memory_config.pointer_max_bits; + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + let range_tuple_checker = { + let existing_chip = inventory + .find_chip::>() + .find(|c| { + c.bus().sizes[0] >= extension.range_tuple_checker_sizes[0] + && c.bus().sizes[1] >= extension.range_tuple_checker_sizes[1] + }); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &RangeTupleCheckerAir<2> = inventory.next_air()?; + let chip = SharedRangeTupleCheckerChip::new(RangeTupleCheckerChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; - let less_than_chip = Rv32LessThan256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + inventory.next_air::()?; + let alu = Rv32BaseAlu256Chip::new( + BaseAluFiller::new( + Rv32HeapAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + bitwise_lu.clone(), + Rv32BaseAlu256Opcode::CLASS_OFFSET, ), - LessThanCoreChip::new(bitwise_lu_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - less_than_chip, - Rv32LessThan256Opcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(alu); - let branch_equal_chip = Rv32BranchEqual256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + inventory.next_air::()?; + let lt = Rv32LessThan256Chip::new( + LessThanFiller::new( + Rv32HeapAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + bitwise_lu.clone(), + Rv32LessThan256Opcode::CLASS_OFFSET, ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - branch_equal_chip, - Rv32BranchEqual256Opcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(lt); - let branch_less_than_chip = Rv32BranchLessThan256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + inventory.next_air::()?; + let beq = Rv32BranchEqual256Chip::new( + BranchEqualFiller::new( + Rv32HeapBranchAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + Rv32BranchEqual256Opcode::CLASS_OFFSET, + DEFAULT_PC_STEP, ), - BranchLessThanCoreChip::new( - bitwise_lu_chip.clone(), + mem_helper.clone(), + ); + inventory.add_executor_chip(beq); + + inventory.next_air::()?; + let blt = Rv32BranchLessThan256Chip::new( + BranchLessThanFiller::new( + Rv32HeapBranchAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + bitwise_lu.clone(), Rv32BranchLessThan256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - branch_less_than_chip, - Rv32BranchLessThan256Opcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(blt); - let multiplication_chip = Rv32Multiplication256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + inventory.next_air::()?; + let mult = Rv32Multiplication256Chip::new( + MultiplicationFiller::new( + Rv32HeapAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + range_tuple_checker.clone(), + Rv32Mul256Opcode::CLASS_OFFSET, ), - MultiplicationCoreChip::new(range_tuple_chip, Rv32Mul256Opcode::CLASS_OFFSET), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - multiplication_chip, - Rv32Mul256Opcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(mult); - let shift_chip = Rv32Shift256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ), - ShiftCoreChip::new( - bitwise_lu_chip.clone(), - range_checker_chip, + inventory.next_air::()?; + let shift = Rv32Shift256Chip::new( + ShiftFiller::new( + Rv32HeapAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()), + bitwise_lu.clone(), + range_checker.clone(), Rv32Shift256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - shift_chip, - Rv32Shift256Opcode::iter().map(|x| x.global_opcode()), - )?; - - Ok(inventory) + inventory.add_executor_chip(shift); + Ok(()) } } diff --git a/extensions/bigint/circuit/src/less_than.rs b/extensions/bigint/circuit/src/less_than.rs new file mode 100644 index 0000000000..f93a3d8997 --- /dev/null +++ b/extensions/bigint/circuit/src/less_than.rs @@ -0,0 +1,157 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32LessThan256Opcode; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::Rv32HeapAdapterExecutor; +use openvm_rv32im_circuit::LessThanExecutor; +use openvm_rv32im_transpiler::LessThanOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{common, Rv32LessThan256Executor, INT256_NUM_LIMBS}; + +type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32LessThan256Executor { + pub fn new(adapter: AdapterExecutor, offset: usize) -> Self { + Self(LessThanExecutor::new(adapter, offset)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LessThanPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor for Rv32LessThan256Executor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut LessThanPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + LessThanOpcode::SLT => execute_e1_impl::<_, _, false>, + LessThanOpcode::SLTU => execute_e1_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for Rv32LessThan256Executor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + LessThanOpcode::SLT => execute_e2_impl::<_, _, false>, + LessThanOpcode::SLTU => execute_e2_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &LessThanPreCompute, + vm_state: &mut VmExecState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let cmp_result = if IS_U256 { + common::u256_lt(rs1, rs2) + } else { + common::i256_lt(rs1, rs2) + }; + let mut rd = [0u8; INT256_NUM_LIMBS]; + rd[0] = cmp_result as u8; + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &LessThanPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32LessThan256Executor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LessThanPreCompute, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = LessThanPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + let local_opcode = LessThanOpcode::from_usize( + opcode.local_opcode_idx(Rv32LessThan256Opcode::CLASS_OFFSET), + ); + Ok(local_opcode) + } +} diff --git a/extensions/bigint/circuit/src/lib.rs b/extensions/bigint/circuit/src/lib.rs index 295ef73db2..e3a56a4a26 100644 --- a/extensions/bigint/circuit/src/lib.rs +++ b/extensions/bigint/circuit/src/lib.rs @@ -1,49 +1,230 @@ -use openvm_circuit::{self, arch::VmChipWrapper}; -use openvm_rv32_adapters::{Rv32HeapAdapterChip, Rv32HeapBranchAdapterChip}; +use openvm_circuit::{ + self, + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, + VmAirWrapper, VmBuilder, VmChipComplex, VmChipWrapper, VmProverExtension, + }, + system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, +}; +use openvm_circuit_derive::{PreflightExecutor, VmConfig}; +use openvm_rv32_adapters::{ + Rv32HeapAdapterAir, Rv32HeapAdapterExecutor, Rv32HeapAdapterFiller, Rv32HeapBranchAdapterAir, + Rv32HeapBranchAdapterExecutor, Rv32HeapBranchAdapterFiller, +}; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV32_CELL_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, + BaseAluCoreAir, BaseAluExecutor, BaseAluFiller, BranchEqualCoreAir, BranchEqualExecutor, + BranchEqualFiller, BranchLessThanCoreAir, BranchLessThanExecutor, BranchLessThanFiller, + LessThanCoreAir, LessThanExecutor, LessThanFiller, MultiplicationCoreAir, + MultiplicationExecutor, MultiplicationFiller, Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, + Rv32IoExecutor, Rv32M, Rv32MExecutor, ShiftCoreAir, ShiftExecutor, ShiftFiller, }; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use serde::{Deserialize, Serialize}; mod extension; pub use extension::*; +mod base_alu; +mod branch_eq; +mod branch_lt; +pub(crate) mod common; +mod less_than; +mod mult; +mod shift; #[cfg(test)] mod tests; +/// BaseAlu256 +pub type Rv32BaseAlu256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + BaseAluCoreAir, +>; +#[derive(Clone, PreflightExecutor)] +pub struct Rv32BaseAlu256Executor( + BaseAluExecutor< + Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, +); pub type Rv32BaseAlu256Chip = VmChipWrapper< F, - Rv32HeapAdapterChip, - BaseAluCoreChip, + BaseAluFiller< + Rv32HeapAdapterFiller<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, >; +/// LessThan256 +pub type Rv32LessThan256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + LessThanCoreAir, +>; +#[derive(Clone, PreflightExecutor)] +pub struct Rv32LessThan256Executor( + LessThanExecutor< + Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, +); pub type Rv32LessThan256Chip = VmChipWrapper< F, - Rv32HeapAdapterChip, - LessThanCoreChip, + LessThanFiller< + Rv32HeapAdapterFiller<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, >; +/// Multiplication256 +pub type Rv32Multiplication256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + MultiplicationCoreAir, +>; +#[derive(Clone, PreflightExecutor)] +pub struct Rv32Multiplication256Executor( + MultiplicationExecutor< + Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, +); pub type Rv32Multiplication256Chip = VmChipWrapper< F, - Rv32HeapAdapterChip, - MultiplicationCoreChip, + MultiplicationFiller< + Rv32HeapAdapterFiller<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, >; +/// Shift256 +pub type Rv32Shift256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + ShiftCoreAir, +>; +#[derive(Clone, PreflightExecutor)] +pub struct Rv32Shift256Executor( + ShiftExecutor< + Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, +); pub type Rv32Shift256Chip = VmChipWrapper< F, - Rv32HeapAdapterChip, - ShiftCoreChip, + ShiftFiller< + Rv32HeapAdapterFiller<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, >; +/// BranchEqual256 +pub type Rv32BranchEqual256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchEqualCoreAir, +>; +#[derive(Clone, PreflightExecutor)] +pub struct Rv32BranchEqual256Executor( + BranchEqualExecutor, INT256_NUM_LIMBS>, +); pub type Rv32BranchEqual256Chip = VmChipWrapper< F, - Rv32HeapBranchAdapterChip, - BranchEqualCoreChip, + BranchEqualFiller, INT256_NUM_LIMBS>, >; +/// BranchLessThan256 +pub type Rv32BranchLessThan256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchLessThanCoreAir, +>; +#[derive(Clone, PreflightExecutor)] +pub struct Rv32BranchLessThan256Executor( + BranchLessThanExecutor< + Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, +); pub type Rv32BranchLessThan256Chip = VmChipWrapper< F, - Rv32HeapBranchAdapterChip, - BranchLessThanCoreChip, + BranchLessThanFiller< + Rv32HeapBranchAdapterFiller<2, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, + >, >; + +#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] +pub struct Int256Rv32Config { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub bigint: Int256, +} + +// Default implementation uses no init file +impl InitFileGenerator for Int256Rv32Config {} + +impl Default for Int256Rv32Config { + fn default() -> Self { + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + bigint: Int256::default(), + } + } +} + +#[derive(Clone)] +pub struct Int256Rv32CpuBuilder; + +impl VmBuilder for Int256Rv32CpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Int256Rv32Config; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Int256Rv32Config, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover( + &Int256CpuProverExt, + &config.bigint, + inventory, + )?; + Ok(chip_complex) + } +} diff --git a/extensions/bigint/circuit/src/mult.rs b/extensions/bigint/circuit/src/mult.rs new file mode 100644 index 0000000000..dda5a9c749 --- /dev/null +++ b/extensions/bigint/circuit/src/mult.rs @@ -0,0 +1,184 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32Mul256Opcode; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::Rv32HeapAdapterExecutor; +use openvm_rv32im_circuit::MultiplicationExecutor; +use openvm_rv32im_transpiler::MulOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{Rv32Multiplication256Executor, INT256_NUM_LIMBS}; + +type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32Multiplication256Executor { + pub fn new(adapter: AdapterExecutor, offset: usize) -> Self { + Self(MultiplicationExecutor::new(adapter, offset)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MultPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor for Rv32Multiplication256Executor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut MultPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl) + } +} + +impl MeteredExecutor for Rv32Multiplication256Executor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MultPreCompute, + vm_state: &mut VmExecState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let rd = u256_mul(rs1, rs2); + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &MultPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} + +impl Rv32Multiplication256Executor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut MultPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + let local_opcode = + MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET)); + assert_eq!(local_opcode, MulOpcode::MUL); + *data = MultPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + Ok(()) + } +} + +#[inline(always)] +pub(crate) fn u256_mul( + rs1: [u8; INT256_NUM_LIMBS], + rs2: [u8; INT256_NUM_LIMBS], +) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u32; 8] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u32; 8] = unsafe { std::mem::transmute(rs2) }; + let mut rd = [0u32; 8]; + for i in 0..8 { + let mut carry = 0u64; + for j in 0..(8 - i) { + let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry; + rd[i + j] = res as u32; + carry = res >> 32; + } + } + unsafe { std::mem::transmute(rd) } +} + +#[cfg(test)] +mod tests { + use alloy_primitives::U256; + use rand::{prelude::StdRng, Rng, SeedableRng}; + + use crate::{mult::u256_mul, INT256_NUM_LIMBS}; + + #[test] + fn test_u256_mul() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u64; 4] = rng.gen(); + let limbs_b: [u64; 4] = rng.gen(); + let a = U256::from_limbs(limbs_a); + let b = U256::from_limbs(limbs_b); + let a_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_a) }; + let b_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_b) }; + assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b)); + } + } +} diff --git a/extensions/bigint/circuit/src/shift.rs b/extensions/bigint/circuit/src/shift.rs new file mode 100644 index 0000000000..71c8317260 --- /dev/null +++ b/extensions/bigint/circuit/src/shift.rs @@ -0,0 +1,258 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32Shift256Opcode; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::Rv32HeapAdapterExecutor; +use openvm_rv32im_circuit::ShiftExecutor; +use openvm_rv32im_transpiler::ShiftOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{Rv32Shift256Executor, INT256_NUM_LIMBS}; + +type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32Shift256Executor { + pub fn new(adapter: AdapterExecutor, offset: usize) -> Self { + Self(ShiftExecutor::new(adapter, offset)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct ShiftPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor for Rv32Shift256Executor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut ShiftPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + ShiftOpcode::SLL => execute_e1_impl::<_, _, SllOp>, + ShiftOpcode::SRA => execute_e1_impl::<_, _, SraOp>, + ShiftOpcode::SRL => execute_e1_impl::<_, _, SrlOp>, + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for Rv32Shift256Executor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + ShiftOpcode::SLL => execute_e2_impl::<_, _, SllOp>, + ShiftOpcode::SRA => execute_e2_impl::<_, _, SraOp>, + ShiftOpcode::SRL => execute_e2_impl::<_, _, SrlOp>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &ShiftPreCompute, + vm_state: &mut VmExecState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let rd = OP::compute(rs1, rs2); + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &ShiftPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32Shift256Executor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut ShiftPreCompute, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = ShiftPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + let local_opcode = + ShiftOpcode::from_usize(opcode.local_opcode_idx(Rv32Shift256Opcode::CLASS_OFFSET)); + Ok(local_opcode) + } +} + +trait ShiftOp { + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS]; +} +struct SllOp; +struct SrlOp; +struct SraOp; +impl ShiftOp for SllOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + let mut rd = [0u64; 4]; + // Only use the first 8 bits. + let shift = (rs2_u64[0] & 0xff) as u32; + let index_offset = (shift / u64::BITS) as usize; + let bit_offset = shift % u64::BITS; + let mut carry = 0u64; + for i in index_offset..4 { + let curr = rs1_u64[i - index_offset]; + rd[i] = (curr << bit_offset) + carry; + if bit_offset > 0 { + carry = curr >> (u64::BITS - bit_offset); + } + } + unsafe { std::mem::transmute(rd) } + } +} +impl ShiftOp for SrlOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + // Logical right shift - fill with 0 + shift_right(rs1, rs2, 0) + } +} +impl ShiftOp for SraOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + // Arithmetic right shift - fill with sign bit + if rs1[INT256_NUM_LIMBS - 1] & 0x80 > 0 { + shift_right(rs1, rs2, u64::MAX) + } else { + shift_right(rs1, rs2, 0) + } + } +} + +#[inline(always)] +fn shift_right( + rs1: [u8; INT256_NUM_LIMBS], + rs2: [u8; INT256_NUM_LIMBS], + init_value: u64, +) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + let mut rd = [init_value; 4]; + let shift = (rs2_u64[0] & 0xff) as u32; + let index_offset = (shift / u64::BITS) as usize; + let bit_offset = shift % u64::BITS; + let mut carry = if bit_offset > 0 { + init_value << (u64::BITS - bit_offset) + } else { + 0 + }; + for i in (index_offset..4).rev() { + let curr = rs1_u64[i]; + rd[i - index_offset] = (curr >> bit_offset) + carry; + if bit_offset > 0 { + carry = curr << (u64::BITS - bit_offset); + } + } + unsafe { std::mem::transmute(rd) } +} + +#[cfg(test)] +mod tests { + use alloy_primitives::U256; + use rand::{prelude::StdRng, Rng, SeedableRng}; + + use crate::{ + shift::{ShiftOp, SllOp, SraOp, SrlOp}, + INT256_NUM_LIMBS, + }; + + #[test] + fn test_shift_op() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u8; INT256_NUM_LIMBS] = rng.gen(); + let mut limbs_b: [u8; INT256_NUM_LIMBS] = [0; INT256_NUM_LIMBS]; + let shift: u8 = rng.gen(); + limbs_b[0] = shift; + let a = U256::from_le_bytes(limbs_a); + { + let res = SllOp::compute(limbs_a, limbs_b); + assert_eq!(U256::from_le_bytes(res), a << shift); + } + { + let res = SraOp::compute(limbs_a, limbs_b); + assert_eq!(U256::from_le_bytes(res), a.arithmetic_shr(shift as usize)); + } + { + let res = SrlOp::compute(limbs_a, limbs_b); + assert_eq!(U256::from_le_bytes(res), a >> shift); + } + } + } +} diff --git a/extensions/bigint/circuit/src/tests.rs b/extensions/bigint/circuit/src/tests.rs index 0e26352410..f49de57339 100644 --- a/extensions/bigint/circuit/src/tests.rs +++ b/extensions/bigint/circuit/src/tests.rs @@ -1,183 +1,214 @@ +use std::sync::Arc; + use openvm_bigint_transpiler::{ Rv32BaseAlu256Opcode, Rv32BranchEqual256Opcode, Rv32BranchLessThan256Opcode, Rv32LessThan256Opcode, Rv32Mul256Opcode, Rv32Shift256Opcode, }; use openvm_circuit::{ arch::{ - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS}, - InstructionExecutor, + testing::{ + TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS, + }, + MatrixRecordArena, PreflightExecutor, }, utils::generate_long_number, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip, SharedRangeTupleCheckerChip}, +}; +use openvm_instructions::{ + program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_CELL_BITS, + LocalOpcode, }; -use openvm_instructions::{program::PC_BITS, riscv::RV32_CELL_BITS, LocalOpcode}; use openvm_rv32_adapters::{ - rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterChip, - Rv32HeapBranchAdapterChip, + rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterAir, Rv32HeapAdapterExecutor, + Rv32HeapAdapterFiller, Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterExecutor, + Rv32HeapBranchAdapterFiller, }; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV_B_TYPE_IMM_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, + BaseAluCoreAir, BaseAluFiller, BranchEqualCoreAir, BranchEqualFiller, BranchLessThanCoreAir, + BranchLessThanFiller, LessThanCoreAir, LessThanFiller, MultiplicationCoreAir, + MultiplicationFiller, ShiftCoreAir, ShiftFiller, }; use openvm_rv32im_transpiler::{ - BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, ShiftOpcode, + BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, MulOpcode, ShiftOpcode, }; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; - -use super::{ - Rv32BaseAlu256Chip, Rv32BranchEqual256Chip, Rv32BranchLessThan256Chip, Rv32LessThan256Chip, - Rv32Multiplication256Chip, Rv32Shift256Chip, +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; + +use crate::{ + Rv32BaseAlu256Air, Rv32BaseAlu256Chip, Rv32BaseAlu256Executor, Rv32BranchEqual256Air, + Rv32BranchEqual256Chip, Rv32BranchEqual256Executor, Rv32BranchLessThan256Air, + Rv32BranchLessThan256Chip, Rv32BranchLessThan256Executor, Rv32LessThan256Air, + Rv32LessThan256Chip, Rv32LessThan256Executor, Rv32Multiplication256Air, + Rv32Multiplication256Chip, Rv32Multiplication256Executor, Rv32Shift256Air, Rv32Shift256Chip, + Rv32Shift256Executor, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); #[allow(clippy::type_complexity)] -fn run_int_256_rand_execute>( - opcode: usize, - num_ops: usize, - executor: &mut E, +fn set_and_execute_rand( tester: &mut VmChipTestBuilder, + harness: &mut TestChipHarness, + rng: &mut StdRng, + opcode: usize, branch_fn: Option bool>, -) { - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); - - let mut rng = create_seeded_rng(); +) where + STEP: PreflightExecutor>, +{ let branch = branch_fn.is_some(); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - if branch { - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - let instruction = rv32_heap_branch_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - imm as isize, - opcode, - ); - - tester.execute_with_pc( - executor, - &instruction, - rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), - ); - - let cmp_result = branch_fn.unwrap()(opcode, &b, &c); - let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; - let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; - assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); - } else { - let instruction = rv32_write_heap_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - opcode, - ); - tester.execute(executor, &instruction); - } + let b = generate_long_number::(rng); + let c = generate_long_number::(rng); + if branch { + let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); + let instruction = rv32_heap_branch_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + imm as isize, + opcode, + ); + + tester.execute_with_pc( + harness, + &instruction, + rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), + ); + + let cmp_result = branch_fn.unwrap()(opcode, &b, &c); + let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; + let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; + assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); + } else { + let instruction = rv32_write_heap_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + opcode, + ); + tester.execute(harness, &instruction); } } +#[test_case(BaseAluOpcode::ADD, 24)] +#[test_case(BaseAluOpcode::SUB, 24)] +#[test_case(BaseAluOpcode::XOR, 24)] +#[test_case(BaseAluOpcode::OR, 24)] +#[test_case(BaseAluOpcode::AND, 24)] fn run_alu_256_rand_test(opcode: BaseAluOpcode, num_ops: usize) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BaseAlu256Opcode::CLASS_OFFSET; - let mut chip = Rv32BaseAlu256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = Rv32BaseAlu256Air::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), tester.memory_bridge(), + bitwise_bus, tester.address_bits(), + ), + BaseAluCoreAir::new(bitwise_bus, offset), + ); + let executor = + Rv32BaseAlu256Executor::new(Rv32HeapAdapterExecutor::new(tester.address_bits()), offset); + let chip = Rv32BaseAlu256Chip::new( + BaseAluFiller::new( + Rv32HeapAdapterFiller::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), + offset, ), - BaseAluCoreChip::new(bitwise_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + let mut harness = TestChipHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - run_int_256_rand_execute( - opcode.local_usize() + Rv32BaseAlu256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut harness, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn alu_256_add_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::ADD, 24); -} - -#[test] -fn alu_256_sub_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::SUB, 24); -} - -#[test] -fn alu_256_xor_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::XOR, 24); -} - -#[test] -fn alu_256_or_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::OR, 24); -} - -#[test] -fn alu_256_and_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::AND, 24); -} - +#[test_case(LessThanOpcode::SLT, 24)] +#[test_case(LessThanOpcode::SLTU, 24)] fn run_lt_256_rand_test(opcode: LessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32LessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32LessThan256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let air = Rv32LessThan256Air::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), tester.memory_bridge(), + bitwise_bus, tester.address_bits(), - bitwise_chip.clone(), ), - LessThanCoreChip::new(bitwise_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + LessThanCoreAir::new(bitwise_bus, offset), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32LessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, + let executor = + Rv32LessThan256Executor::new(Rv32HeapAdapterExecutor::new(tester.address_bits()), offset); + let chip = Rv32LessThan256Chip::new( + LessThanFiller::new( + Rv32HeapAdapterFiller::new(tester.address_bits(), bitwise_chip.clone()), + bitwise_chip.clone(), + offset, + ), + tester.memory_helper(), ); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + let mut harness = TestChipHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); -#[test] -fn lt_256_slt_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLT, 24); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut harness, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); + tester.simple_test().expect("Verification failed"); } -#[test] -fn lt_256_sltu_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLTU, 24); -} +#[test_case(MulOpcode::MUL, 24)] +fn run_mul_256_rand_test(opcode: MulOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Mul256Opcode::CLASS_OFFSET; -fn run_mul_256_rand_test(num_ops: usize) { let range_tuple_bus = RangeTupleCheckerBus::new( RANGE_TUPLE_CHECKER_BUS, [ @@ -185,106 +216,143 @@ fn run_mul_256_rand_test(num_ops: usize) { (INT256_NUM_LIMBS * (1 << RV32_CELL_BITS)) as u32, ], ); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); + let range_tuple_chip = + SharedRangeTupleCheckerChip::new(RangeTupleCheckerChip::<2>::new(range_tuple_bus)); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32Multiplication256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let air = Rv32Multiplication256Air::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), tester.memory_bridge(), + bitwise_bus, tester.address_bits(), - bitwise_chip.clone(), ), - MultiplicationCoreChip::new(range_tuple_checker.clone(), Rv32Mul256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + MultiplicationCoreAir::new(range_tuple_bus, offset), ); - - run_int_256_rand_execute( - Rv32Mul256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, + let executor = Rv32Multiplication256Executor::new( + Rv32HeapAdapterExecutor::new(tester.address_bits()), + offset, + ); + let chip = Rv32Multiplication256Chip::::new( + MultiplicationFiller::new( + Rv32HeapAdapterFiller::new(tester.address_bits(), bitwise_chip.clone()), + range_tuple_chip.clone(), + offset, + ), + tester.memory_helper(), ); + let mut harness = TestChipHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut harness, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester .build() - .load(chip) - .load(range_tuple_checker) - .load(bitwise_chip) + .load(harness) + .load_periphery((range_tuple_chip.air, range_tuple_chip)) + .load_periphery((bitwise_chip.air, bitwise_chip)) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn mul_256_rand_test() { - run_mul_256_rand_test(24); -} - +#[test_case(ShiftOpcode::SLL, 24)] +#[test_case(ShiftOpcode::SRL, 24)] +#[test_case(ShiftOpcode::SRA, 24)] fn run_shift_256_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Shift256Opcode::CLASS_OFFSET; + + let range_checker_chip = tester.range_checker(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32Shift256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let air = Rv32Shift256Air::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), tester.memory_bridge(), + bitwise_bus, tester.address_bits(), - bitwise_chip.clone(), ), - ShiftCoreChip::new( + ShiftCoreAir::new(bitwise_bus, range_checker_chip.bus(), offset), + ); + let executor = + Rv32Shift256Executor::new(Rv32HeapAdapterExecutor::new(tester.address_bits()), offset); + let chip = Rv32Shift256Chip::new( + ShiftFiller::new( + Rv32HeapAdapterFiller::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - tester.memory_controller().borrow().range_checker.clone(), - Rv32Shift256Opcode::CLASS_OFFSET, + range_checker_chip.clone(), + offset, ), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32Shift256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn shift_256_sll_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SLL, 24); -} + let mut harness = TestChipHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); -#[test] -fn shift_256_srl_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRL, 24); -} + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut harness, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } -#[test] -fn shift_256_sra_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRA, 24); + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); + tester.simple_test().expect("Verification failed"); } +#[test_case(BranchEqualOpcode::BEQ, 24)] +#[test_case(BranchEqualOpcode::BNE, 24)] fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchEqual256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Rv32BranchEqual256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + let air = Rv32BranchEqual256Air::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), tester.memory_bridge(), + bitwise_bus, tester.address_bits(), - bitwise_chip.clone(), ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), + BranchEqualCoreAir::new(offset, DEFAULT_PC_STEP), + ); + let executor = Rv32BranchEqual256Executor::new( + Rv32HeapBranchAdapterExecutor::new(tester.address_bits()), + offset, + DEFAULT_PC_STEP, ); + let chip = Rv32BranchEqual256Chip::new( + BranchEqualFiller::new( + Rv32HeapBranchAdapterFiller::new(tester.address_bits(), bitwise_chip.clone()), + offset, + DEFAULT_PC_STEP, + ), + tester.memory_helper(), + ); + let mut harness = TestChipHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { x.iter() @@ -294,93 +362,94 @@ fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { == BranchEqualOpcode::BNE.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET) }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut harness, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn beq_256_beq_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BEQ, 24); -} - -#[test] -fn beq_256_bne_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BNE, 24); -} - +#[test_case(BranchLessThanOpcode::BLT, 24)] +#[test_case(BranchLessThanOpcode::BLTU, 24)] +#[test_case(BranchLessThanOpcode::BGE, 24)] +#[test_case(BranchLessThanOpcode::BGEU, 24)] fn run_blt_256_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchLessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchLessThan256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let air = Rv32BranchLessThan256Air::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), tester.memory_bridge(), + bitwise_bus, tester.address_bits(), - bitwise_chip.clone(), ), - BranchLessThanCoreChip::new( + BranchLessThanCoreAir::new(bitwise_bus, offset), + ); + let executor = Rv32BranchLessThan256Executor::new( + Rv32HeapBranchAdapterExecutor::new(tester.address_bits()), + offset, + ); + let chip = Rv32BranchLessThan256Chip::new( + BranchLessThanFiller::new( + Rv32HeapBranchAdapterFiller::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + offset, ), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + let mut harness = TestChipHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { - let opcode = - BranchLessThanOpcode::from_usize(opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET); - let (is_ge, is_signed) = match opcode { - BranchLessThanOpcode::BLT => (false, true), - BranchLessThanOpcode::BLTU => (false, false), - BranchLessThanOpcode::BGE => (true, true), - BranchLessThanOpcode::BGEU => (true, false), - }; - let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - for (x, y) in x.iter().rev().zip(y.iter().rev()) { - if x != y { - return (x < y) ^ x_sign ^ y_sign ^ is_ge; + let branch_fn = + |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| -> bool { + let opcode = BranchLessThanOpcode::from_usize( + opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ); + let (is_ge, is_signed) = match opcode { + BranchLessThanOpcode::BLT => (false, true), + BranchLessThanOpcode::BLTU => (false, false), + BranchLessThanOpcode::BGE => (true, true), + BranchLessThanOpcode::BGEU => (true, false), + }; + let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + for (x, y) in x.iter().rev().zip(y.iter().rev()) { + if x != y { + return (x < y) ^ x_sign ^ y_sign ^ is_ge; + } } - } - is_ge - }; + is_ge + }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchLessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut harness, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); tester.simple_test().expect("Verification failed"); } - -#[test] -fn blt_256_blt_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLT, 24); -} - -#[test] -fn blt_256_bltu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLTU, 24); -} - -#[test] -fn blt_256_bge_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGE, 24); -} - -#[test] -fn blt_256_bgeu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGEU, 24); -} diff --git a/extensions/ecc/circuit/Cargo.toml b/extensions/ecc/circuit/Cargo.toml index dca4fb91e9..8a083bc922 100644 --- a/extensions/ecc/circuit/Cargo.toml +++ b/extensions/ecc/circuit/Cargo.toml @@ -8,28 +8,30 @@ homepage.workspace = true repository.workspace = true [dependencies] -openvm-circuit-primitives-derive = { workspace = true } openvm-circuit-primitives = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } openvm-mod-circuit-builder = { workspace = true } openvm-stark-backend = { workspace = true } -openvm-rv32im-circuit = { workspace = true } openvm-algebra-circuit = { workspace = true } openvm-rv32-adapters = { workspace = true } openvm-ecc-transpiler = { workspace = true } +openvm-ecc-guest = { workspace = true, features = ["ed25519"] } num-bigint = { workspace = true } +num-integer = { workspace = true } num-traits = { workspace = true } strum = { workspace = true } -derive_more = { workspace = true } +derive_more = { workspace = true, features = ["deref", "deref_mut"] } derive-new = { workspace = true } once_cell = { workspace = true, features = ["std"] } +rand = { workspace = true } serde = { workspace = true } serde_with = { workspace = true } lazy_static = { workspace = true } hex-literal = { workspace = true } +halo2curves-axiom = { workspace = true } [dev-dependencies] openvm-stark-sdk = { workspace = true } @@ -37,3 +39,6 @@ openvm-mod-circuit-builder = { workspace = true, features = ["test-utils"] } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } lazy_static = { workspace = true } + +[package.metadata.cargo-shear] +ignored = ["rand"] diff --git a/extensions/ecc/circuit/src/config.rs b/extensions/ecc/circuit/src/config.rs index a959938be9..f3d6be47c9 100644 --- a/extensions/ecc/circuit/src/config.rs +++ b/extensions/ecc/circuit/src/config.rs @@ -1,51 +1,88 @@ -use openvm_algebra_circuit::*; -use openvm_circuit::arch::{InitFileGenerator, SystemConfig}; +use std::result::Result; + +use openvm_algebra_circuit::{Rv32ModularConfig, Rv32ModularConfigExecutor, Rv32ModularCpuBuilder}; +use openvm_circuit::{ + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, + VmBuilder, VmChipComplex, VmProverExtension, + }, + system::SystemChipInventory, +}; use openvm_circuit_derive::VmConfig; -use openvm_rv32im_circuit::*; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; use serde::{Deserialize, Serialize}; use super::*; #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] -pub struct Rv32WeierstrassConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub base: Rv32I, - #[extension] - pub mul: Rv32M, - #[extension] - pub io: Rv32Io, +pub struct Rv32EccConfig { + #[config(generics = true)] + pub modular: Rv32ModularConfig, #[extension] - pub modular: ModularExtension, - #[extension] - pub weierstrass: WeierstrassExtension, + pub ecc: EccExtension, } -impl Rv32WeierstrassConfig { - pub fn new(curves: Vec) -> Self { - let primes: Vec<_> = curves +impl Rv32EccConfig { + pub fn new( + sw_curves: Vec>, + te_curves: Vec>, + ) -> Self { + let sw_primes: Vec<_> = sw_curves .iter() .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) .collect(); + let te_primes: Vec<_> = te_curves + .iter() + .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) + .collect(); + let primes = sw_primes.into_iter().chain(te_primes).collect(); Self { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(primes), - weierstrass: WeierstrassExtension::new(curves), + modular: Rv32ModularConfig::new(primes), + ecc: EccExtension::new(sw_curves, te_curves), } } } -impl InitFileGenerator for Rv32WeierstrassConfig { +impl InitFileGenerator for Rv32EccConfig { fn generate_init_file_contents(&self) -> Option { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", - self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.modular.modular.generate_moduli_init(), + self.ecc.generate_ecc_init() )) } } + +#[derive(Clone)] +pub struct Rv32EccCpuBuilder; + +impl VmBuilder for Rv32EccCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Rv32EccConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Self::VmConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&Rv32ModularCpuBuilder, &config.modular, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&EccCpuProverExt, &config.ecc, inventory)?; + Ok(chip_complex) + } +} diff --git a/extensions/ecc/circuit/src/ecc_extension.rs b/extensions/ecc/circuit/src/ecc_extension.rs new file mode 100644 index 0000000000..09c5afe43b --- /dev/null +++ b/extensions/ecc/circuit/src/ecc_extension.rs @@ -0,0 +1,570 @@ +use std::sync::Arc; + +use hex_literal::hex; +use lazy_static::lazy_static; +use num_bigint::BigUint; +use num_traits::{FromPrimitive, Zero}; +use once_cell::sync::Lazy; +use openvm_circuit::{ + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{memory::SharedMemoryHelper, SystemPort}, +}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, + var_range::VariableRangeCheckerBus, +}; +use openvm_ecc_guest::{ + algebra::IntMod, + ed25519::{CURVE_A as ED25519_A, CURVE_D as ED25519_D, ED25519_MODULUS, ED25519_ORDER}, +}; +use openvm_ecc_transpiler::{Rv32EdwardsOpcode, Rv32WeierstrassOpcode}; +use openvm_instructions::{LocalOpcode, VmOpcode}; +use openvm_mod_circuit_builder::ExprBuilderConfig; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, DisplayFromStr}; +use strum::EnumCount; + +use crate::{ + get_sw_addne_air, get_sw_addne_chip, get_sw_addne_step, get_sw_double_air, get_sw_double_chip, + get_sw_double_step, get_te_add_air, get_te_add_chip, get_te_add_step, EccCpuProverExt, + EdwardsAir, SwAddNeExecutor, SwDoubleExecutor, TeAddExecutor, WeierstrassAir, +}; + +#[serde_as] +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct CurveConfig { + /// The name of the curve struct as defined by moduli_declare. + pub struct_name: String, + /// The coordinate modulus of the curve. + #[serde_as(as = "DisplayFromStr")] + pub modulus: BigUint, + /// The scalar field modulus of the curve. + #[serde_as(as = "DisplayFromStr")] + pub scalar: BigUint, + // curve-specific coefficients + pub coeffs: T, +} + +#[serde_as] +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct SwCurveCoeffs { + /// The coefficient a of y^2 = x^3 + ax + b. + #[serde_as(as = "DisplayFromStr")] + pub a: BigUint, + /// The coefficient b of y^2 = x^3 + ax + b. + #[serde_as(as = "DisplayFromStr")] + pub b: BigUint, +} + +#[serde_as] +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct TeCurveCoeffs { + /// The coefficient a of ax^2 + y^2 = 1 + dx^2y^2 + #[serde_as(as = "DisplayFromStr")] + pub a: BigUint, + /// The coefficient d of ax^2 + y^2 = 1 + dx^2y^2 + #[serde_as(as = "DisplayFromStr")] + pub d: BigUint, +} + +pub static SECP256K1_CONFIG: Lazy> = Lazy::new(|| CurveConfig { + struct_name: "Secp256k1Point".to_string(), + modulus: BigUint::from_bytes_be(&hex!( + "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F" + )), + scalar: BigUint::from_bytes_be(&hex!( + "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141" + )), + coeffs: SwCurveCoeffs { + a: BigUint::zero(), + b: BigUint::from_u8(7u8).unwrap(), + }, +}); + +pub static P256_CONFIG: Lazy> = Lazy::new(|| CurveConfig { + struct_name: "P256Point".to_string(), + modulus: BigUint::from_bytes_be(&hex!( + "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff" + )), + scalar: BigUint::from_bytes_be(&hex!( + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" + )), + coeffs: SwCurveCoeffs { + a: BigUint::from_bytes_le(&hex!( + "fcffffffffffffffffffffff00000000000000000000000001000000ffffffff" + )), + b: BigUint::from_bytes_le(&hex!( + "4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a" + )), + }, +}); + +pub static ED25519_CONFIG: Lazy> = Lazy::new(|| CurveConfig { + struct_name: "Ed25519Point".to_string(), + modulus: ED25519_MODULUS.clone(), + scalar: ED25519_ORDER.clone(), + coeffs: TeCurveCoeffs { + a: BigUint::from_bytes_le(ED25519_A.as_le_bytes()), + d: BigUint::from_bytes_le(ED25519_D.as_le_bytes()), + }, +}); + +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct EccExtension { + #[serde(default)] + pub supported_sw_curves: Vec>, + #[serde(default)] + pub supported_te_curves: Vec>, +} + +impl EccExtension { + pub fn generate_ecc_init(&self) -> String { + let supported_sw_curves = self + .supported_sw_curves + .iter() + .map(|curve_config| format!("\"{}\"", curve_config.struct_name)) + .collect::>() + .join(", "); + + let supported_te_curves = self + .supported_te_curves + .iter() + .map(|curve_config| format!("\"{}\"", curve_config.struct_name)) + .collect::>() + .join(", "); + + format!( + "openvm_ecc_guest::sw_macros::sw_init! {{ {supported_sw_curves} }}\nopenvm_ecc_guest::te_macros::te_init! {{ {supported_te_curves} }}" + ) + } +} + +#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum EccExtensionExecutor { + // 32 limbs prime + SwAddNeRv32_32(SwAddNeExecutor<2, 32>), + SwDoubleRv32_32(SwDoubleExecutor<2, 32>), + // 48 limbs prime + SwAddNeRv32_48(SwAddNeExecutor<6, 16>), + SwDoubleRv32_48(SwDoubleExecutor<6, 16>), + // 32 limbs prime + TeEcAddRv32_32(TeAddExecutor<2, 32>), +} + +impl VmExecutionExtension for EccExtension { + type Executor = EccExtensionExecutor; + + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + // TODO: somehow get the range checker bus from `ExecutorInventory` + let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16); + + // add the sw curves + for (i, curve) in self.supported_sw_curves.iter().enumerate() { + let start_offset = + Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT; + let bytes = curve.modulus.bits().div_ceil(8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + let addne = get_sw_addne_step( + config.clone(), + dummy_range_checker_bus, + pointer_max_bits, + start_offset, + ); + + inventory.add_executor( + EccExtensionExecutor::SwAddNeRv32_32(addne), + ((Rv32WeierstrassOpcode::SW_ADD_NE as usize) + ..=(Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize)) + .map(|x| VmOpcode::from_usize(x + start_offset)), + )?; + + let double = get_sw_double_step( + config, + dummy_range_checker_bus, + pointer_max_bits, + start_offset, + curve.coeffs.a.clone(), + ); + + inventory.add_executor( + EccExtensionExecutor::SwDoubleRv32_32(double), + ((Rv32WeierstrassOpcode::SW_DOUBLE as usize) + ..=(Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize)) + .map(|x| VmOpcode::from_usize(x + start_offset)), + )?; + } else if bytes <= 48 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + let addne = get_sw_addne_step( + config.clone(), + dummy_range_checker_bus, + pointer_max_bits, + start_offset, + ); + + inventory.add_executor( + EccExtensionExecutor::SwAddNeRv32_48(addne), + ((Rv32WeierstrassOpcode::SW_ADD_NE as usize) + ..=(Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize)) + .map(|x| VmOpcode::from_usize(x + start_offset)), + )?; + + let double = get_sw_double_step( + config, + dummy_range_checker_bus, + pointer_max_bits, + start_offset, + curve.coeffs.a.clone(), + ); + + inventory.add_executor( + EccExtensionExecutor::SwDoubleRv32_48(double), + ((Rv32WeierstrassOpcode::SW_DOUBLE as usize) + ..=(Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize)) + .map(|x| VmOpcode::from_usize(x + start_offset)), + )?; + } else { + panic!("Modulus too large"); + } + } + + // add the te curves + for (i, curve) in self.supported_te_curves.iter().enumerate() { + let start_offset = Rv32EdwardsOpcode::CLASS_OFFSET + i * Rv32EdwardsOpcode::COUNT; + let bytes = curve.modulus.bits().div_ceil(8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + let add = get_te_add_step( + config.clone(), + dummy_range_checker_bus, + pointer_max_bits, + start_offset, + curve.coeffs.a.clone(), + curve.coeffs.d.clone(), + ); + + inventory.add_executor( + EccExtensionExecutor::TeEcAddRv32_32(add), + ((Rv32EdwardsOpcode::TE_ADD as usize) + ..=(Rv32EdwardsOpcode::SETUP_TE_ADD as usize)) + .map(|x| VmOpcode::from_usize(x + start_offset)), + )?; + } else { + panic!("Modulus too large"); + } + } + + Ok(()) + } +} + +impl VmCircuitExtension for EccExtension { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = inventory.system().port(); + + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker_bus = inventory.range_checker().bus; + let pointer_max_bits = inventory.pointer_max_bits(); + + let bitwise_lu = { + // A trick to get around Rust's borrow rules + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } + }; + + for (i, curve) in self.supported_sw_curves.iter().enumerate() { + let start_offset = + Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT; + let bytes = curve.modulus.bits().div_ceil(8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + let addne = get_sw_addne_air::<2, 32>( + exec_bridge, + memory_bridge, + config.clone(), + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(addne); + + let double = get_sw_double_air::<2, 32>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + curve.coeffs.a.clone(), + ); + inventory.add_air(double); + } else if bytes <= 48 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + + let addne = get_sw_addne_air::<6, 16>( + exec_bridge, + memory_bridge, + config.clone(), + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(addne); + + let double = get_sw_double_air::<6, 16>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + curve.coeffs.a.clone(), + ); + inventory.add_air(double); + } else { + panic!("Modulus too large"); + } + } + + for (i, curve) in self.supported_te_curves.iter().enumerate() { + let start_offset = Rv32EdwardsOpcode::CLASS_OFFSET + i * Rv32EdwardsOpcode::COUNT; + let bytes = curve.modulus.bits().div_ceil(8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + let add = get_te_add_air::<2, 32>( + exec_bridge, + memory_bridge, + config.clone(), + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + curve.coeffs.a.clone(), + curve.coeffs.d.clone(), + ); + inventory.add_air(add); + } else { + panic!("Modulus too large"); + } + } + + Ok(()) + } +} + +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for EccCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + extension: &EccExtension, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + for curve in extension.supported_sw_curves.iter() { + let bytes = curve.modulus.bits().div_ceil(8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + inventory.next_air::>()?; + let addne = get_sw_addne_chip::, 2, 32>( + config.clone(), + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(addne); + + inventory.next_air::>()?; + let double = get_sw_double_chip::, 2, 32>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + curve.coeffs.a.clone(), + ); + inventory.add_executor_chip(double); + } else if bytes <= 48 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + + inventory.next_air::>()?; + let addne = get_sw_addne_chip::, 6, 16>( + config.clone(), + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(addne); + + inventory.next_air::>()?; + let double = get_sw_double_chip::, 6, 16>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + curve.coeffs.a.clone(), + ); + inventory.add_executor_chip(double); + } else { + panic!("Modulus too large"); + } + } + + for curve in extension.supported_te_curves.iter() { + let bytes = curve.modulus.bits().div_ceil(8); + + if bytes <= 32 { + let config = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + + inventory.next_air::>()?; + let add = get_te_add_chip::, 2, 32>( + config.clone(), + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + curve.coeffs.a.clone(), + curve.coeffs.d.clone(), + ); + inventory.add_executor_chip(add); + } else { + panic!("Modulus too large"); + } + } + + Ok(()) + } +} + +// Convenience constants for constructors +lazy_static! { + // The constants are taken from: https://en.bitcoin.it/wiki/Secp256k1 + pub static ref SECP256K1_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( + "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F" + )); + pub static ref SECP256K1_ORDER: BigUint = BigUint::from_bytes_be(&hex!( + "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141" + )); +} + +lazy_static! { + // The constants are taken from: https://neuromancer.sk/std/secg/secp256r1 + pub static ref P256_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( + "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff" + )); + pub static ref P256_ORDER: BigUint = BigUint::from_bytes_be(&hex!( + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" + )); +} +// little-endian +pub const P256_A: [u8; 32] = + hex!("fcffffffffffffffffffffff00000000000000000000000001000000ffffffff"); +// little-endian +pub const P256_B: [u8; 32] = + hex!("4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a"); + +pub const SECP256K1_ECC_STRUCT_NAME: &str = "Secp256k1Point"; +pub const P256_ECC_STRUCT_NAME: &str = "P256Point"; diff --git a/extensions/ecc/circuit/src/edwards_chip/README.md b/extensions/ecc/circuit/src/edwards_chip/README.md new file mode 100644 index 0000000000..24167e062a --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/README.md @@ -0,0 +1,17 @@ +# Twisted Edwards (TE) curve operations + +The `te_add` instruction is implemented in the `edwards_chip` module. + +### 1. `te_add` + +**Assumptions:** + +- Both points `(x1, y1)` and `(x2, y2)` lie on the curve. + +**Circuit statements:** + +- The chip takes two inputs: `(x1, y1)` and `(x2, y2)`, and returns `(x3, y3)` where: + - `x3 = (x1 * y2 + x2 * y1) / (1 + d * x1 * x2 * y1 * y2)` + - `y3 = (y1 * y2 - a * x1 * x2) / (1 - d * x1 * x2 * y1 * y2)` + +- The `TeAddChip` constrains that these field expressions are computed correctly over the field `C::Fp`. The coefficients `a` and `d` are taken from the `CurveConfig`. diff --git a/extensions/ecc/circuit/src/edwards_chip/add.rs b/extensions/ecc/circuit/src/edwards_chip/add.rs new file mode 100644 index 0000000000..6fa1512589 --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/add.rs @@ -0,0 +1,470 @@ +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, + cell::RefCell, + rc::Rc, +}; + +use derive_more::derive::{Deref, DerefMut}; +use num_bigint::BigUint; +use num_traits::One; +use openvm_circuit::{ + arch::{ExecutionBridge, *}, + system::memory::{ + offline_checker::MemoryBridge, online::GuestMemory, SharedMemoryHelper, POINTER_MAX_BITS, + }, +}; +use openvm_circuit_derive::PreflightExecutor; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_ecc_transpiler::Rv32EdwardsOpcode; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, +}; +use openvm_mod_circuit_builder::{ + run_field_expression_precomputed, ExprBuilder, ExprBuilderConfig, FieldExpr, + FieldExpressionCoreAir, FieldExpressionExecutor, FieldExpressionFiller, +}; +use openvm_rv32_adapters::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller, +}; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{utils::jacobi, EdwardsAir, EdwardsChip}; +use crate::edwards_chip::curves::{get_te_curve_type, te_add, TeCurveType}; + +pub fn te_add_expr( + config: ExprBuilderConfig, // The coordinate field. + range_bus: VariableRangeCheckerBus, + a_biguint: BigUint, + d_biguint: BigUint, +) -> FieldExpr { + config.check_valid(); + let builder = ExprBuilder::new(config, range_bus.range_max_bits); + let builder = Rc::new(RefCell::new(builder)); + + let x1 = ExprBuilder::new_input(builder.clone()); + let y1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let y2 = ExprBuilder::new_input(builder.clone()); + let a = ExprBuilder::new_const(builder.clone(), a_biguint.clone()); + let d = ExprBuilder::new_const(builder.clone(), d_biguint.clone()); + let one = ExprBuilder::new_const(builder.clone(), BigUint::one()); + + let x1y2 = x1.clone() * y2.clone(); + let x2y1 = x2.clone() * y1.clone(); + let y1y2 = y1 * y2; + let x1x2 = x1 * x2; + let dx1x2y1y2 = d * x1x2.clone() * y1y2.clone(); + + let mut x3 = (x1y2 + x2y1) / (one.clone() + dx1x2y1y2.clone()); + let mut y3 = (y1y2 - a * x1x2) / (one - dx1x2y1y2); + + x3.save_output(); + y3.save_output(); + + let builder = (*builder).borrow().clone(); + + FieldExpr::new_with_setup_values(builder, range_bus, true, vec![a_biguint, d_biguint]) +} + +#[derive(Clone, PreflightExecutor, Deref, DerefMut)] +pub struct TeAddExecutor( + pub(crate) FieldExpressionExecutor< + Rv32VecHeapAdapterExecutor<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>, + >, +); + +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + a_biguint: BigUint, + d_biguint: BigUint, +) -> (FieldExpr, Vec) { + let expr = te_add_expr(config, range_checker_bus, a_biguint, d_biguint); + + let local_opcode_idx = vec![ + Rv32EdwardsOpcode::TE_ADD as usize, + Rv32EdwardsOpcode::SETUP_TE_ADD as usize, + ]; + + (expr, local_opcode_idx) +} + +#[allow(clippy::too_many_arguments)] +pub fn get_te_add_air( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, + a_biguint: BigUint, + d_biguint: BigUint, +) -> EdwardsAir<2, BLOCKS, BLOCK_SIZE> { + // Ensure that the addition operation is complete + assert!(jacobi(&a_biguint.clone().into(), &config.modulus.clone().into()) == 1); + assert!(jacobi(&d_biguint.clone().into(), &config.modulus.clone().into()) == -1); + + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint, d_biguint); + EdwardsAir::new( + Rv32VecHeapAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ) +} + +pub fn get_te_add_step( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, + a_biguint: BigUint, + d_biguint: BigUint, +) -> TeAddExecutor { + // Ensure that the addition operation is complete + assert!(jacobi(&a_biguint.clone().into(), &config.modulus.clone().into()) == 1); + assert!(jacobi(&d_biguint.clone().into(), &config.modulus.clone().into()) == -1); + + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint, d_biguint); + TeAddExecutor(FieldExpressionExecutor::new( + Rv32VecHeapAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + vec![], + "TeAdd", + )) +} + +pub fn get_te_add_chip( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, + a_biguint: BigUint, + d_biguint: BigUint, +) -> EdwardsChip { + // Ensure that the addition operation is complete + assert!(jacobi(&a_biguint.clone().into(), &config.modulus.clone().into()) == 1); + assert!(jacobi(&d_biguint.clone().into(), &config.modulus.clone().into()) == -1); + + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus(), a_biguint, d_biguint); + EdwardsChip::new( + FieldExpressionFiller::new( + Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), + expr, + local_opcode_idx, + vec![], + range_checker, + true, + ), + mem_helper, + ) +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct TeAddPreCompute<'a> { + expr: &'a FieldExpr, + rs_addrs: [u8; 2], + a: u8, + flag_idx: u8, +} + +impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> TeAddExecutor { + fn pre_compute_impl( + &'a self, + pc: u32, + inst: &Instruction, + data: &mut TeAddPreCompute<'a>, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + // Validate instruction format + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let local_opcode = opcode.local_opcode_idx(self.offset); + + // Pre-compute flag_idx + let needs_setup = self.expr.needs_setup(); + let mut flag_idx = self.expr.num_flags() as u8; + if needs_setup { + // Find which opcode this is in our local_opcode_idx list + if let Some(opcode_position) = self + .local_opcode_idx + .iter() + .position(|&idx| idx == local_opcode) + { + // If this is NOT the last opcode (setup), get the corresponding flag_idx + if opcode_position < self.opcode_flag_idx.len() { + flag_idx = self.opcode_flag_idx[opcode_position] as u8; + } + } + } + + let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8); + *data = TeAddPreCompute { + expr: &self.expr, + rs_addrs, + a: a as u8, + flag_idx, + }; + + let local_opcode = opcode.local_opcode_idx(self.offset); + let is_setup = local_opcode == Rv32EdwardsOpcode::SETUP_TE_ADD as usize; + + Ok(is_setup) + } +} + +impl Executor + for TeAddExecutor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let pre_compute: &mut TeAddPreCompute = data.borrow_mut(); + + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + if let Some(curve_type) = { + let modulus = &pre_compute.expr.builder.prime; + let a_coeff = &pre_compute.expr.setup_values[0]; + let d_coeff = &pre_compute.expr.setup_values[1]; + get_te_curve_type(modulus, a_coeff, d_coeff) + } { + match (is_setup, curve_type) { + (true, TeCurveType::ED25519) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { TeCurveType::ED25519 as u8 }, + true, + >), + (false, TeCurveType::ED25519) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { TeCurveType::ED25519 as u8 }, + false, + >), + } + } else if is_setup { + Ok(execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) + } else { + Ok(execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) + } + } +} + +impl MeteredExecutor + for TeAddExecutor +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + if let Some(curve_type) = { + let modulus = &pre_compute.data.expr.builder.prime; + let a_coeff = &pre_compute.data.expr.setup_values[0]; + let d_coeff = &pre_compute.data.expr.setup_values[1]; + get_te_curve_type(modulus, a_coeff, d_coeff) + } { + match (is_setup, curve_type) { + (true, TeCurveType::ED25519) => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { TeCurveType::ED25519 as u8 }, + >), + (false, TeCurveType::ED25519) => { + Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { TeCurveType::ED25519 as u8 }>) + } + } + } else if is_setup { + Ok(execute_e2_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }>) + } else { + Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }>) + } + } +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const CURVE_TYPE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let e2_pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(e2_pre_compute.chip_idx as usize, 1); + let pre_compute = unsafe { + std::slice::from_raw_parts( + &e2_pre_compute.data as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, false>(pre_compute, vm_state); +} + +unsafe fn execute_e2_setup_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const CURVE_TYPE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let e2_pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(e2_pre_compute.chip_idx as usize, 1); + let pre_compute = unsafe { + std::slice::from_raw_parts( + &e2_pre_compute.data as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, true>(pre_compute, vm_state); +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const CURVE_TYPE: u8, + const IS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &TeAddPreCompute = pre_compute.borrow(); + // Read register values + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + // Read memory values for both points + let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| { + debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32)) + }); + + if IS_SETUP { + let input_prime = BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened()); + let input_a = BigUint::from_bytes_le(read_data[0][BLOCKS / 2..].as_flattened()); + let input_d = BigUint::from_bytes_le(read_data[1][..BLOCKS / 2].as_flattened()); + + if input_prime != pre_compute.expr.prime { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "TeAdd: mismatched prime", + }); + return; + } + + if input_a != pre_compute.expr.setup_values[0] { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "TeAdd: mismatched a", + }); + return; + } + + if input_d != pre_compute.expr.setup_values[1] { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "TeAdd: mismatched d", + }); + return; + } + } + + let output_data = if CURVE_TYPE == u8::MAX { + let read_data: DynArray = read_data.into(); + run_field_expression_precomputed::( + pre_compute.expr, + pre_compute.flag_idx as usize, + &read_data.0, + ) + .into() + } else { + te_add::(read_data) + }; + + let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + + // Write output data to memory + for (i, block) in output_data.into_iter().enumerate() { + vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} diff --git a/extensions/ecc/circuit/src/edwards_chip/curves.rs b/extensions/ecc/circuit/src/edwards_chip/curves.rs new file mode 100644 index 0000000000..db37b5c616 --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/curves.rs @@ -0,0 +1,89 @@ +use halo2curves_axiom::{ed25519::TwistedEdwardsCurveExt, ff::PrimeField}; +use lazy_static::lazy_static; +use num_bigint::BigUint; +use openvm_algebra_circuit::fields::{blocks_to_field_element, field_element_to_blocks}; + +use crate::weierstrass_chip::curves::get_modulus_as_bigint; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TeCurveType { + ED25519 = 0, +} + +pub(super) fn get_te_curve_type( + modulus: &BigUint, + a_coeff: &BigUint, + d_coeff: &BigUint, +) -> Option { + if modulus == &ED25519_CURVE_PARAMS.modulus + && a_coeff == &ED25519_CURVE_PARAMS.a + && d_coeff == &ED25519_CURVE_PARAMS.d + { + return Some(TeCurveType::ED25519); + } + + None +} + +struct CurveParams { + modulus: BigUint, + a: BigUint, + d: BigUint, +} + +lazy_static! { + static ref ED25519_CURVE_PARAMS: CurveParams = CurveParams { + modulus: get_modulus_as_bigint::(), + a: BigUint::from_bytes_le( + &::a().to_repr(), + ), + d: BigUint::from_bytes_le( + &::d().to_repr(), + ), + }; +} + +#[inline(always)] +pub fn te_add( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + match CURVE_TYPE { + x if x == TeCurveType::ED25519 as u8 => { + te_add_256bit::( + input_data, + halo2curves_axiom::ed25519::Ed25519::a(), + halo2curves_axiom::ed25519::Ed25519::d(), + ) + } + _ => panic!("Unsupported curve type: {}", CURVE_TYPE), + } +} + +#[inline(always)] +fn te_add_256bit, const BLOCKS: usize, const BLOCK_SIZE: usize>( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], + a: F, + d: F, +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + let x1 = blocks_to_field_element::(input_data[0][..BLOCKS / 2].as_flattened()); + let y1 = blocks_to_field_element::(input_data[0][BLOCKS / 2..].as_flattened()); + let x2 = blocks_to_field_element::(input_data[1][..BLOCKS / 2].as_flattened()); + let y2 = blocks_to_field_element::(input_data[1][BLOCKS / 2..].as_flattened()); + + let (x3, y3) = te_add_impl::(x1, y1, x2, y2, a, d); + + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + field_element_to_blocks::(&x3, &mut output[..BLOCKS / 2]); + field_element_to_blocks::(&y3, &mut output[BLOCKS / 2..]); + output +} + +#[inline(always)] +pub fn te_add_impl(x1: F, y1: F, x2: F, y2: F, a: F, d: F) -> (F, F) { + println!("te_add_impl called, a: {:?}, d: {:?}", a, d); + let dx1x2y1y2 = d * x1 * x2 * y1 * y2; + let x3 = (x1 * y2 + x2 * y1) * (F::ONE + dx1x2y1y2).invert().unwrap(); + let y3 = (y1 * y2 - a * x1 * x2) * (F::ONE - dx1x2y1y2).invert().unwrap(); + + (x3, y3) +} diff --git a/extensions/ecc/circuit/src/edwards_chip/mod.rs b/extensions/ecc/circuit/src/edwards_chip/mod.rs new file mode 100644 index 0000000000..a396646f88 --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/mod.rs @@ -0,0 +1,30 @@ +mod add; +pub use add::*; + +mod curves; +mod utils; + +#[cfg(test)] +mod tests; + +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_mod_circuit_builder::{FieldExpressionCoreAir, FieldExpressionFiller}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterFiller}; + +pub(crate) type EdwardsAir = + VmAirWrapper< + Rv32VecHeapAdapterAir, + FieldExpressionCoreAir, + >; + +pub(crate) type EdwardsChip< + F, + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = VmChipWrapper< + F, + FieldExpressionFiller< + Rv32VecHeapAdapterFiller, + >, +>; diff --git a/extensions/ecc/circuit/src/edwards_chip/tests.rs b/extensions/ecc/circuit/src/edwards_chip/tests.rs new file mode 100644 index 0000000000..9a24ff6d22 --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/tests.rs @@ -0,0 +1,242 @@ +use std::{str::FromStr, sync::Arc}; + +use num_bigint::BigUint; +use num_traits::FromPrimitive; +use openvm_circuit::arch::{ + testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + MatrixRecordArena, +}; +use openvm_circuit_primitives::{ + bigint::utils::big_uint_to_limbs, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, +}; +use openvm_ecc_transpiler::Rv32EdwardsOpcode; +use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_mod_circuit_builder::{test_utils::biguint_to_limbs, ExprBuilderConfig}; +use openvm_rv32_adapters::rv32_write_heap_default; +use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_sdk::p3_baby_bear::BabyBear; + +use crate::{ + edwards_chip::{ + get_te_add_air, get_te_add_chip, get_te_add_step, EdwardsAir, EdwardsChip, TeAddExecutor, + }, + weierstrass_chip::prime_limbs, +}; + +const NUM_LIMBS: usize = 32; +const LIMB_BITS: usize = 8; +const BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; + +lazy_static::lazy_static! { + pub static ref SampleEcPoints: Vec<(BigUint, BigUint)> = { + // Base point of edwards25519 + let x1 = BigUint::from_str( + "15112221349535400772501151409588531511454012693041857206046113283949847762202", + ) + .unwrap(); + let y1 = BigUint::from_str( + "46316835694926478169428394003475163141307993866256225615783033603165251855960", + ) + .unwrap(); + + // random point on edwards25519 + let x2 = BigUint::from_u32(2).unwrap(); + let y2 = BigUint::from_str( + "11879831548380997166425477238087913000047176376829905612296558668626594440753", + ) + .unwrap(); + + // This is the sum of (x1, y1) and (x2, y2). + let x3 = BigUint::from_str( + "44969869612046584870714054830543834361257841801051546235130567688769346152934", + ) + .unwrap(); + let y3 = BigUint::from_str( + "50796027728050908782231253190819121962159170739537197094456293084373503699602", + ) + .unwrap(); + + // This is 2 * (x1, y1) + let x4 = BigUint::from_str( + "39226743113244985161159605482495583316761443760287217110659799046557361995496", + ) + .unwrap(); + let y4 = BigUint::from_str( + "12570354238812836652656274015246690354874018829607973815551555426027032771563", + ) + .unwrap(); + + vec![(x1, y1), (x2, y2), (x3, y3), (x4, y4)] + }; + + pub static ref Edwards25519_Prime: BigUint = BigUint::from_str( + "57896044618658097711785492504343953926634992332820282019728792003956564819949", + ) + .unwrap(); + + pub static ref Edwards25519_A: BigUint = BigUint::from_str( + "57896044618658097711785492504343953926634992332820282019728792003956564819948", + ) + .unwrap(); + + pub static ref Edwards25519_D: BigUint = BigUint::from_str( + "37095705934669439343138083508754565189542113879843219016388785533085940283555", + ) + .unwrap(); + + pub static ref Edwards25519_A_LIMBS: [BabyBear; NUM_LIMBS] = + big_uint_to_limbs(&Edwards25519_A, LIMB_BITS) + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect::>() + .try_into() + .unwrap(); + pub static ref Edwards25519_D_LIMBS: [BabyBear; NUM_LIMBS] = + big_uint_to_limbs(&Edwards25519_D, LIMB_BITS) + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect::>() + .try_into() + .unwrap(); +} + +type EdwardsHarness = TestChipHarness< + F, + TeAddExecutor<2, BLOCK_SIZE>, + EdwardsAir<2, 2, BLOCK_SIZE>, + EdwardsChip, + MatrixRecordArena, +>; + +fn create_test_chip( + tester: &VmChipTestBuilder, + config: ExprBuilderConfig, + offset: usize, + a_biguint: BigUint, + d_biguint: BigUint, +) -> ( + EdwardsHarness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + let air = get_te_add_air( + tester.execution_bridge(), + tester.memory_bridge(), + config.clone(), + tester.range_checker().bus(), + bitwise_bus, + tester.address_bits(), + offset, + a_biguint.clone(), + d_biguint.clone(), + ); + let executor = get_te_add_step( + config.clone(), + tester.range_checker().bus(), + tester.address_bits(), + offset, + a_biguint.clone(), + d_biguint.clone(), + ); + let chip = get_te_add_chip( + config.clone(), + tester.memory_helper(), + tester.range_checker(), + bitwise_chip.clone(), + tester.address_bits(), + a_biguint, + d_biguint, + ); + let harness = EdwardsHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} + +#[test] +fn test_add() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: Edwards25519_Prime.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + + let (mut harness, bitwise) = create_test_chip( + &tester, + config, + Rv32EdwardsOpcode::CLASS_OFFSET, + Edwards25519_A.clone(), + Edwards25519_D.clone(), + ); + + assert_eq!(harness.executor.expr.builder.num_variables, 12); + + let (p1_x, p1_y) = SampleEcPoints[0].clone(); + let (p2_x, p2_y) = SampleEcPoints[1].clone(); + + let p1_x_limbs = + biguint_to_limbs::(p1_x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let p1_y_limbs = + biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let p2_x_limbs = + biguint_to_limbs::(p2_x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let p2_y_limbs = + biguint_to_limbs::(p2_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + + let r = harness + .executor + .expr + .execute(vec![p1_x, p1_y, p2_x, p2_y], vec![true]); + assert_eq!(r.len(), 12); + + let outputs = harness + .executor + .expr + .output_indices() + .iter() + .map(|i| &r[*i]) + .collect::>(); + assert_eq!(outputs[0], &SampleEcPoints[2].0); + assert_eq!(outputs[1], &SampleEcPoints[2].1); + + let prime_limbs: [BabyBear; NUM_LIMBS] = + prime_limbs(&harness.executor.expr).try_into().unwrap(); + let mut one_limbs = [BabyBear::ZERO; NUM_LIMBS]; + one_limbs[0] = BabyBear::ONE; + let setup_instruction = rv32_write_heap_default( + &mut tester, + vec![prime_limbs, *Edwards25519_A_LIMBS], + vec![*Edwards25519_D_LIMBS], + harness.executor.offset + Rv32EdwardsOpcode::SETUP_TE_ADD as usize, + ); + tester.execute(&mut harness, &setup_instruction); + + let instruction = rv32_write_heap_default( + &mut tester, + vec![p1_x_limbs, p1_y_limbs], + vec![p2_x_limbs, p2_y_limbs], + harness.executor.offset + Rv32EdwardsOpcode::TE_ADD as usize, + ); + + tester.execute(&mut harness, &instruction); + + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + + tester.simple_test().expect("Verification failed"); +} diff --git a/extensions/ecc/circuit/src/edwards_chip/utils.rs b/extensions/ecc/circuit/src/edwards_chip/utils.rs new file mode 100644 index 0000000000..ce7711519f --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/utils.rs @@ -0,0 +1,101 @@ +use num_bigint::BigInt; +use num_integer::Integer; +use num_traits::{sign::Signed, One, Zero}; + +/// Jacobi returns the Jacobi symbol (x/y), either +1, -1, or 0. +/// The y argument must be an odd integer. +pub fn jacobi(x: &BigInt, y: &BigInt) -> isize { + if !y.is_odd() { + panic!( + "invalid arguments, y must be an odd integer,but got {:?}", + y + ); + } + + let mut a = x.clone(); + let mut b = y.clone(); + let mut j = 1; + + if b.is_negative() { + if a.is_negative() { + j = -1; + } + b = -b; + } + + loop { + if b.is_one() { + return j; + } + if a.is_zero() { + return 0; + } + + a = a.mod_floor(&b); + if a.is_zero() { + return 0; + } + + // a > 0 + + // handle factors of 2 in a + let s = a.trailing_zeros().unwrap(); + if s & 1 != 0 { + //let bmod8 = b.get_limb(0) & 7; + let bmod8 = mod_2_to_the_k(&b, 3); + if bmod8 == BigInt::from(3) || bmod8 == BigInt::from(5) { + j = -j; + } + } + + let c = &a >> s; // a = 2^s*c + + // swap numerator and denominator + if mod_2_to_the_k(&b, 2) == BigInt::from(3) && mod_2_to_the_k(&c, 2) == BigInt::from(3) { + j = -j + } + + a = b; + b = c; + } +} + +fn mod_2_to_the_k(x: &BigInt, k: u32) -> BigInt { + x & BigInt::from(2u32.pow(k) - 1) +} +#[cfg(test)] +mod tests { + use num_traits::FromPrimitive; + + use super::*; + + #[test] + fn test_jacobi() { + let cases = [ + [0, 1, 1], + [0, -1, 1], + [1, 1, 1], + [1, -1, 1], + [0, 5, 0], + [1, 5, 1], + [2, 5, -1], + [-2, 5, -1], + [2, -5, -1], + [-2, -5, 1], + [3, 5, -1], + [5, 5, 0], + [-5, 5, 0], + [6, 5, 1], + [6, -5, 1], + [-6, 5, 1], + [-6, -5, -1], + ]; + + for case in cases.iter() { + let x = BigInt::from_i64(case[0]).unwrap(); + let y = BigInt::from_i64(case[1]).unwrap(); + + assert_eq!(case[2] as isize, jacobi(&x, &y), "jacobi({}, {})", x, y); + } + } +} diff --git a/extensions/ecc/circuit/src/lib.rs b/extensions/ecc/circuit/src/lib.rs index c1ec864636..7f466952f9 100644 --- a/extensions/ecc/circuit/src/lib.rs +++ b/extensions/ecc/circuit/src/lib.rs @@ -1,8 +1,13 @@ mod weierstrass_chip; pub use weierstrass_chip::*; -mod weierstrass_extension; -pub use weierstrass_extension::*; +mod ecc_extension; +pub use ecc_extension::*; + +mod edwards_chip; +pub use edwards_chip::*; mod config; pub use config::*; + +pub struct EccCpuProverExt; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/README.md b/extensions/ecc/circuit/src/weierstrass_chip/README.md index 94d8df6847..ba7119b0fc 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/README.md +++ b/extensions/ecc/circuit/src/weierstrass_chip/README.md @@ -1,8 +1,8 @@ # Short Weierstrass (SW) Curve Operations -The `ec_add_ne` and `ec_double` instructions are implemented in the `weierstrass_chip` module. +The `sw_add_ne` and `sw_double` instructions are implemented in the `weierstrass_chip` module. -### 1. `ec_add_ne` +### 1. `sw_add_ne` **Assumptions:** @@ -16,9 +16,9 @@ The `ec_add_ne` and `ec_double` instructions are implemented in the `weierstrass - `x3 = lambda^2 - x1 - x2` - `y3 = lambda * (x1 - x3) - y1` -- The `EcAddNeChip` constrains that these field expressions are computed correctly over the field `C::Fp`. +- The `SwAddNeChip` constrains that these field expressions are computed correctly over the field `C::Fp`. -### 2. `ec_double` +### 2. `sw_double` **Assumptions:** @@ -31,4 +31,4 @@ The `ec_add_ne` and `ec_double` instructions are implemented in the `weierstrass - `x3 = lambda^2 - 2 * x1` - `y3 = lambda * (x1 - x3) - y1` -- The `EcDoubleChip` constrains that these expressions are computed correctly over the field `C::Fp`. The coefficient `a` is taken from the `CurveConfig`. +- The `SwDoubleChip` constrains that these expressions are computed correctly over the field `C::Fp`. The coefficient `a` is taken from the `CurveConfig`. diff --git a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs index 24bcc52ef3..5248bd8c7f 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs @@ -1,11 +1,46 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, + cell::RefCell, + rc::Rc, +}; -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; +use derive_more::derive::{Deref, DerefMut}; +use num_bigint::BigUint; +use openvm_algebra_circuit::fields::{get_field_type, FieldType}; +use openvm_circuit::{ + arch::*, + system::memory::{ + offline_checker::MemoryBridge, online::GuestMemory, SharedMemoryHelper, POINTER_MAX_BITS, + }, +}; +use openvm_circuit_derive::PreflightExecutor; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, +}; +use openvm_mod_circuit_builder::{ + run_field_expression_precomputed, ExprBuilder, ExprBuilderConfig, FieldExpr, + FieldExpressionCoreAir, FieldExpressionExecutor, FieldExpressionFiller, +}; +use openvm_rv32_adapters::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller, +}; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{WeierstrassAir, WeierstrassChip}; +use crate::weierstrass_chip::curves::sw_add_ne; // Assumes that (x1, y1), (x2, y2) both lie on the curve and are not the identity point. // Further assumes that x1, x2 are not equal in the coordinate field. -pub fn ec_add_ne_expr( +pub fn sw_add_ne_expr( config: ExprBuilderConfig, // The coordinate field. range_bus: VariableRangeCheckerBus, ) -> FieldExpr { @@ -23,6 +58,462 @@ pub fn ec_add_ne_expr( let mut y3 = lambda * (x1 - x3.clone()) - y1; y3.save_output(); - let builder = builder.borrow().clone(); + let builder = (*builder).borrow().clone(); FieldExpr::new(builder, range_bus, true) } + +/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. +/// BLOCKS: how many blocks do we need to represent one input or output +/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per +/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. +#[derive(Clone, PreflightExecutor, Deref, DerefMut)] +pub struct SwAddNeExecutor( + FieldExpressionExecutor>, +); + +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, +) -> (FieldExpr, Vec) { + let expr = sw_add_ne_expr(config, range_checker_bus); + + let local_opcode_idx = vec![ + Rv32WeierstrassOpcode::SW_ADD_NE as usize, + Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize, + ]; + + (expr, local_opcode_idx) +} + +pub fn get_sw_addne_air( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, +) -> WeierstrassAir<2, BLOCKS, BLOCK_SIZE> { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus); + WeierstrassAir::new( + Rv32VecHeapAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ) +} + +pub fn get_sw_addne_step( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, +) -> SwAddNeExecutor { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus); + SwAddNeExecutor(FieldExpressionExecutor::new( + Rv32VecHeapAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + vec![], + "SwAddNe", + )) +} + +pub fn get_sw_addne_chip( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, +) -> WeierstrassChip { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus()); + WeierstrassChip::new( + FieldExpressionFiller::new( + Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), + expr, + local_opcode_idx, + vec![], + range_checker, + false, + ), + mem_helper, + ) +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct SwAddNePreCompute<'a> { + expr: &'a FieldExpr, + rs_addrs: [u8; 2], + a: u8, + flag_idx: u8, +} + +impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> SwAddNeExecutor { + fn pre_compute_impl( + &'a self, + pc: u32, + inst: &Instruction, + data: &mut SwAddNePreCompute<'a>, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + // Validate instruction format + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let local_opcode = opcode.local_opcode_idx(self.offset); + + // Pre-compute flag_idx + let needs_setup = self.expr.needs_setup(); + let mut flag_idx = self.expr.num_flags() as u8; + if needs_setup { + // Find which opcode this is in our local_opcode_idx list + if let Some(opcode_position) = self + .local_opcode_idx + .iter() + .position(|&idx| idx == local_opcode) + { + // If this is NOT the last opcode (setup), get the corresponding flag_idx + if opcode_position < self.opcode_flag_idx.len() { + flag_idx = self.opcode_flag_idx[opcode_position] as u8; + } + } + } + + let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8); + *data = SwAddNePreCompute { + expr: &self.expr, + rs_addrs, + a: a as u8, + flag_idx, + }; + + let local_opcode = opcode.local_opcode_idx(self.offset); + let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize; + + Ok(is_setup) + } +} + +impl Executor + for SwAddNeExecutor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let pre_compute: &mut SwAddNePreCompute = data.borrow_mut(); + + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + if let Some(field_type) = { + let modulus = &pre_compute.expr.builder.prime; + get_field_type(modulus) + } { + match (is_setup, field_type) { + (true, FieldType::K256Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::K256Coordinate as u8 }, + true, + >), + (true, FieldType::P256Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::P256Coordinate as u8 }, + true, + >), + (true, FieldType::BN254Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BN254Coordinate as u8 }, + true, + >), + (true, FieldType::BLS12_381Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BLS12_381Coordinate as u8 }, + true, + >), + (false, FieldType::K256Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::K256Coordinate as u8 }, + false, + >), + (false, FieldType::P256Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::P256Coordinate as u8 }, + false, + >), + (false, FieldType::BN254Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BN254Coordinate as u8 }, + false, + >), + (false, FieldType::BLS12_381Coordinate) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BLS12_381Coordinate as u8 }, + false, + >), + _ => panic!("Unsupported field type"), + } + } else if is_setup { + Ok(execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) + } else { + Ok(execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) + } + } +} + +impl MeteredExecutor + for SwAddNeExecutor +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + if let Some(field_type) = { + let modulus = &pre_compute.data.expr.builder.prime; + get_field_type(modulus) + } { + if is_setup { + match field_type { + FieldType::K256Coordinate => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::K256Coordinate as u8 }, + >), + FieldType::P256Coordinate => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::P256Coordinate as u8 }, + >), + FieldType::BN254Coordinate => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BN254Coordinate as u8 }, + >), + FieldType::BLS12_381Coordinate => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BLS12_381Coordinate as u8 }, + >), + _ => panic!("Unsupported field type"), + } + } else { + match field_type { + FieldType::K256Coordinate => Ok(execute_e2_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::K256Coordinate as u8 }, + >), + FieldType::P256Coordinate => Ok(execute_e2_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::P256Coordinate as u8 }, + >), + FieldType::BN254Coordinate => Ok(execute_e2_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BN254Coordinate as u8 }, + >), + FieldType::BLS12_381Coordinate => Ok(execute_e2_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { FieldType::BLS12_381Coordinate as u8 }, + >), + _ => panic!("Unsupported field type"), + } + } + } else if is_setup { + Ok(execute_e2_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }>) + } else { + Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }>) + } + } +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const FIELD_TYPE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let e2_pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(e2_pre_compute.chip_idx as usize, 1); + let pre_compute = unsafe { + std::slice::from_raw_parts( + &e2_pre_compute.data as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, false>(pre_compute, vm_state); +} + +unsafe fn execute_e2_setup_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const FIELD_TYPE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let e2_pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(e2_pre_compute.chip_idx as usize, 1); + let pre_compute = unsafe { + std::slice::from_raw_parts( + &e2_pre_compute.data as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, FIELD_TYPE, true>(pre_compute, vm_state); +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const FIELD_TYPE: u8, + const IS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &SwAddNePreCompute = pre_compute.borrow(); + // Read register values + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + // Read memory values for both points + let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2] = rs_vals.map(|address| { + debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32)) + }); + + if IS_SETUP { + let input_prime = BigUint::from_bytes_le(read_data[0][..BLOCKS / 2].as_flattened()); + if input_prime != pre_compute.expr.prime { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "EcAddNe: mismatched prime", + }); + return; + } + } + + let output_data = if FIELD_TYPE == u8::MAX || IS_SETUP { + let read_data: DynArray = read_data.into(); + run_field_expression_precomputed::( + pre_compute.expr, + pre_compute.flag_idx as usize, + &read_data.0, + ) + .into() + } else { + sw_add_ne::(read_data) + }; + + let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + + // Write output data to memory + for (i, block) in output_data.into_iter().enumerate() { + vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/curves.rs b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs new file mode 100644 index 0000000000..2340a4cc26 --- /dev/null +++ b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs @@ -0,0 +1,214 @@ +use halo2curves_axiom::ff::PrimeField; +use num_bigint::BigUint; +use num_traits::Num; +use openvm_algebra_circuit::fields::{ + blocks_to_field_element, blocks_to_field_element_bls12_381_coordinate, field_element_to_blocks, + field_element_to_blocks_bls12_381_coordinate, FieldType, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SwCurveType { + K256 = 0, + P256 = 1, + BN254 = 2, + BLS12_381 = 3, +} + +const P256_NEG_A: u64 = 3; + +pub fn get_modulus_as_bigint() -> BigUint { + BigUint::from_str_radix(F::MODULUS.trim_start_matches("0x"), 16).unwrap() +} + +pub(super) fn get_sw_curve_type(modulus: &BigUint, a_coeff: &BigUint) -> Option { + if modulus == &get_modulus_as_bigint::() + && a_coeff == &BigUint::ZERO + { + return Some(SwCurveType::K256); + } + + let coeff_a = (-halo2curves_axiom::secp256r1::Fp::from(P256_NEG_A)).to_bytes(); + if modulus == &get_modulus_as_bigint::() + && a_coeff == &BigUint::from_bytes_le(&coeff_a) + { + return Some(SwCurveType::P256); + } + + if modulus == &get_modulus_as_bigint::() + && a_coeff == &BigUint::ZERO + { + return Some(SwCurveType::BN254); + } + + if modulus == &get_modulus_as_bigint::() + && a_coeff == &BigUint::ZERO + { + return Some(SwCurveType::BLS12_381); + } + + None +} + +#[inline(always)] +pub fn sw_add_ne( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + match FIELD_TYPE { + x if x == FieldType::K256Coordinate as u8 => { + sw_add_ne_256bit::(input_data) + } + x if x == FieldType::P256Coordinate as u8 => { + sw_add_ne_256bit::(input_data) + } + x if x == FieldType::BN254Coordinate as u8 => { + sw_add_ne_256bit::(input_data) + } + x if x == FieldType::BLS12_381Coordinate as u8 => { + sw_add_ne_bls12_381::(input_data) + } + _ => panic!("Unsupported field type: {}", FIELD_TYPE), + } +} + +/// Dispatch elliptic curve point doubling based on const generic curve type +#[inline(always)] +pub fn sw_double( + input_data: [[u8; BLOCK_SIZE]; BLOCKS], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + match CURVE_TYPE { + x if x == SwCurveType::K256 as u8 => { + sw_double_256bit::(input_data) + } + x if x == SwCurveType::P256 as u8 => { + sw_double_256bit::( + input_data, + ) + } + x if x == SwCurveType::BN254 as u8 => { + sw_double_256bit::(input_data) + } + x if x == SwCurveType::BLS12_381 as u8 => { + sw_double_bls12_381::(input_data) + } + _ => panic!("Unsupported curve type: {}", CURVE_TYPE), + } +} + +#[inline(always)] +fn sw_add_ne_256bit< + F: PrimeField, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +>( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + let x1 = blocks_to_field_element::(input_data[0][..BLOCKS / 2].as_flattened()); + let y1 = blocks_to_field_element::(input_data[0][BLOCKS / 2..].as_flattened()); + let x2 = blocks_to_field_element::(input_data[1][..BLOCKS / 2].as_flattened()); + let y2 = blocks_to_field_element::(input_data[1][BLOCKS / 2..].as_flattened()); + + let (x3, y3) = sw_add_ne_impl::(x1, y1, x2, y2); + + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + field_element_to_blocks::(&x3, &mut output[..BLOCKS / 2]); + field_element_to_blocks::(&y3, &mut output[BLOCKS / 2..]); + output +} + +#[inline(always)] +fn sw_double_256bit< + F: PrimeField, + const NEG_A: u64, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +>( + input_data: [[u8; BLOCK_SIZE]; BLOCKS], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + let x1 = blocks_to_field_element::(input_data[..BLOCKS / 2].as_flattened()); + let y1 = blocks_to_field_element::(input_data[BLOCKS / 2..].as_flattened()); + + let (x3, y3) = sw_double_impl::(x1, y1); + + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + field_element_to_blocks::(&x3, &mut output[..BLOCKS / 2]); + field_element_to_blocks::(&y3, &mut output[BLOCKS / 2..]); + output +} + +#[inline(always)] +fn sw_add_ne_bls12_381( + input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + // Extract coordinates + let x1 = + blocks_to_field_element_bls12_381_coordinate(input_data[0][..BLOCKS / 2].as_flattened()); + let y1 = + blocks_to_field_element_bls12_381_coordinate(input_data[0][BLOCKS / 2..].as_flattened()); + let x2 = + blocks_to_field_element_bls12_381_coordinate(input_data[1][..BLOCKS / 2].as_flattened()); + let y2 = + blocks_to_field_element_bls12_381_coordinate(input_data[1][BLOCKS / 2..].as_flattened()); + + let (x3, y3) = sw_add_ne_impl::(x1, y1, x2, y2); + + // Final output + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + field_element_to_blocks_bls12_381_coordinate(&x3, &mut output[..BLOCKS / 2]); + field_element_to_blocks_bls12_381_coordinate(&y3, &mut output[BLOCKS / 2..]); + output +} + +#[inline(always)] +fn sw_double_bls12_381( + input_data: [[u8; BLOCK_SIZE]; BLOCKS], +) -> [[u8; BLOCK_SIZE]; BLOCKS] { + // Extract coordinates + let x1 = blocks_to_field_element_bls12_381_coordinate(input_data[..BLOCKS / 2].as_flattened()); + let y1 = blocks_to_field_element_bls12_381_coordinate(input_data[BLOCKS / 2..].as_flattened()); + + let (x3, y3) = sw_double_impl::(x1, y1); + + // Final output + let mut output = [[0u8; BLOCK_SIZE]; BLOCKS]; + field_element_to_blocks_bls12_381_coordinate(&x3, &mut output[..BLOCKS / 2]); + field_element_to_blocks_bls12_381_coordinate(&y3, &mut output[BLOCKS / 2..]); + output +} + +#[inline(always)] +pub fn sw_add_ne_impl(x1: F, y1: F, x2: F, y2: F) -> (F, F) { + // Calculate lambda = (y2 - y1) / (x2 - x1) + let lambda = (y2 - y1) * (x2 - x1).invert().unwrap(); + + // Calculate x3 = lambda^2 - x1 - x2 + let x3 = lambda.square() - x1 - x2; + + // Calculate y3 = lambda * (x1 - x3) - y1 + let y3 = lambda * (x1 - x3) - y1; + + (x3, y3) +} + +#[inline(always)] +pub fn sw_double_impl(x1: F, y1: F) -> (F, F) { + // Calculate lambda based on curve coefficient 'a' + let x1_squared = x1.square(); + let three_x1_squared = x1_squared + x1_squared.double(); + let two_y1 = y1.double(); + + let lambda = if NEG_A == 0 { + // For a = 0: lambda = (3 * x1^2) / (2 * y1) + three_x1_squared * two_y1.invert().unwrap() + } else { + // lambda = (3 * x1^2 + a) / (2 * y1) + (three_x1_squared - F::from(NEG_A)) * two_y1.invert().unwrap() + }; + + // Calculate x3 = lambda^2 - 2 * x1 + let x3 = lambda.square() - x1.double(); + + // Calculate y3 = lambda * (x1 - x3) - y1 + let y3 = lambda * (x1 - x3) - y1; + + (x3, y3) +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/double.rs b/extensions/ecc/circuit/src/weierstrass_chip/double.rs index 0ae55f2df7..5578a40030 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/double.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/double.rs @@ -1,11 +1,44 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, + cell::RefCell, + rc::Rc, +}; +use derive_more::derive::{Deref, DerefMut}; use num_bigint::BigUint; use num_traits::One; -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr, FieldVariable}; +use openvm_circuit::{ + arch::*, + system::memory::{ + offline_checker::MemoryBridge, online::GuestMemory, SharedMemoryHelper, POINTER_MAX_BITS, + }, +}; +use openvm_circuit_derive::PreflightExecutor; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, +}; +use openvm_mod_circuit_builder::{ + run_field_expression_precomputed, ExprBuilder, ExprBuilderConfig, FieldExpr, + FieldExpressionCoreAir, FieldExpressionExecutor, FieldExpressionFiller, FieldVariable, +}; +use openvm_rv32_adapters::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterExecutor, Rv32VecHeapAdapterFiller, +}; +use openvm_stark_backend::p3_field::PrimeField32; -pub fn ec_double_ne_expr( +use super::{curves::get_sw_curve_type, WeierstrassAir, WeierstrassChip}; +use crate::weierstrass_chip::curves::{sw_double, SwCurveType}; + +pub fn sw_double_ne_expr( config: ExprBuilderConfig, // The coordinate field. range_bus: VariableRangeCheckerBus, a_biguint: BigUint, @@ -17,7 +50,7 @@ pub fn ec_double_ne_expr( let mut x1 = ExprBuilder::new_input(builder.clone()); let mut y1 = ExprBuilder::new_input(builder.clone()); let a = ExprBuilder::new_const(builder.clone(), a_biguint.clone()); - let is_double_flag = builder.borrow_mut().new_flag(); + let is_double_flag = (*builder).borrow_mut().new_flag(); // We need to prevent divide by zero when not double flag // (equivalently, when it is the setup opcode) let lambda_denom = FieldVariable::select( @@ -31,6 +64,453 @@ pub fn ec_double_ne_expr( let mut y3 = lambda * (x1 - x3.clone()) - y1; y3.save_output(); - let builder = builder.borrow().clone(); + let builder = (*builder).borrow().clone(); FieldExpr::new_with_setup_values(builder, range_bus, true, vec![a_biguint]) } + +/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. +/// BLOCKS: how many blocks do we need to represent one input or output +/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per +/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. +#[derive(Clone, PreflightExecutor, Deref, DerefMut)] +pub struct SwDoubleExecutor( + FieldExpressionExecutor>, +); + +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + a_biguint: BigUint, +) -> (FieldExpr, Vec) { + let expr = sw_double_ne_expr(config, range_checker_bus, a_biguint); + + let local_opcode_idx = vec![ + Rv32WeierstrassOpcode::SW_DOUBLE as usize, + Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize, + ]; + + (expr, local_opcode_idx) +} + +#[allow(clippy::too_many_arguments)] +pub fn get_sw_double_air( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, + a_biguint: BigUint, +) -> WeierstrassAir<1, BLOCKS, BLOCK_SIZE> { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint); + WeierstrassAir::new( + Rv32VecHeapAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ) +} + +pub fn get_sw_double_step( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, + a_biguint: BigUint, +) -> SwDoubleExecutor { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint); + SwDoubleExecutor(FieldExpressionExecutor::new( + Rv32VecHeapAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + vec![], + "SwDouble", + )) +} + +pub fn get_sw_double_chip( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, + a_biguint: BigUint, +) -> WeierstrassChip { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus(), a_biguint); + WeierstrassChip::new( + FieldExpressionFiller::new( + Rv32VecHeapAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), + expr, + local_opcode_idx, + vec![], + range_checker, + true, + ), + mem_helper, + ) +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct SwDoublePreCompute<'a> { + expr: &'a FieldExpr, + rs_addrs: [u8; 1], + a: u8, + flag_idx: u8, +} + +impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize> SwDoubleExecutor { + fn pre_compute_impl( + &'a self, + pc: u32, + inst: &Instruction, + data: &mut SwDoublePreCompute<'a>, + ) -> Result { + let Instruction { + opcode, a, b, d, e, .. + } = inst; + + // Validate instruction format + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let local_opcode = opcode.local_opcode_idx(self.offset); + + // Pre-compute flag_idx + let needs_setup = self.expr.needs_setup(); + let mut flag_idx = self.expr.num_flags() as u8; + if needs_setup { + // Find which opcode this is in our local_opcode_idx list + if let Some(opcode_position) = self + .local_opcode_idx + .iter() + .position(|&idx| idx == local_opcode) + { + // If this is NOT the last opcode (setup), get the corresponding flag_idx + if opcode_position < self.opcode_flag_idx.len() { + flag_idx = self.opcode_flag_idx[opcode_position] as u8; + } + } + } + + let rs_addrs = [b as u8]; + *data = SwDoublePreCompute { + expr: &self.expr, + rs_addrs, + a: a as u8, + flag_idx, + }; + + let local_opcode = opcode.local_opcode_idx(self.offset); + let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize; + + Ok(is_setup) + } +} + +impl Executor + for SwDoubleExecutor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let pre_compute: &mut SwDoublePreCompute = data.borrow_mut(); + + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + if let Some(curve_type) = { + let modulus = &pre_compute.expr.builder.prime; + let a_coeff = &pre_compute.expr.setup_values[0]; + get_sw_curve_type(modulus, a_coeff) + } { + match (is_setup, curve_type) { + (true, SwCurveType::K256) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::K256 as u8 }, + true, + >), + (true, SwCurveType::P256) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::P256 as u8 }, + true, + >), + (true, SwCurveType::BN254) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::BN254 as u8 }, + true, + >), + (true, SwCurveType::BLS12_381) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::BLS12_381 as u8 }, + true, + >), + (false, SwCurveType::K256) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::K256 as u8 }, + false, + >), + (false, SwCurveType::P256) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::P256 as u8 }, + false, + >), + (false, SwCurveType::BN254) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::BN254 as u8 }, + false, + >), + (false, SwCurveType::BLS12_381) => Ok(execute_e12_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::BLS12_381 as u8 }, + false, + >), + } + } else if is_setup { + Ok(execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, true>) + } else { + Ok(execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }, false>) + } + } +} + +impl MeteredExecutor + for SwDoubleExecutor +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + if let Some(curve_type) = { + let modulus = &pre_compute.data.expr.builder.prime; + let a_coeff = &pre_compute.data.expr.setup_values[0]; + get_sw_curve_type(modulus, a_coeff) + } { + match (is_setup, curve_type) { + (true, SwCurveType::K256) => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::K256 as u8 }, + >), + (true, SwCurveType::P256) => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::P256 as u8 }, + >), + (true, SwCurveType::BN254) => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::BN254 as u8 }, + >), + (true, SwCurveType::BLS12_381) => Ok(execute_e2_setup_impl::< + _, + _, + BLOCKS, + BLOCK_SIZE, + { SwCurveType::BLS12_381 as u8 }, + >), + (false, SwCurveType::K256) => { + Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { SwCurveType::K256 as u8 }>) + } + (false, SwCurveType::P256) => { + Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { SwCurveType::P256 as u8 }>) + } + (false, SwCurveType::BN254) => { + Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { SwCurveType::BN254 as u8 }>) + } + (false, SwCurveType::BLS12_381) => { + Ok( + execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { SwCurveType::BLS12_381 as u8 }>, + ) + } + } + } else if is_setup { + Ok(execute_e2_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }>) + } else { + Ok(execute_e2_impl::<_, _, BLOCKS, BLOCK_SIZE, { u8::MAX }>) + } + } +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const CURVE_TYPE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let e2_pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(e2_pre_compute.chip_idx as usize, 1); + let pre_compute = unsafe { + std::slice::from_raw_parts( + &e2_pre_compute.data as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, false>(pre_compute, vm_state); +} + +unsafe fn execute_e2_setup_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const CURVE_TYPE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let e2_pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(e2_pre_compute.chip_idx as usize, 1); + let pre_compute = unsafe { + std::slice::from_raw_parts( + &e2_pre_compute.data as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + execute_e12_impl::<_, _, BLOCKS, BLOCK_SIZE, CURVE_TYPE, true>(pre_compute, vm_state); +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const CURVE_TYPE: u8, + const IS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &SwDoublePreCompute = pre_compute.borrow(); + // Read register values + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + // Read memory values for the point + let read_data: [[u8; BLOCK_SIZE]; BLOCKS] = { + let address = rs_vals[0]; + debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32)) + }; + + if IS_SETUP { + let input_prime = BigUint::from_bytes_le(read_data[..BLOCKS / 2].as_flattened()); + + if input_prime != pre_compute.expr.builder.prime { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "SwDouble: mismatched prime", + }); + return; + } + + // Extract second field element as the a coefficient + let input_a = BigUint::from_bytes_le(read_data[BLOCKS / 2..].as_flattened()); + let coeff_a = &pre_compute.expr.setup_values[0]; + if input_a != *coeff_a { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "SwDouble: mismatched coeff_a", + }); + return; + } + } + + let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP { + let read_data: DynArray = read_data.into(); + run_field_expression_precomputed::( + pre_compute.expr, + pre_compute.flag_idx as usize, + &read_data.0, + ) + .into() + } else { + sw_double::(read_data) + }; + + let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + + // Write output data to memory + for (i, block) in output_data.into_iter().enumerate() { + vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs index 0bcee1facf..9d85397fc8 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs @@ -1,99 +1,28 @@ mod add_ne; +pub(crate) mod curves; mod double; -use std::sync::Arc; - pub use add_ne::*; pub use double::*; #[cfg(test)] mod tests; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_mod_circuit_builder::{FieldExpressionCoreAir, FieldExpressionFiller}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterFiller}; +#[cfg(test)] +pub use tests::*; -use std::sync::Mutex; - -use num_bigint::BigUint; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::SharedVariableRangeCheckerChip; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_ecc_transpiler::Rv32WeierstrassOpcode; -use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionCoreChip}; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. -/// BLOCKS: how many blocks do we need to represent one input or output -/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per -/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcAddNeChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - EcAddNeChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = ec_add_ne_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Rv32WeierstrassOpcode::EC_ADD_NE as usize, - Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, - ], - vec![], - range_checker, - "EcAddNe", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} +pub type WeierstrassAir = + VmAirWrapper< + Rv32VecHeapAdapterAir, + FieldExpressionCoreAir, + >; -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcDoubleChip( - pub VmChipWrapper< +pub type WeierstrassChip = + VmChipWrapper< F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - EcDoubleChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - offset: usize, - a: BigUint, - offline_memory: Arc>>, - ) -> Self { - let expr = ec_double_ne_expr(config, range_checker.bus(), a); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Rv32WeierstrassOpcode::EC_DOUBLE as usize, - Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, - ], - vec![], - range_checker, - "EcDouble", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} + FieldExpressionFiller< + Rv32VecHeapAdapterFiller, + >, + >; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs index 213918ec2e..fd89c3d5bc 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs @@ -1,24 +1,34 @@ -use std::str::FromStr; +use std::{str::FromStr, sync::Arc}; use num_bigint::BigUint; use num_traits::{FromPrimitive, Num, Zero}; -use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; +use openvm_circuit::arch::{ + testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + MatrixRecordArena, +}; use openvm_circuit_primitives::{ bigint::utils::{secp256k1_coord_prime, secp256r1_coord_prime}, - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, }; use openvm_ecc_transpiler::Rv32WeierstrassOpcode; use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; use openvm_mod_circuit_builder::{test_utils::biguint_to_limbs, ExprBuilderConfig, FieldExpr}; -use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; +use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use super::{EcAddNeChip, EcDoubleChip}; +use crate::{ + get_sw_addne_air, get_sw_addne_chip, get_sw_addne_step, get_sw_double_air, get_sw_double_chip, + get_sw_double_step, SwDoubleExecutor, WeierstrassAir, WeierstrassChip, +}; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; const BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; lazy_static::lazy_static! { @@ -70,13 +80,67 @@ lazy_static::lazy_static! { }; } -fn prime_limbs(expr: &FieldExpr) -> Vec { +pub fn prime_limbs(expr: &FieldExpr) -> Vec { expr.prime_limbs .iter() .map(|n| BabyBear::from_canonical_usize(*n)) .collect::>() } +type WeierstrassHarness = TestChipHarness< + F, + SwDoubleExecutor<2, BLOCK_SIZE>, + WeierstrassAir<1, 2, BLOCK_SIZE>, + WeierstrassChip, + MatrixRecordArena, +>; + +fn create_test_double_chips( + tester: &VmChipTestBuilder, + config: ExprBuilderConfig, + offset: usize, + a_biguint: BigUint, +) -> ( + WeierstrassHarness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + let air = get_sw_double_air( + tester.execution_bridge(), + tester.memory_bridge(), + config.clone(), + tester.range_checker().bus(), + bitwise_bus, + tester.address_bits(), + offset, + a_biguint.clone(), + ); + let executor = get_sw_double_step( + config.clone(), + tester.range_checker().bus(), + tester.address_bits(), + offset, + a_biguint.clone(), + ); + let chip = get_sw_double_chip( + config.clone(), + tester.memory_helper(), + tester.range_checker(), + bitwise_chip.clone(), + tester.address_bits(), + a_biguint, + ); + let harness = WeierstrassHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} + #[test] fn test_add_ne() { let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); @@ -86,22 +150,36 @@ fn test_add_ne() { limb_bits: LIMB_BITS, }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = get_sw_addne_air::<2, BLOCK_SIZE>( + tester.execution_bridge(), tester.memory_bridge(), + config.clone(), + tester.range_checker().bus(), + bitwise_bus, tester.address_bits(), - bitwise_chip.clone(), + Rv32WeierstrassOpcode::CLASS_OFFSET, ); - let mut chip = EcAddNeChip::new( - adapter, - config, + let executor = get_sw_addne_step::<2, BLOCK_SIZE>( + config.clone(), + tester.range_checker().bus(), + tester.address_bits(), Rv32WeierstrassOpcode::CLASS_OFFSET, + ); + let chip = get_sw_addne_chip::( + config.clone(), + tester.memory_helper(), tester.range_checker(), - tester.offline_memory_mutex_arc(), + bitwise_chip.clone(), + tester.address_bits(), ); - assert_eq!(chip.0.core.expr().builder.num_variables, 3); // lambda, x3, y3 + + let mut harness = TestChipHarness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + assert_eq!(harness.executor.expr.builder.num_variables, 3); // lambda, x3, y3 let (p1_x, p1_y) = SampleEcPoints[0].clone(); let (p2_x, p2_y) = SampleEcPoints[1].clone(); @@ -115,36 +193,40 @@ fn test_add_ne() { let p2_y_limbs = biguint_to_limbs::(p2_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); - let r = chip - .0 - .core - .expr() + let r = harness + .executor + .expr .execute(vec![p1_x, p1_y, p2_x, p2_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 assert_eq!(r[1], SampleEcPoints[2].0); assert_eq!(r[2], SampleEcPoints[2].1); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(chip.0.core.expr()).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = + prime_limbs(&harness.executor.expr).try_into().unwrap(); let mut one_limbs = [BabyBear::ONE; NUM_LIMBS]; one_limbs[0] = BabyBear::ONE; let setup_instruction = rv32_write_heap_default( &mut tester, vec![prime_limbs, one_limbs], // inputs[0] = prime, others doesn't matter vec![one_limbs, one_limbs], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, + harness.executor.offset + Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize, ); - tester.execute(&mut chip, &setup_instruction); + tester.execute(&mut harness, &setup_instruction); let instruction = rv32_write_heap_default( &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![p2_x_limbs, p2_y_limbs], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_ADD_NE as usize, + harness.executor.offset + Rv32WeierstrassOpcode::SW_ADD_NE as usize, ); - tester.execute(&mut chip, &instruction); + tester.execute(&mut harness, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery((bitwise_chip.air, bitwise_chip)) + .finalize(); tester.simple_test().expect("Verification failed"); } @@ -157,14 +239,12 @@ fn test_double() { num_limbs: NUM_LIMBS, limb_bits: LIMB_BITS, }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + + let (mut harness, bitwise) = create_test_double_chips( + &tester, + config, + Rv32WeierstrassOpcode::CLASS_OFFSET, + BigUint::zero(), ); let (p1_x, p1_y) = SampleEcPoints[1].clone(); @@ -173,41 +253,38 @@ fn test_double() { let p1_y_limbs = biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); - let mut chip = EcDoubleChip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - config, - Rv32WeierstrassOpcode::CLASS_OFFSET, - BigUint::zero(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!(chip.0.core.air.expr.builder.num_variables, 3); // lambda, x3, y3 + assert_eq!(harness.executor.expr.builder.num_variables, 3); // lambda, x3, y3 - let r = chip.0.core.air.expr.execute(vec![p1_x, p1_y], vec![true]); + let r = harness.executor.expr.execute(vec![p1_x, p1_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 assert_eq!(r[1], SampleEcPoints[3].0); assert_eq!(r[2], SampleEcPoints[3].1); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.core.air.expr).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = + prime_limbs(&harness.executor.expr).try_into().unwrap(); let a_limbs = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( &mut tester, vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + harness.executor.offset + Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize, ); - tester.execute(&mut chip, &setup_instruction); + tester.execute(&mut harness, &setup_instruction); let instruction = rv32_write_heap_default( &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + harness.executor.offset + Rv32WeierstrassOpcode::SW_DOUBLE as usize, ); - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.execute(&mut harness, &instruction); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } @@ -225,14 +302,12 @@ fn test_p256_double() { 16, ) .unwrap(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + + let (mut harness, bitwise) = create_test_double_chips( + &tester, + config, + Rv32WeierstrassOpcode::CLASS_OFFSET, + a.clone(), ); // Testing data from: http://point-at-infinity.org/ecc/nisttv @@ -251,17 +326,9 @@ fn test_p256_double() { let p1_y_limbs = biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); - let mut chip = EcDoubleChip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - config, - Rv32WeierstrassOpcode::CLASS_OFFSET, - a.clone(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!(chip.0.core.air.expr.builder.num_variables, 3); // lambda, x3, y3 + assert_eq!(harness.executor.expr.builder.num_variables, 3); // lambda, x3, y3 - let r = chip.0.core.air.expr.execute(vec![p1_x, p1_y], vec![true]); + let r = harness.executor.expr.execute(vec![p1_x, p1_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 let expected_double_x = BigUint::from_str_radix( "7CF27B188D034F7E8A52380304B51AC3C08969E277F21B35A60B48FC47669978", @@ -276,7 +343,8 @@ fn test_p256_double() { assert_eq!(r[1], expected_double_x); assert_eq!(r[2], expected_double_y); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.core.air.expr).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = + prime_limbs(&harness.executor.expr).try_into().unwrap(); let a_limbs = biguint_to_limbs::(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); let setup_instruction = rv32_write_heap_default( @@ -284,19 +352,26 @@ fn test_p256_double() { vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + harness.executor.offset + Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize, ); - tester.execute(&mut chip, &setup_instruction); + tester.execute(&mut harness, &setup_instruction); let instruction = rv32_write_heap_default( &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + harness.executor.offset + Rv32WeierstrassOpcode::SW_DOUBLE as usize, ); - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.execute(&mut harness, &instruction); + // Adding another row to make sure there are dummy rows, and that the dummy row constraints are + // satisfied + tester.execute(&mut harness, &instruction); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/extensions/ecc/circuit/src/weierstrass_extension.rs b/extensions/ecc/circuit/src/weierstrass_extension.rs deleted file mode 100644 index f0ec35e688..0000000000 --- a/extensions/ecc/circuit/src/weierstrass_extension.rs +++ /dev/null @@ -1,258 +0,0 @@ -use derive_more::derive::From; -use hex_literal::hex; -use lazy_static::lazy_static; -use num_bigint::BigUint; -use num_traits::{FromPrimitive, Zero}; -use once_cell::sync::Lazy; -use openvm_circuit::{ - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, - system::phantom::PhantomChip, -}; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_ecc_transpiler::Rv32WeierstrassOpcode; -use openvm_instructions::{LocalOpcode, VmOpcode}; -use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; -use serde_with::{serde_as, DisplayFromStr}; -use strum::EnumCount; - -use super::{EcAddNeChip, EcDoubleChip}; - -#[serde_as] -#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] -pub struct CurveConfig { - /// The name of the curve struct as defined by moduli_declare. - pub struct_name: String, - /// The coordinate modulus of the curve. - #[serde_as(as = "DisplayFromStr")] - pub modulus: BigUint, - /// The scalar field modulus of the curve. - #[serde_as(as = "DisplayFromStr")] - pub scalar: BigUint, - /// The coefficient a of y^2 = x^3 + ax + b. - #[serde_as(as = "DisplayFromStr")] - pub a: BigUint, - /// The coefficient b of y^2 = x^3 + ax + b. - #[serde_as(as = "DisplayFromStr")] - pub b: BigUint, -} - -pub static SECP256K1_CONFIG: Lazy = Lazy::new(|| CurveConfig { - struct_name: SECP256K1_ECC_STRUCT_NAME.to_string(), - modulus: SECP256K1_MODULUS.clone(), - scalar: SECP256K1_ORDER.clone(), - a: BigUint::zero(), - b: BigUint::from_u8(7u8).unwrap(), -}); - -pub static P256_CONFIG: Lazy = Lazy::new(|| CurveConfig { - struct_name: P256_ECC_STRUCT_NAME.to_string(), - modulus: P256_MODULUS.clone(), - scalar: P256_ORDER.clone(), - a: BigUint::from_bytes_le(&P256_A), - b: BigUint::from_bytes_le(&P256_B), -}); - -#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] -pub struct WeierstrassExtension { - pub supported_curves: Vec, -} - -impl WeierstrassExtension { - pub fn generate_sw_init(&self) -> String { - let supported_curves = self - .supported_curves - .iter() - .map(|curve_config| curve_config.struct_name.to_string()) - .collect::>() - .join(", "); - - format!("openvm_ecc_guest::sw_macros::sw_init! {{ {supported_curves} }}") - } -} - -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] -pub enum WeierstrassExtensionExecutor { - // 32 limbs prime - EcAddNeRv32_32(EcAddNeChip), - EcDoubleRv32_32(EcDoubleChip), - // 48 limbs prime - EcAddNeRv32_48(EcAddNeChip), - EcDoubleRv32_48(EcDoubleChip), -} - -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] -pub enum WeierstrassExtensionPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - Phantom(PhantomChip), -} - -impl VmExtension for WeierstrassExtension { - type Executor = WeierstrassExtensionExecutor; - type Periphery = WeierstrassExtensionPeriphery; - - fn build( - &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); - let SystemPort { - execution_bus, - program_bus, - memory_bridge, - } = builder.system_port(); - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip - }; - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - let pointer_bits = builder.system_config().memory_config.pointer_max_bits; - let ec_add_ne_opcodes = (Rv32WeierstrassOpcode::EC_ADD_NE as usize) - ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize); - let ec_double_opcodes = (Rv32WeierstrassOpcode::EC_DOUBLE as usize) - ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize); - - for (i, curve) in self.supported_curves.iter().enumerate() { - let start_offset = - Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT; - let bytes = curve.modulus.bits().div_ceil(8); - let config32 = ExprBuilderConfig { - modulus: curve.modulus.clone(), - num_limbs: 32, - limb_bits: 8, - }; - let config48 = ExprBuilderConfig { - modulus: curve.modulus.clone(), - num_limbs: 48, - limb_bits: 8, - }; - if bytes <= 32 { - let add_ne_chip = EcAddNeChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - config32.clone(), - start_offset, - range_checker.clone(), - offline_memory.clone(), - ); - inventory.add_executor( - WeierstrassExtensionExecutor::EcAddNeRv32_32(add_ne_chip), - ec_add_ne_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - let double_chip = EcDoubleChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - range_checker.clone(), - config32.clone(), - start_offset, - curve.a.clone(), - offline_memory.clone(), - ); - inventory.add_executor( - WeierstrassExtensionExecutor::EcDoubleRv32_32(double_chip), - ec_double_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - } else if bytes <= 48 { - let add_ne_chip = EcAddNeChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - config48.clone(), - start_offset, - range_checker.clone(), - offline_memory.clone(), - ); - inventory.add_executor( - WeierstrassExtensionExecutor::EcAddNeRv32_48(add_ne_chip), - ec_add_ne_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - let double_chip = EcDoubleChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - range_checker.clone(), - config48.clone(), - start_offset, - curve.a.clone(), - offline_memory.clone(), - ); - inventory.add_executor( - WeierstrassExtensionExecutor::EcDoubleRv32_48(double_chip), - ec_double_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - } else { - panic!("Modulus too large"); - } - } - - Ok(inventory) - } -} - -// Convenience constants for constructors -lazy_static! { - // The constants are taken from: https://en.bitcoin.it/wiki/Secp256k1 - pub static ref SECP256K1_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( - "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F" - )); - pub static ref SECP256K1_ORDER: BigUint = BigUint::from_bytes_be(&hex!( - "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141" - )); -} - -lazy_static! { - // The constants are taken from: https://neuromancer.sk/std/secg/secp256r1 - pub static ref P256_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( - "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff" - )); - pub static ref P256_ORDER: BigUint = BigUint::from_bytes_be(&hex!( - "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" - )); -} -// little-endian -const P256_A: [u8; 32] = hex!("fcffffffffffffffffffffff00000000000000000000000001000000ffffffff"); -// little-endian -const P256_B: [u8; 32] = hex!("4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a"); - -pub const SECP256K1_ECC_STRUCT_NAME: &str = "Secp256k1Point"; -pub const P256_ECC_STRUCT_NAME: &str = "P256Point"; diff --git a/extensions/ecc/guest/Cargo.toml b/extensions/ecc/guest/Cargo.toml index e5251eb366..a85f8c4597 100644 --- a/extensions/ecc/guest/Cargo.toml +++ b/extensions/ecc/guest/Cargo.toml @@ -16,15 +16,23 @@ elliptic-curve = { workspace = true, features = ["arithmetic", "sec1"] } openvm-custom-insn = { workspace = true } openvm-rv32im-guest = { workspace = true } openvm-algebra-guest = { workspace = true } +openvm-algebra-moduli-macros = { workspace = true } openvm-ecc-sw-macros = { workspace = true } +openvm-ecc-te-macros = { workspace = true } once_cell = { workspace = true, features = ["race", "alloc"] } +num-bigint = { workspace = true } +hex-literal = { workspace = true } # Used for `halo2curves` feature halo2curves-axiom = { workspace = true, optional = true } group = "0.13.0" +[target.'cfg(not(target_os = "zkvm"))'.dependencies] +lazy_static = { workspace = true } + [features] default = [] +ed25519 = [] halo2curves = ["dep:halo2curves-axiom", "openvm-algebra-guest/halo2curves"] std = ["alloc"] alloc = [] diff --git a/extensions/ecc/guest/src/ecdsa.rs b/extensions/ecc/guest/src/ecdsa.rs index 07fc6d44fc..7c60e575d3 100644 --- a/extensions/ecc/guest/src/ecdsa.rs +++ b/extensions/ecc/guest/src/ecdsa.rs @@ -20,10 +20,7 @@ use elliptic_curve::{ }; use openvm_algebra_guest::{DivUnsafe, IntMod, Reduce}; -use crate::{ - weierstrass::{FromCompressed, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, -}; +use crate::{weierstrass::WeierstrassPoint, CyclicGroup, FromCompressed, Group, IntrinsicCurve}; type Coordinate = <::Point as WeierstrassPoint>::Coordinate; type Scalar = ::Scalar; diff --git a/extensions/ecc/guest/src/ed25519.rs b/extensions/ecc/guest/src/ed25519.rs new file mode 100644 index 0000000000..32f48cba58 --- /dev/null +++ b/extensions/ecc/guest/src/ed25519.rs @@ -0,0 +1,85 @@ +use core::ops::Add; + +use hex_literal::hex; +#[cfg(not(target_os = "zkvm"))] +use lazy_static::lazy_static; +#[cfg(not(target_os = "zkvm"))] +use num_bigint::BigUint; +use openvm_algebra_guest::IntMod; + +use super::group::{CyclicGroup, Group}; +use crate::IntrinsicCurve; + +#[cfg(not(target_os = "zkvm"))] +lazy_static! { + pub static ref ED25519_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( + "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED" + )); + pub static ref ED25519_ORDER: BigUint = BigUint::from_bytes_be(&hex!( + "1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED" + )); + pub static ref ED25519_A: BigUint = BigUint::from_bytes_be(&hex!( + "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEC" + )); + pub static ref ED25519_D: BigUint = BigUint::from_bytes_be(&hex!( + "52036CEE2B6FFE738CC740797779E89800700A4D4141D8AB75EB4DCA135978A3" + )); +} + +openvm_algebra_moduli_macros::moduli_declare! { + Ed25519Coord { modulus = "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED" }, + Ed25519Scalar { modulus = "0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED" }, +} + +pub const ED25519_NUM_LIMBS: usize = 32; +pub const ED25519_LIMB_BITS: usize = 8; +pub const ED25519_BLOCK_SIZE: usize = 32; +// from_const_bytes is little endian +pub const CURVE_A: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "ECFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF7F" +)); +pub const CURVE_D: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "A3785913CA4DEB75ABD841414D0A700098E879777940C78C73FE6F2BEE6C0352" +)); + +openvm_ecc_te_macros::te_declare! { + Ed25519Point { mod_type = Ed25519Coord, a = CURVE_A, d = CURVE_D }, +} + +impl CyclicGroup for Ed25519Point { + // from_const_bytes is little endian + const GENERATOR: Self = Ed25519Point { + x: Ed25519Coord::from_const_bytes(hex!( + "1AD5258F602D56C9B2A7259560C72C695CDCD6FD31E2A4C0FE536ECDD3366921" + )), + y: Ed25519Coord::from_const_bytes(hex!( + "5866666666666666666666666666666666666666666666666666666666666666" + )), + }; + const NEG_GENERATOR: Self = Ed25519Point { + x: Ed25519Coord::from_const_bytes([ + 211, 42, 218, 112, 159, 210, 169, 54, 77, 88, 218, 106, 159, 56, 211, 150, 163, 35, 41, + 2, 206, 29, 91, 63, 1, 172, 145, 50, 44, 201, 150, 94, + ]), + y: Ed25519Coord::from_const_bytes(hex!( + "5866666666666666666666666666666666666666666666666666666666666666" + )), + }; +} + +impl IntrinsicCurve for Ed25519Point { + type Scalar = Ed25519Scalar; + type Point = Ed25519Point; + + fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point + where + for<'a> &'a Self::Point: Add<&'a Self::Point, Output = Self::Point>, + { + if coeffs.len() < 25 { + let table = crate::edwards::CachedMulTable::::new(bases, 4); + table.windowed_mul(coeffs) + } else { + crate::msm(coeffs, bases) + } + } +} diff --git a/extensions/ecc/guest/src/edwards.rs b/extensions/ecc/guest/src/edwards.rs new file mode 100644 index 0000000000..e9089e49d8 --- /dev/null +++ b/extensions/ecc/guest/src/edwards.rs @@ -0,0 +1,396 @@ +use alloc::vec::Vec; +use core::ops::{AddAssign, Mul}; + +use openvm_algebra_guest::{Field, IntMod}; + +use crate::{Group, IntrinsicCurve}; + +pub trait TwistedEdwardsPoint: Sized { + /// The `a` coefficient in the twisted Edwards curve equation `ax^2 + y^2 = 1 + d x^2 y^2`. + const CURVE_A: Self::Coordinate; + /// The `d` coefficient in the twisted Edwards curve equation `ax^2 + y^2 = 1 + d x^2 y^2`. + const CURVE_D: Self::Coordinate; + const IDENTITY: Self; + + type Coordinate: Field; + + /// The concatenated `x, y` coordinates of the affine point, where + /// coordinates are in little endian. + /// + /// **Warning**: The memory layout of `Self` is expected to pack + /// `x` and `y` contigously with no unallocated space in between. + fn as_le_bytes(&self) -> &[u8]; + + /// Raw constructor without asserting point is on the curve. + fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self; + fn into_coords(self) -> (Self::Coordinate, Self::Coordinate); + fn x(&self) -> &Self::Coordinate; + fn y(&self) -> &Self::Coordinate; + fn x_mut(&mut self) -> &mut Self::Coordinate; + fn y_mut(&mut self) -> &mut Self::Coordinate; + + fn add_impl(&self, p2: &Self) -> Self; + + #[inline(always)] + fn from_xy(x: Self::Coordinate, y: Self::Coordinate) -> Option + where + for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>, + { + let lhs = Self::CURVE_A * &x * &x + &y * &y; + let rhs = Self::CURVE_D * &x * &x * &y * &y + &Self::Coordinate::ONE; + if lhs != rhs { + return None; + } + Some(Self::from_xy_unchecked(x, y)) + } +} + +/// Macro to generate a newtype wrapper for [AffinePoint](crate::AffinePoint) +/// that implements elliptic curve operations by using the underlying field operations according to +/// the [formulas](https://en.wikipedia.org/wiki/Twisted_Edwards_curve) for twisted Edwards curves. +/// +/// The following imports are required: +/// ```rust +/// use core::ops::AddAssign; +/// +/// use openvm_algebra_guest::{DivUnsafe, Field}; +/// use openvm_ecc_guest::{edwards::TwistedEdwardsPoint, AffinePoint, Group}; +/// ``` +#[macro_export] +macro_rules! impl_te_affine { + ($struct_name:ident, $field:ty, $a:expr, $d:expr) => { + /// A newtype wrapper for [AffinePoint] that implements elliptic curve operations + /// by using the underlying field operations according to the [formulas](https://en.wikipedia.org/wiki/Twisted_Edwards_curve) for twisted Edwards curves. + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)] + #[repr(transparent)] + pub struct $struct_name(AffinePoint<$field>); + + impl TwistedEdwardsPoint for $struct_name { + const CURVE_A: $field = $a; + const CURVE_D: $field = $d; + const IDENTITY: Self = Self(AffinePoint::new(<$field>::ZERO, <$field>::ONE)); + + type Coordinate = $field; + + /// SAFETY: assumes that [$field] has internal representation in little-endian. + fn as_le_bytes(&self) -> &[u8] { + unsafe { + &*core::ptr::slice_from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of::(), + ) + } + } + fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self { + Self(AffinePoint::new(x, y)) + } + fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) { + (self.0.x, self.0.y) + } + fn x(&self) -> &Self::Coordinate { + &self.0.x + } + fn y(&self) -> &Self::Coordinate { + &self.0.y + } + fn x_mut(&mut self) -> &mut Self::Coordinate { + &mut self.0.x + } + fn y_mut(&mut self) -> &mut Self::Coordinate { + &mut self.0.y + } + + fn add_impl(&self, p2: &Self) -> Self { + use ::openvm_algebra_guest::DivUnsafe; + // For twisted Edwards curves: + // x3 = (x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2) + // y3 = (y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2) + let x1y2 = self.x() * p2.y(); + let y1x2 = self.y() * p2.x(); + let x1x2 = self.x() * p2.x(); + let y1y2 = self.y() * p2.y(); + let dx1x2y1y2 = Self::CURVE_D * x1x2 * y1y2; + + let x3 = (x1y2 + y1x2).div_unsafe(&(Self::Coordinate::ONE + dx1x2y1y2)); + let y3 = (y1y2 - Self::CURVE_A * x1x2).div_unsafe(&(Self::Coordinate::ONE - dx1x2y1y2)); + + Self(AffinePoint::new(x3, y3)) + } + + impl core::ops::Neg for $struct_name { + type Output = Self; + + fn neg(mut self) -> Self::Output { + self.0.x.neg_assign(); + self + } + } + + impl core::ops::Neg for &$struct_name { + type Output = $struct_name; + + fn neg(self) -> Self::Output { + self.clone().neg() + } + } + + impl From<$struct_name> for AffinePoint<$field> { + fn from(value: $struct_name) -> Self { + value.0 + } + } + + impl From> for $struct_name { + fn from(value: AffinePoint<$field>) -> Self { + Self(value) + } + } + } + } +} + +/// Implements `Group` on `$struct_name` assuming that `$struct_name` implements +/// `TwistedEdwardsPoint`. Assumes that `Neg` is implemented for `&$struct_name`. +#[macro_export] +macro_rules! impl_te_group_ops { + ($struct_name:ident, $field:ty) => { + impl Group for $struct_name { + type SelfRef<'a> = &'a Self; + + const IDENTITY: Self = ::IDENTITY; + + fn double(&self) -> Self { + if self.is_identity() { + self.clone() + } else { + self.add_impl(self) + } + } + + fn double_assign(&mut self) { + if !self.is_identity() { + *self = self.add_impl(self) + } + } + + // Note: It was found that implementing `is_identity` in group.rs as a default + // implementation increases the cycle count by 50% on the ecrecover benchmark. For + // this reason, we implement it here instead. We hypothesize that this is due to + // compiler optimizations that are not possible when the `is_identity` function is + // defined in a different source file. + #[inline(always)] + fn is_identity(&self) -> bool { + self == &::IDENTITY + } + } + + impl core::ops::Add<&$struct_name> for $struct_name { + type Output = Self; + + fn add(mut self, p2: &$struct_name) -> Self::Output { + use core::ops::AddAssign; + self.add_assign(p2); + self + } + } + + impl core::ops::Add for $struct_name { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + self.add(&rhs) + } + } + + impl core::ops::Add<&$struct_name> for &$struct_name { + type Output = $struct_name; + + fn add(self, p2: &$struct_name) -> Self::Output { + if self.is_identity() { + p2.clone() + } else if p2.is_identity() { + self.clone() + } else if self.x() + p2.x() == <$field as openvm_algebra_guest::Field>::ZERO + && self.y() == p2.y() + { + <$struct_name as TwistedEdwardsPoint>::IDENTITY + } else { + self.add_impl(p2) + } + } + } + + impl core::ops::AddAssign<&$struct_name> for $struct_name { + fn add_assign(&mut self, p2: &$struct_name) { + if self.is_identity() { + *self = p2.clone(); + } else if p2.is_identity() { + // do nothing + } else if self.x() + p2.x() == <$field as openvm_algebra_guest::Field>::ZERO + && self.y() == p2.y() + { + *self = <$struct_name as TwistedEdwardsPoint>::IDENTITY; + } else { + *self = self.add_impl(p2); + } + } + } + + impl core::ops::AddAssign for $struct_name { + fn add_assign(&mut self, rhs: Self) { + self.add_assign(&rhs); + } + } + + impl core::ops::Sub<&$struct_name> for $struct_name { + type Output = Self; + + fn sub(self, rhs: &$struct_name) -> Self::Output { + core::ops::Sub::sub(&self, rhs) + } + } + + impl core::ops::Sub for $struct_name { + type Output = $struct_name; + + fn sub(self, rhs: Self) -> Self::Output { + self.sub(&rhs) + } + } + + impl core::ops::Sub<&$struct_name> for &$struct_name { + type Output = $struct_name; + + fn sub(self, p2: &$struct_name) -> Self::Output { + use core::ops::Add; + self.add(&-p2) + } + } + + impl core::ops::SubAssign<&$struct_name> for $struct_name { + fn sub_assign(&mut self, p2: &$struct_name) { + use core::ops::AddAssign; + self.add_assign(-p2); + } + } + + impl core::ops::SubAssign for $struct_name { + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign(&rhs); + } + } + }; +} + +// This is the same as the Weierstrass version, but for Edwards curves we use +// TwistedEdwardsPoint::add_impl instead of WeierstrassPoint::add_ne_nonidentity, etc. +// Unlike the Weierstrass version, we do not require the bases to have prime order, since our +// addition formulas are complete. + +// MSM using preprocessed table (windowed method) +// Reference: modified from https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/mod.rs + +/// Cached precomputations of scalar multiples of several base points. +/// - `window_bits` is the window size used for the precomputation +/// - `max_scalar_bits` is the maximum size of the scalars that will be multiplied +/// - `table` is the precomputed table +pub struct CachedMulTable<'a, C: IntrinsicCurve> { + /// Window bits. Must be > 0. + /// For alignment, we currently require this to divide 8 (bits in a byte). + pub window_bits: usize, + pub bases: &'a [C::Point], + /// `table[i][j] = (j + 2) * bases[i]` for `j + 2 < 2 ** window_bits` + table: Vec>, + /// Needed to return reference to the identity point. + identity: C::Point, +} + +impl<'a, C: IntrinsicCurve> CachedMulTable<'a, C> +where + C::Point: TwistedEdwardsPoint + Group, + C::Scalar: IntMod, +{ + pub fn new(bases: &'a [C::Point], window_bits: usize) -> Self { + assert!(window_bits > 0); + let window_size = 1 << window_bits; + let table = bases + .iter() + .map(|base| { + if base.is_identity() { + vec![::IDENTITY; window_size - 2] + } else { + let mut multiples = Vec::with_capacity(window_size - 2); + for _ in 0..window_size - 2 { + let multiple = multiples + .last() + .map(|last| TwistedEdwardsPoint::add_impl(last, base)) + .unwrap_or_else(|| base.double()); + multiples.push(multiple); + } + multiples + } + }) + .collect(); + + Self { + window_bits, + bases, + table, + identity: ::IDENTITY, + } + } + + fn get_multiple(&self, base_idx: usize, scalar: usize) -> &C::Point { + if scalar == 0 { + &self.identity + } else if scalar == 1 { + unsafe { self.bases.get_unchecked(base_idx) } + } else { + unsafe { self.table.get_unchecked(base_idx).get_unchecked(scalar - 2) } + } + } + + /// Computes `sum scalars[i] * bases[i]`. + /// + /// For implementation simplicity, currently only implemented when + /// `window_bits` divides 8 (number of bits in a byte). + pub fn windowed_mul(&self, scalars: &[C::Scalar]) -> C::Point { + assert_eq!(8 % self.window_bits, 0); + assert_eq!(scalars.len(), self.bases.len()); + let windows_per_byte = 8 / self.window_bits; + + let num_windows = C::Scalar::NUM_LIMBS * windows_per_byte; + let mask = (1u8 << self.window_bits) - 1; + + // The current byte index (little endian) at the current step of the + // windowed method, across all scalars. + let mut limb_idx = C::Scalar::NUM_LIMBS; + // The current bit (little endian) within the current byte of the windowed + // method. The window will look at bits `bit_idx..bit_idx + window_bits`. + // bit_idx will always be in range [0, 8) + let mut bit_idx = 0; + + let mut res = ::IDENTITY; + for outer in 0..num_windows { + if bit_idx == 0 { + limb_idx -= 1; + bit_idx = 8 - self.window_bits; + } else { + bit_idx -= self.window_bits; + } + + if outer != 0 { + for _ in 0..self.window_bits { + res.double_assign(); + } + } + for (base_idx, scalar) in scalars.iter().enumerate() { + let scalar = (scalar.as_le_bytes()[limb_idx] >> bit_idx) & mask; + let summand = self.get_multiple(base_idx, scalar as usize); + // handles identity + res.add_assign(summand); + } + } + res + } +} diff --git a/extensions/ecc/guest/src/lib.rs b/extensions/ecc/guest/src/lib.rs index c7a9851cfd..aaaa1ee9fc 100644 --- a/extensions/ecc/guest/src/lib.rs +++ b/extensions/ecc/guest/src/lib.rs @@ -6,6 +6,7 @@ extern crate alloc; pub use once_cell; pub use openvm_algebra_guest as algebra; pub use openvm_ecc_sw_macros as sw_macros; +pub use openvm_ecc_te_macros as te_macros; use strum_macros::FromRepr; mod affine_point; @@ -17,11 +18,16 @@ pub use msm::*; /// Optimized ECDSA implementation with the same functional interface as the `ecdsa` crate pub mod ecdsa; +/// Edwards curve traits +pub mod edwards; /// Weierstrass curve traits pub mod weierstrass; +#[cfg(feature = "ed25519")] +pub mod ed25519; + /// This is custom-1 defined in RISC-V spec document -pub const OPCODE: u8 = 0x2b; +pub const SW_OPCODE: u8 = 0x2b; pub const SW_FUNCT3: u8 = 0b001; /// Short Weierstrass curves are configurable. @@ -37,3 +43,46 @@ pub enum SwBaseFunct7 { impl SwBaseFunct7 { pub const SHORT_WEIERSTRASS_MAX_KINDS: u8 = 8; } + +/// This is custom-1 defined in RISC-V spec document +pub const TE_OPCODE: u8 = 0x2b; +pub const TE_FUNCT3: u8 = 0b100; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, FromRepr)] +#[repr(u8)] +pub enum TeBaseFunct7 { + TeAdd = 0, + TeSetup, + TeHintDecompress, + TeHintNonQr, +} + +impl TeBaseFunct7 { + pub const TWISTED_EDWARDS_MAX_KINDS: u8 = 8; +} + +/// A trait for elliptic curves that bridges the openvm types and external types with +/// CurveArithmetic etc. Implement this for external curves with corresponding openvm point and +/// scalar types. +pub trait IntrinsicCurve { + type Scalar: Clone; + type Point: Clone; + + /// Multi-scalar multiplication. + /// The implementation may be specialized to use properties of the curve + /// (e.g., if the curve order is prime). + fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point; +} + +pub trait FromCompressed { + /// Given `x`-coordinate, + /// + /// Decompresses a point from its x-coordinate and a recovery identifier which indicates + /// the parity of the y-coordinate. Given the x-coordinate, this function attempts to find the + /// corresponding y-coordinate that satisfies the elliptic curve equation. If successful, it + /// returns the point as an instance of Self. If the point cannot be decompressed, it returns + /// None. + fn decompress(x: Coordinate, rec_id: &u8) -> Option + where + Self: core::marker::Sized; +} diff --git a/extensions/ecc/guest/src/weierstrass.rs b/extensions/ecc/guest/src/weierstrass.rs index 82d5468b04..3c39cbcfe6 100644 --- a/extensions/ecc/guest/src/weierstrass.rs +++ b/extensions/ecc/guest/src/weierstrass.rs @@ -4,6 +4,7 @@ use core::ops::Mul; use openvm_algebra_guest::{Field, IntMod}; use super::group::Group; +use crate::IntrinsicCurve; /// Short Weierstrass curve affine point. pub trait WeierstrassPoint: Clone + Sized { @@ -113,32 +114,6 @@ pub trait WeierstrassPoint: Clone + Sized { } } -pub trait FromCompressed { - /// Given `x`-coordinate, - /// - /// Decompresses a point from its x-coordinate and a recovery identifier which indicates - /// the parity of the y-coordinate. Given the x-coordinate, this function attempts to find the - /// corresponding y-coordinate that satisfies the elliptic curve equation. If successful, it - /// returns the point as an instance of Self. If the point cannot be decompressed, it returns - /// None. - fn decompress(x: Coordinate, rec_id: &u8) -> Option - where - Self: core::marker::Sized; -} - -/// A trait for elliptic curves that bridges the openvm types and external types with -/// CurveArithmetic etc. Implement this for external curves with corresponding openvm point and -/// scalar types. -pub trait IntrinsicCurve { - type Scalar: Clone; - type Point: Clone; - - /// Multi-scalar multiplication. - /// The implementation may be specialized to use properties of the curve - /// (e.g., if the curve order is prime). - fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point; -} - // MSM using preprocessed table (windowed method) // Reference: modified from https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/mod.rs // @@ -476,11 +451,11 @@ macro_rules! impl_sw_group_ops { self.double_assign_impl::(); } - // This implementation is the same as the default implementation in the `Group` trait, - // but it was found that overriding the default implementation reduced the cycle count - // by 50% on the ecrecover benchmark. - // We hypothesize that this is due to compiler optimizations that are not possible when - // the `is_identity` function is defined in a different source file. + // Note: It was found that implementing `is_identity` in group.rs as a default + // implementation increases the cycle count by 50% on the ecrecover benchmark. For + // this reason, we implement it here instead. We hypothesize that this is due to + // compiler optimizations that are not possible when the `is_identity` function is + // defined in a different source file. #[inline(always)] fn is_identity(&self) -> bool { self == &::IDENTITY diff --git a/extensions/ecc/sw-macros/README.md b/extensions/ecc/sw-macros/README.md index 71f8d553f4..994bd5a6d4 100644 --- a/extensions/ecc/sw-macros/README.md +++ b/extensions/ecc/sw-macros/README.md @@ -28,7 +28,7 @@ openvm_algebra_guest::moduli_macros::moduli_init! { } openvm_ecc_guest::sw_macros::sw_init! { - Secp256k1Point, + "Secp256k1Point", } */ @@ -45,7 +45,7 @@ The crate provides two macros: `sw_declare!` and `sw_init!`. The signatures are: - `sw_declare!` receives comma-separated list of moduli classes descriptions. Each description looks like `SwStruct { mod_type = ModulusName, a = a_expr, b = b_expr }`. Here `ModulusName` is the name of a struct that implements `trait IntMod` -- in particular, the ones created by `moduli_declare!` do -- and has `NUM_LIMBS` divisible by 4. Parameters `a` and `b` correspond to the coefficients of the equation defining the curve. They **must be compile-time constants**. The parameter `a` may be omitted, in which case it defaults to `0` (or, more specifically, to `::ZERO`). The parameter `b` is required. -- `sw_init!` receives comma-separated list of struct names. The struct name must exactly match the name in `sw_declare!` -- type defs are not allowed (see point 5 below). +- `sw_init!` receives comma-separated list of struct names as string literals. Each struct name must exactly match the name in `sw_declare!` -- type defs are not allowed (see point 5 below). What happens under the hood: @@ -93,7 +93,7 @@ mod openvm_intrinsics_ffi_2 { 3. Again, if using the Rust bindings, then the `sw_setup_extern_func_*` function for every curve is automatically called on first use of any of the curve's intrinsics. -4. The order of the items in `sw_init!` **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of `CurveConfig`s in `WeierstrassExtension::supported_curves`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). +4. The order of the items in `sw_init!` **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of `CurveConfig`s in `EccExtension::supported_sw_curves`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). 5. Note that, due to the nature of function names, the name of the struct used in `sw_init!` must be the same as in `sw_declare!`. To illustrate, the following code will **fail** to compile: @@ -107,7 +107,7 @@ sw_declare! { pub type Sw = Secp256k1Point; sw_init! { - Sw, + "Sw", } ``` diff --git a/extensions/ecc/sw-macros/src/lib.rs b/extensions/ecc/sw-macros/src/lib.rs index 7af9e77daf..92f61634a1 100644 --- a/extensions/ecc/sw-macros/src/lib.rs +++ b/extensions/ecc/sw-macros/src/lib.rs @@ -5,7 +5,7 @@ use proc_macro::TokenStream; use quote::format_ident; use syn::{ parse::{Parse, ParseStream}, - parse_macro_input, Expr, ExprPath, Path, Token, + parse_macro_input, ExprPath, LitStr, Token, }; /// This macro generates the code to setup the elliptic curve for a given modular type. Also it @@ -28,9 +28,8 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { let span = proc_macro::Span::call_site(); for item in items.into_iter() { - let struct_name = item.name.to_string(); - let struct_name = syn::Ident::new(&struct_name, span.into()); - let struct_path: syn::Path = syn::parse_quote!(#struct_name); + let struct_name_str = item.name.to_string(); + let struct_name = syn::Ident::new(&struct_name_str, span.into()); let mut intmod_type: Option = None; let mut const_a: Option = None; let mut const_b: Option = None; @@ -71,16 +70,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { macro_rules! create_extern_func { ($name:ident) => { let $name = syn::Ident::new( - &format!( - "{}_{}", - stringify!($name), - struct_path - .segments - .iter() - .map(|x| x.ident.to_string()) - .collect::>() - .join("_") - ), + &format!("{}_{}", stringify!($name), struct_name_str), span.into(), ); }; @@ -89,13 +79,13 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { create_extern_func!(sw_double_extern_func); create_extern_func!(sw_setup_extern_func); - let group_ops_mod_name = format_ident!("{}_ops", struct_name.to_string().to_lowercase()); + let group_ops_mod_name = format_ident!("{}_ops", struct_name_str.to_lowercase()); let result = TokenStream::from(quote::quote_spanned! { span.into() => extern "C" { fn #sw_add_ne_extern_func(rd: usize, rs1: usize, rs2: usize); fn #sw_double_extern_func(rd: usize, rs1: usize); - fn #sw_setup_extern_func(); + fn #sw_setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8); } #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)] @@ -196,8 +186,21 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { #[cfg(target_os = "zkvm")] fn set_up_once() { static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new(); + is_setup.get_or_init(|| { - unsafe { #sw_setup_extern_func(); } + // p1 is (x1, y1), and x1 must be the modulus. + // y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble + let modulus_bytes = <::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; + let mut one = [0u8; <::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS]; + one[0] = 1; + let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#struct_name as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A); + // p1 should be (p, a) + let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat(); + // (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add. + let p2 = [one.as_ref(), one.as_ref()].concat(); + let mut uninit: core::mem::MaybeUninit<[Self; 2]> = core::mem::MaybeUninit::uninit(); + + unsafe { #sw_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); } <#intmod_type as openvm_algebra_guest::IntMod>::set_up_once(); true }); @@ -214,7 +217,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { use openvm_algebra_guest::IntMod; // Safety: Self::set_up_once() ensures IntMod::set_up_once() has been called. unsafe { - self.x.eq_impl::(&#intmod_type::ZERO) && self.y.eq_impl::(&#intmod_type::ZERO) + self.x.eq_impl::(&<#intmod_type as IntMod>::ZERO) && self.y.eq_impl::(&<#intmod_type as IntMod>::ZERO) } } } @@ -373,7 +376,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } mod #group_ops_mod_name { - use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint, FromCompressed}, impl_sw_group_ops, algebra::IntMod}; + use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint}, FromCompressed, impl_sw_group_ops, algebra::IntMod}; use super::*; impl_sw_group_ops!(#struct_name, #intmod_type); @@ -410,23 +413,14 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } struct SwDefine { - items: Vec, + items: Vec, } impl Parse for SwDefine { fn parse(input: ParseStream) -> syn::Result { - let items = input.parse_terminated(::parse, Token![,])?; + let items = input.parse_terminated(::parse, Token![,])?; Ok(Self { - items: items - .into_iter() - .map(|e| { - if let Expr::Path(p) = e { - p.path - } else { - panic!("expected path"); - } - }) - .collect(), + items: items.into_iter().map(|e| e.value()).collect(), }) } } @@ -439,25 +433,21 @@ pub fn sw_init(input: TokenStream) -> TokenStream { let span = proc_macro::Span::call_site(); - for (ec_idx, item) in items.into_iter().enumerate() { - let str_path = item - .segments - .iter() - .map(|x| x.ident.to_string()) - .collect::>() - .join("_"); + for (ec_idx, struct_id) in items.into_iter().enumerate() { + // Unique identifier shared by sw_define! and sw_init! used for naming the extern funcs. + // Currently it's just the struct type name. let add_ne_extern_func = - syn::Ident::new(&format!("sw_add_ne_extern_func_{}", str_path), span.into()); + syn::Ident::new(&format!("sw_add_ne_extern_func_{}", struct_id), span.into()); let double_extern_func = - syn::Ident::new(&format!("sw_double_extern_func_{}", str_path), span.into()); + syn::Ident::new(&format!("sw_double_extern_func_{}", struct_id), span.into()); let setup_extern_func = - syn::Ident::new(&format!("sw_setup_extern_func_{}", str_path), span.into()); + syn::Ident::new(&format!("sw_setup_extern_func_{}", struct_id), span.into()); externs.push(quote::quote_spanned! { span.into() => #[no_mangle] extern "C" fn #add_ne_extern_func(rd: usize, rs1: usize, rs2: usize) { openvm::platform::custom_insn_r!( - opcode = OPCODE, + opcode = SW_OPCODE, funct3 = SW_FUNCT3 as usize, funct7 = SwBaseFunct7::SwAddNe as usize + #ec_idx * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), @@ -470,7 +460,7 @@ pub fn sw_init(input: TokenStream) -> TokenStream { #[no_mangle] extern "C" fn #double_extern_func(rd: usize, rs1: usize) { openvm::platform::custom_insn_r!( - opcode = OPCODE, + opcode = SW_OPCODE, funct3 = SW_FUNCT3 as usize, funct7 = SwBaseFunct7::SwDouble as usize + #ec_idx * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), @@ -481,41 +471,31 @@ pub fn sw_init(input: TokenStream) -> TokenStream { } #[no_mangle] - extern "C" fn #setup_extern_func() { + extern "C" fn #setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8) { #[cfg(target_os = "zkvm")] { - use super::#item; - // p1 is (x1, y1), and x1 must be the modulus. - // y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble - let modulus_bytes = <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; - let mut one = [0u8; <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS]; - one[0] = 1; - let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A); - // p1 should be (p, a) - let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat(); - // (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add. - let p2 = [one.as_ref(), one.as_ref()].concat(); - let mut uninit: core::mem::MaybeUninit<[#item; 2]> = core::mem::MaybeUninit::uninit(); openvm::platform::custom_insn_r!( - opcode = ::openvm_ecc_guest::OPCODE, + opcode = ::openvm_ecc_guest::SW_OPCODE, funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize, funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize + #ec_idx * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), - rd = In uninit.as_mut_ptr(), - rs1 = In p1.as_ptr(), - rs2 = In p2.as_ptr() + rd = In uninit, + rs1 = In p1, + rs2 = In p2 ); openvm::platform::custom_insn_r!( - opcode = ::openvm_ecc_guest::OPCODE, + opcode = ::openvm_ecc_guest::SW_OPCODE, funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize, funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize + #ec_idx * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), - rd = In uninit.as_mut_ptr(), - rs1 = In p1.as_ptr(), + rd = In uninit, + rs1 = In p1, rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_EC_DOUBLE ); + + } } }); @@ -524,8 +504,8 @@ pub fn sw_init(input: TokenStream) -> TokenStream { TokenStream::from(quote::quote_spanned! { span.into() => #[allow(non_snake_case)] #[cfg(target_os = "zkvm")] - mod openvm_intrinsics_ffi_2 { - use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7}; + mod openvm_intrinsics_ffi_2_sw { + use ::openvm_ecc_guest::{SW_OPCODE, SW_FUNCT3, SwBaseFunct7}; #(#externs)* } diff --git a/extensions/ecc/te-macros/Cargo.toml b/extensions/ecc/te-macros/Cargo.toml new file mode 100644 index 0000000000..de3544ff87 --- /dev/null +++ b/extensions/ecc/te-macros/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "openvm-ecc-te-macros" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +syn = { version = "2.0", features = ["full"] } +quote = "1.0" +openvm-macros-common = { workspace = true, default-features = false } + +[lib] +proc-macro = true diff --git a/extensions/ecc/te-macros/README.md b/extensions/ecc/te-macros/README.md new file mode 100644 index 0000000000..6de5c50110 --- /dev/null +++ b/extensions/ecc/te-macros/README.md @@ -0,0 +1,125 @@ +# `openvm-ecc-te-macros` + +Procedural macros for use in guest program to generate short twisted Edwards elliptic curve struct with custom intrinsics for compile-time modulus. + +The workflow of this macro is very similar to the [`openvm-algebra-moduli-macros`](../moduli-macros/README.md) crate. We recommend reading it first. + +## Example + +```rust +// ... + +moduli_declare! { + Ed25519Coord { modulus = "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED" }, + Ed25519Scalar { modulus = "0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED" }, +} + +// Note that from_const_bytes is little endian +pub const CURVE_A: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "ECFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF7F" +)); +pub const CURVE_D: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "A3785913CA4DEB75ABD841414D0A700098E879777940C78C73FE6F2BEE6C0352" +)); + +sw_declare! { + Ed25519Point { mod_type = Ed25519Coord, a = CURVE_A, d = CURVE_D }, +} + +openvm_algebra_guest::moduli_macros::moduli_init! { + "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED", + "0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED", +} + +openvm_ecc_guest::te_macros::te_init! { + Ed25519Point, +} + +pub fn main() { + setup_all_moduli(); + setup_all_te_curves(); + // ... +} +``` + +## Full story + +Again, the principle is the same as in the [`openvm-algebra-moduli-macros`](../moduli-macros/README.md) crate. Here we emphasize the core differences. + +The crate provides two macros: `te_declare!` and `te_init!`. The signatures are: + +- `te_declare!` receives comma-separated list of moduli classes descriptions. Each description looks like `TeStruct { mod_type = ModulusName, a = a_expr, d = d_expr }`. Here `ModulusName` is the name of any struct that implements `trait IntMod` -- in particular, the ones created by `moduli_declare!` do. Parameters `a` and `d` correspond to the coefficients of the equation defining the curve. They **must be compile-time constants**. Both the parameters `a` and `d` are required. + +- `te_init!` receives comma-separated list of struct names. The struct name must exactly match the name in `te_declare!` -- type defs are not allowed (see point 5 below). + +What happens under the hood: + +1. `te_declare!` macro creates a struct with two field `x` and `y` of type `mod_type`. This struct denotes a point on the corresponding elliptic curve. In the example it would be + +```rust +struct Ed25519Point { + x: Ed25519Coord, + y: Ed25519Coord, +} +``` + +Similar to `moduli_declare!`, this macro also creates extern functions for arithmetic operations -- but in this case they are named after the te type, not after any hexadecimal (since the macro has no way to obtain it from the name of the modulus type anyway): + +```rust +extern "C" { + fn te_add_extern_func_Ed25519Point(rd: usize, rs1: usize, rs2: usize); + fn hint_decompress_extern_func_Ed25519Point(rs1: usize, rs2: usize); +} +``` + +2. Again, `te_init!` macro implements these extern functions and defines the setup functions for the te struct. + +```rust +#[cfg(target_os = "zkvm")] +mod openvm_intrinsics_ffi_2 { + use :openvm_ecc_guest::{OPCODE, TE_FUNCT3, TeBaseFunct7}; + + #[no_mangle] + extern "C" fn te_add_extern_func_Ed25519Point(rd: usize, rs1: usize, rs2: usize) { + // ... + } + // other externs +} +#[allow(non_snake_case)] +pub fn setup_te_Ed25519Point() { + #[cfg(target_os = "zkvm")] + { + // ... + } +} +pub fn setup_all_te_curves() { + setup_te_Ed25519Point(); + // other setups +} +``` + +3. Again, if using the Rust bindings, then the `te_setup_extern_func_*` function for every curve is automatically called on first use of any of the curve's intrinsics. + +4. The order of the items in `te_init!` **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of `CurveConfig`s in `EccExtension::supported_te_curves`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). + +5. Note that, due to the nature of function names, the name of the struct used in `te_init!` must be the same as in `te_declare!`. To illustrate, the following code will **fail** to compile: + +```rust +// ... + +te_declare! { + Ed25519Point { mod_type = Ed25519Coord, a = CURVE_A, d = CURVE_D }, +} + +pub type Te = Ed25519Point; + +te_init! { + Te, +} +``` + +The reason is that, for example, the function `sw_add_extern_func_Secp256k1Point` remains unimplemented, but we implement `sw_add_extern_func_Sw`. + +6. `cargo openvm build` will automatically generate a call to `te_init!` based on `openvm.toml`. +Note that `openvm.toml` must contain the name of each struct created by `te_declare!` as a string (in the example at the top of this document, its `"Ed25519Point"`). +The SDK also supports this feature. diff --git a/extensions/ecc/te-macros/src/lib.rs b/extensions/ecc/te-macros/src/lib.rs new file mode 100644 index 0000000000..6f1b3af29f --- /dev/null +++ b/extensions/ecc/te-macros/src/lib.rs @@ -0,0 +1,345 @@ +extern crate proc_macro; + +use openvm_macros_common::MacroArgs; +use proc_macro::TokenStream; +use quote::format_ident; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, ExprPath, LitStr, Token, +}; + +/// This macro generates the code to setup a Twisted Edwards elliptic curve for a given modular +/// type. Also it places the curve parameters into a special static variable to be later extracted +/// from the ELF and used by the VM. Usage: +/// ``` +/// te_declare! { +/// [TODO] +/// } +/// ``` +/// +/// For this macro to work, you must import the `elliptic_curve` crate and the `openvm_ecc_guest` +/// crate.. +#[proc_macro] +pub fn te_declare(input: TokenStream) -> TokenStream { + let MacroArgs { items } = parse_macro_input!(input as MacroArgs); + + let mut output = Vec::new(); + + let span = proc_macro::Span::call_site(); + + for item in items.into_iter() { + let struct_name = item.name.to_string(); + let struct_name = syn::Ident::new(&struct_name, span.into()); + let struct_path: syn::Path = syn::parse_quote!(#struct_name); + let mut intmod_type: Option = None; + let mut const_a: Option = None; + let mut const_d: Option = None; + for param in item.params { + match param.name.to_string().as_str() { + "mod_type" => { + if let syn::Expr::Path(ExprPath { path, .. }) = param.value { + intmod_type = Some(path) + } else { + return syn::Error::new_spanned(param.value, "Expected a type") + .to_compile_error() + .into(); + } + } + "a" => { + const_a = Some(param.value); + } + "d" => { + const_d = Some(param.value); + } + _ => { + panic!("Unknown parameter {}", param.name); + } + } + } + + let intmod_type = intmod_type.expect("mod_type parameter is required"); + let const_a = const_a.expect("constant a coefficient is required"); + let const_d = const_d.expect("constant d coefficient is required"); + + macro_rules! create_extern_func { + ($name:ident) => { + let $name = syn::Ident::new( + &format!( + "{}_{}", + stringify!($name), + struct_path + .segments + .iter() + .map(|x| x.ident.to_string()) + .collect::>() + .join("_") + ), + span.into(), + ); + }; + } + create_extern_func!(te_add_extern_func); + create_extern_func!(te_setup_extern_func); + + let group_ops_mod_name = format_ident!("{}_ops", struct_name.to_string().to_lowercase()); + + let result = TokenStream::from(quote::quote_spanned! { span.into() => + extern "C" { + fn #te_add_extern_func(rd: usize, rs1: usize, rs2: usize); + fn #te_setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8); + } + + #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)] + #[repr(C)] + pub struct #struct_name { + x: #intmod_type, + y: #intmod_type, + } + + impl #struct_name { + const fn identity() -> Self { + Self { + x: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO, + y: <#intmod_type as openvm_algebra_guest::IntMod>::ONE, + } + } + // Below are wrapper functions for the intrinsic instructions. + // Should not be called directly. + #[inline(always)] + fn add_chip(p1: &#struct_name, p2: &#struct_name) -> #struct_name { + #[cfg(not(target_os = "zkvm"))] + { + use openvm_algebra_guest::DivUnsafe; + + let x1y2 = p1.x.clone() * p2.y.clone(); + let y1x2 = p1.y.clone() * p2.x.clone(); + let x1x2 = p1.x.clone() * p2.x.clone(); + let y1y2 = p1.y.clone() * p2.y.clone(); + let dx1x2y1y2 = ::CURVE_D * &x1x2 * &y1y2; + + let x3 = (x1y2 + y1x2).div_unsafe(&<#intmod_type as openvm_algebra_guest::IntMod>::ONE + &dx1x2y1y2); + let y3 = (y1y2 - ::CURVE_A * x1x2).div_unsafe(&<#intmod_type as openvm_algebra_guest::IntMod>::ONE - &dx1x2y1y2); + + #struct_name { x: x3, y: y3 } + } + #[cfg(target_os = "zkvm")] + { + Self::set_up_once(); + let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit(); + unsafe { + #te_add_extern_func( + uninit.as_mut_ptr() as usize, + p1 as *const #struct_name as usize, + p2 as *const #struct_name as usize + ) + }; + unsafe { uninit.assume_init() } + } + } + + // Helper function to call the setup instruction on first use + #[cfg(target_os = "zkvm")] + #[inline(always)] + fn set_up_once() { + static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new(); + is_setup.get_or_init(|| { + let modulus_bytes = <::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; + let mut zero = [0u8; <::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS]; + let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&::CURVE_A); + let curve_d_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&::CURVE_D); + let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat(); + let p2 = [curve_d_bytes.as_ref(), zero.as_ref()].concat(); + let mut uninit: core::mem::MaybeUninit<[Self; 2]> = core::mem::MaybeUninit::uninit(); + + unsafe { #te_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); } + <#intmod_type as openvm_algebra_guest::IntMod>::set_up_once(); + true + }); + } + + #[cfg(not(target_os = "zkvm"))] + #[inline(always)] + fn set_up_once() { + // No-op for non-ZKVM targets + } + } + + impl ::openvm_ecc_guest::edwards::TwistedEdwardsPoint for #struct_name { + const CURVE_A: Self::Coordinate = #const_a; + const CURVE_D: Self::Coordinate = #const_d; + + const IDENTITY: Self = Self::identity(); + type Coordinate = #intmod_type; + + /// SAFETY: assumes that #intmod_type has a memory representation + /// such that with repr(C), two coordinates are packed contiguously. + #[inline(always)] + fn as_le_bytes(&self) -> &[u8] { + unsafe { &*core::ptr::slice_from_raw_parts(self as *const Self as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS * 2) } + } + + #[inline(always)] + fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self { + Self { x, y } + } + + #[inline(always)] + fn x(&self) -> &Self::Coordinate { + &self.x + } + + #[inline(always)] + fn y(&self) -> &Self::Coordinate { + &self.y + } + + #[inline(always)] + fn x_mut(&mut self) -> &mut Self::Coordinate { + &mut self.x + } + + #[inline(always)] + fn y_mut(&mut self) -> &mut Self::Coordinate { + &mut self.y + } + + #[inline(always)] + fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) { + (self.x, self.y) + } + + #[inline(always)] + fn add_impl(&self, p2: &Self) -> Self { + Self::add_chip(self, p2) + } + } + + impl core::ops::Neg for #struct_name { + type Output = Self; + + fn neg(self) -> Self::Output { + #struct_name { + x: core::ops::Neg::neg(&self.x), + y: self.y, + } + } + } + + impl core::ops::Neg for &#struct_name { + type Output = #struct_name; + + fn neg(self) -> #struct_name { + #struct_name { + x: core::ops::Neg::neg(&self.x), + y: self.y.clone(), + } + } + } + + mod #group_ops_mod_name { + use ::openvm_ecc_guest::{edwards::TwistedEdwardsPoint, FromCompressed, impl_te_group_ops, algebra::{IntMod, DivUnsafe, DivAssignUnsafe, ExpBytes}}; + use super::*; + + impl_te_group_ops!(#struct_name, #intmod_type); + + impl FromCompressed<#intmod_type> for #struct_name { + fn decompress(y: #intmod_type, rec_id: &u8) -> Option { + use openvm_algebra_guest::{Sqrt, DivUnsafe}; + let x_squared = (<#intmod_type as openvm_algebra_guest::IntMod>::ONE - &y * &y).div_unsafe(<#struct_name as ::openvm_ecc_guest::edwards::TwistedEdwardsPoint>::CURVE_A - &<#struct_name as ::openvm_ecc_guest::edwards::TwistedEdwardsPoint>::CURVE_D * &y * &y); + let x = x_squared.sqrt(); + match x { + None => None, + Some(x) => { + let correct_x = if x.as_le_bytes()[0] & 1 == *rec_id & 1 { + x + } else { + -x + }; + // handle the case where x = 0 + if correct_x.as_le_bytes()[0] & 1 != *rec_id & 1 { + return None; + } + // In order for sqrt() to return Some, we are guaranteed that x * x == x_squared, which already proves (correct_x, y) is on the curve + Some(<#struct_name as ::openvm_ecc_guest::edwards::TwistedEdwardsPoint>::from_xy_unchecked(correct_x, y)) + } + } + } + } + } + }); + output.push(result); + } + + TokenStream::from_iter(output) +} + +struct TeDefine { + items: Vec, +} + +impl Parse for TeDefine { + fn parse(input: ParseStream) -> syn::Result { + let items = input.parse_terminated(::parse, Token![,])?; + Ok(Self { + items: items.into_iter().map(|e| e.value()).collect(), + }) + } +} + +#[proc_macro] +pub fn te_init(input: TokenStream) -> TokenStream { + let TeDefine { items } = parse_macro_input!(input as TeDefine); + + let mut externs = Vec::new(); + + let span = proc_macro::Span::call_site(); + + for (ec_idx, struct_id) in items.into_iter().enumerate() { + let add_extern_func = + syn::Ident::new(&format!("te_add_extern_func_{}", struct_id), span.into()); + let setup_extern_func = + syn::Ident::new(&format!("te_setup_extern_func_{}", struct_id), span.into()); + externs.push(quote::quote_spanned! { span.into() => + #[no_mangle] + extern "C" fn #add_extern_func(rd: usize, rs1: usize, rs2: usize) { + openvm::platform::custom_insn_r!( + opcode = TE_OPCODE, + funct3 = TE_FUNCT3 as usize, + funct7 = TeBaseFunct7::TeAdd as usize + #ec_idx + * (TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS as usize), + rd = In rd, + rs1 = In rs1, + rs2 = In rs2 + ); + } + + #[no_mangle] + extern "C" fn #setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8) { + #[cfg(target_os = "zkvm")] + { + + openvm::platform::custom_insn_r!( + opcode = ::openvm_ecc_guest::TE_OPCODE, + funct3 = ::openvm_ecc_guest::TE_FUNCT3 as usize, + funct7 = ::openvm_ecc_guest::TeBaseFunct7::TeSetup as usize + + #ec_idx + * (::openvm_ecc_guest::TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS as usize), + rd = In uninit, + rs1 = In p1, + rs2 = In p2, + ); + } + } + }); + } + + TokenStream::from(quote::quote_spanned! { span.into() => + #[allow(non_snake_case)] + #[cfg(target_os = "zkvm")] + mod openvm_intrinsics_ffi_2_te { + use ::openvm_ecc_guest::{TE_OPCODE, TE_FUNCT3, TeBaseFunct7}; + + #(#externs)* + } + }) +} diff --git a/extensions/ecc/tests/Cargo.toml b/extensions/ecc/tests/Cargo.toml index 5f90e77fa4..0de9c0249e 100644 --- a/extensions/ecc/tests/Cargo.toml +++ b/extensions/ecc/tests/Cargo.toml @@ -21,8 +21,8 @@ serde.workspace = true serde_with.workspace = true toml.workspace = true eyre.workspace = true -hex-literal.workspace = true num-bigint.workspace = true +hex-literal.workspace = true halo2curves-axiom = { workspace = true } [features] diff --git a/extensions/ecc/tests/programs/Cargo.toml b/extensions/ecc/tests/programs/Cargo.toml index 55fcedd3ee..065dc404f2 100644 --- a/extensions/ecc/tests/programs/Cargo.toml +++ b/extensions/ecc/tests/programs/Cargo.toml @@ -11,6 +11,7 @@ openvm-custom-insn = { path = "../../../../crates/toolchain/custom_insn", defaul openvm-ecc-guest = { path = "../../guest", default-features = false } openvm-ecc-sw-macros = { path = "../../../../extensions/ecc/sw-macros", default-features = false } +openvm-ecc-te-macros = { path = "../../../../extensions/ecc/te-macros", default-features = false } openvm-algebra-guest = { path = "../../../algebra/guest", default-features = false } openvm-algebra-moduli-macros = { path = "../../../algebra/moduli-macros", default-features = false } openvm-rv32im-guest = { path = "../../../../extensions/rv32im/guest", default-features = false } @@ -43,6 +44,7 @@ default = [] std = ["serde/std", "openvm/std"] k256 = ["dep:openvm-k256"] p256 = ["dep:openvm-p256"] +ed25519 = ["openvm-ecc-guest/ed25519"] [profile.release] panic = "abort" @@ -63,7 +65,7 @@ required-features = ["k256", "p256"] [[example]] name = "decompress" -required-features = ["k256"] +required-features = ["k256", "ed25519"] [[example]] name = "ecdsa" @@ -81,6 +83,10 @@ required-features = ["k256"] name = "sec1_decode" required-features = ["k256"] +[[example]] +name = "edwards_ec" +required-features = ["ed25519"] + [[example]] name = "invalid_setup" required-features = ["k256", "p256"] diff --git a/extensions/ecc/tests/programs/examples/decompress.rs b/extensions/ecc/tests/programs/examples/decompress.rs index 0148d5d057..f6e9870a3e 100644 --- a/extensions/ecc/tests/programs/examples/decompress.rs +++ b/extensions/ecc/tests/programs/examples/decompress.rs @@ -7,9 +7,11 @@ extern crate alloc; use hex_literal::hex; use openvm::io::read_vec; use openvm_ecc_guest::{ - algebra::IntMod, - weierstrass::{FromCompressed, WeierstrassPoint}, - Group, + algebra::{Field, IntMod}, + ed25519::{Ed25519Coord, Ed25519Point}, + edwards::TwistedEdwardsPoint, + weierstrass::WeierstrassPoint, + FromCompressed, Group, }; use openvm_k256::{Secp256k1Coord, Secp256k1Point}; @@ -22,7 +24,6 @@ openvm_algebra_moduli_macros::moduli_declare! { Fp1mod4 { modulus = "0xffffffffffffffffffffffffffffffff000000000000000000000001" }, } -// const CURVE_B_5MOD8: Fp5mod8 = Fp5mod8::from_const_u8(3); const CURVE_B_5MOD8: Fp5mod8 = Fp5mod8::from_const_u8(6); const CURVE_A_1MOD4: Fp1mod4 = Fp1mod4::from_const_bytes(hex!( @@ -44,7 +45,7 @@ openvm_ecc_sw_macros::sw_declare! { }, } -openvm::init!("openvm_init_decompress_k256.rs"); +openvm::init!("openvm_init_decompress_k256_ed25519.rs"); // test decompression under an honest host pub fn main() { @@ -53,35 +54,43 @@ pub fn main() { let y = Secp256k1Coord::from_le_bytes_unchecked(&bytes[32..64]); let rec_id = y.as_le_bytes()[0] & 1; - test_possible_decompression::(&x, &y, rec_id); + test_possible_sw_decompression::(&x, &y, rec_id); // x = 5 is not on the x-coordinate of any point on the Secp256k1 curve - test_impossible_decompression::(&Secp256k1Coord::from_u8(5), rec_id); + test_impossible_sw_decompression::(&Secp256k1Coord::from_u8(5), rec_id); let x = Fp5mod8::from_le_bytes_unchecked(&bytes[64..96]); let y = Fp5mod8::from_le_bytes_unchecked(&bytes[96..128]); let rec_id = y.as_le_bytes()[0] & 1; - test_possible_decompression::(&x, &y, rec_id); + test_possible_sw_decompression::(&x, &y, rec_id); // x = 0 is not on the x-coordinate of any point on the CurvePoint5mod8 curve - test_impossible_decompression::(&Fp5mod8::ZERO, rec_id); + test_impossible_sw_decompression::(&::ZERO, rec_id); // this x is such that y^2 = x^3 + 6 = 0 // we want to test the case where y^2 = 0 and rec_id = 1 let x = Fp5mod8::from_le_bytes_unchecked(&hex!( "d634a701c3b9b8cbf7797988be3953b442863b74d2d5c4d5f1a9de3c0c256d90" )); - test_possible_decompression::(&x, &Fp5mod8::ZERO, 0); - test_impossible_decompression::(&x, 1); + test_possible_sw_decompression::(&x, &::ZERO, 0); + test_impossible_sw_decompression::(&x, 1); let x = Fp1mod4::from_le_bytes_unchecked(&bytes[128..160]); let y = Fp1mod4::from_le_bytes_unchecked(&bytes[160..192]); let rec_id = y.as_le_bytes()[0] & 1; - test_possible_decompression::(&x, &y, rec_id); + test_possible_sw_decompression::(&x, &y, rec_id); // x = 1 is not on the x-coordinate of any point on the CurvePoint1mod4 curve - test_impossible_decompression::(&Fp1mod4::from_u8(1), rec_id); + test_impossible_sw_decompression::(&Fp1mod4::from_u8(1), rec_id); + + // ed25519 + let x = Ed25519Coord::from_le_bytes_unchecked(&bytes[192..224]); + let y = Ed25519Coord::from_le_bytes_unchecked(&bytes[224..256]); + let rec_id = x.as_le_bytes()[0] & 1; + test_possible_te_decompression::(&x, &y, rec_id); + // y = 2 is not on the y-coordinate of any point on the Ed25519 curve + test_impossible_te_decompression::(&Ed25519Coord::from_u8(2), rec_id); } -fn test_possible_decompression>( +fn test_possible_sw_decompression>( x: &P::Coordinate, y: &P::Coordinate, rec_id: u8, @@ -91,7 +100,25 @@ fn test_possible_decompression>( +fn test_possible_te_decompression>( + x: &P::Coordinate, + y: &P::Coordinate, + rec_id: u8, +) { + let p = P::decompress(y.clone(), &rec_id).unwrap(); + assert_eq!(p.x(), x); + assert_eq!(p.y(), y); +} + +fn test_impossible_sw_decompression>( + x: &P::Coordinate, + rec_id: u8, +) { + let p = P::decompress(x.clone(), &rec_id); + assert!(p.is_none()); +} + +fn test_impossible_te_decompression>( x: &P::Coordinate, rec_id: u8, ) { diff --git a/extensions/ecc/tests/programs/examples/ec.rs b/extensions/ecc/tests/programs/examples/ec.rs index 1b63057c30..71c1194463 100644 --- a/extensions/ecc/tests/programs/examples/ec.rs +++ b/extensions/ecc/tests/programs/examples/ec.rs @@ -2,8 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{msm, weierstrass::WeierstrassPoint, Group}; +use openvm_ecc_guest::{algebra::IntMod, msm, weierstrass::WeierstrassPoint, Group}; use openvm_k256::{Secp256k1Coord, Secp256k1Point, Secp256k1Scalar}; openvm::init!("openvm_init_ec_k256.rs"); diff --git a/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs b/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs index 41db1ececc..854641c4bf 100644 --- a/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs +++ b/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs @@ -2,8 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, Group}; +use openvm_ecc_guest::{algebra::IntMod, weierstrass::WeierstrassPoint, CyclicGroup, Group}; use openvm_p256::{P256Coord, P256Point}; openvm::entry!(main); diff --git a/extensions/ecc/tests/programs/examples/ec_two_curves.rs b/extensions/ecc/tests/programs/examples/ec_two_curves.rs index 6412e3184f..681f1c9fe4 100644 --- a/extensions/ecc/tests/programs/examples/ec_two_curves.rs +++ b/extensions/ecc/tests/programs/examples/ec_two_curves.rs @@ -2,8 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{msm, weierstrass::WeierstrassPoint, Group}; +use openvm_ecc_guest::{algebra::IntMod, msm, weierstrass::WeierstrassPoint, Group}; use openvm_k256::{Secp256k1Coord, Secp256k1Point, Secp256k1Scalar}; use openvm_p256::{P256Coord, P256Point}; diff --git a/extensions/ecc/tests/programs/examples/edwards_ec.rs b/extensions/ecc/tests/programs/examples/edwards_ec.rs new file mode 100644 index 0000000000..c3236cd428 --- /dev/null +++ b/extensions/ecc/tests/programs/examples/edwards_ec.rs @@ -0,0 +1,64 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +use hex_literal::hex; +use openvm_algebra_guest::moduli_macros::moduli_init; +use openvm_ecc_guest::{ + algebra::IntMod, + ed25519::{Ed25519Coord, Ed25519Point}, + edwards::TwistedEdwardsPoint, + te_macros::te_init, + CyclicGroup, Group, +}; + +openvm::init!("openvm_init_edwards_ec_ed25519.rs"); + +openvm::entry!(main); + +pub fn main() { + // Base point of edwards25519 + let mut p1 = Ed25519Point::GENERATOR; + + // random point on edwards25519 + let x2 = Ed25519Coord::from_u32(2); + let y2 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "1A43BF127BDDC4D71FF910403C11DDB5BA2BCDD2815393924657EF111E712631" + )); + let mut p2 = Ed25519Point::from_xy(x2, y2).unwrap(); + + // This is the sum of (x1, y1) and (x2, y2). + let x3 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "636C0B519B2C5B1E0D3BFD213F45AFD5DAEE3CECC9B68CF88615101BC78329E6" + )); + let y3 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "704D8868CB335A7B609D04B9CD619511675691A78861F1DFF7A5EBC389C7EA92" + )); + + // This is 2 * (x1, y1) + let x4 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "56B98CC045559AD2BBC45CAB58D842ECEE264DB9395F6014B772501B62BB7EE8" + )); + let y4 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "1BCA918096D89C83A15105DF343DC9F7510494407750226DAC0A7620ACE77BEB" + )); + + // Generic add can handle equal or unequal points. + let p3 = &p1 + &p2; + if p3.x() != &x3 || p3.y() != &y3 { + panic!(); + } + let p4 = &p2 + &p2; + if p4.x() != &x4 || p4.y() != &y4 { + panic!(); + } + + // Add assign and double assign + p1 += &p2; + if p1.x() != &x3 || p1.y() != &y3 { + panic!(); + } + p2.double_assign(); + if p2.x() != &x4 || p2.y() != &y4 { + panic!(); + } +} diff --git a/extensions/ecc/tests/programs/examples/invalid_setup.rs b/extensions/ecc/tests/programs/examples/invalid_setup.rs index 6ee2ba4f58..f32c553b68 100644 --- a/extensions/ecc/tests/programs/examples/invalid_setup.rs +++ b/extensions/ecc/tests/programs/examples/invalid_setup.rs @@ -14,8 +14,8 @@ openvm_algebra_moduli_macros::moduli_init! { // the order of the curves here does not match the order in supported_curves openvm_ecc_sw_macros::sw_init! { - P256Point, - Secp256k1Point, + "P256Point", + "Secp256k1Point", } openvm::entry!(main); diff --git a/extensions/ecc/tests/programs/openvm_ed25519.toml b/extensions/ecc/tests/programs/openvm_ed25519.toml new file mode 100644 index 0000000000..58701a8203 --- /dev/null +++ b/extensions/ecc/tests/programs/openvm_ed25519.toml @@ -0,0 +1,15 @@ +[app_vm_config.rv32i] +[app_vm_config.rv32m] +[app_vm_config.io] + +[app_vm_config.modular] +supported_moduli = ["57896044618658097711785492504343953926634992332820282019728792003956564819949", "7237005577332262213973186563042994240857116359379907606001950938285454250989"] + +[[app_vm_config.ecc.supported_te_curves]] +struct_name = "Ed25519Point" +modulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949" +scalar = "7237005577332262213973186563042994240857116359379907606001950938285454250989" + +[app_vm_config.ecc.supported_te_curves.coeffs] +a = "57896044618658097711785492504343953926634992332820282019728792003956564819948" +d = "37095705934669439343138083508754565189542113879843219016388785533085940283555" diff --git a/extensions/ecc/tests/programs/openvm_init_decompress_k256.rs b/extensions/ecc/tests/programs/openvm_init_decompress_k256_ed25519.rs similarity index 61% rename from extensions/ecc/tests/programs/openvm_init_decompress_k256.rs rename to extensions/ecc/tests/programs/openvm_init_decompress_k256_ed25519.rs index b6137ae9ee..f1251f7039 100644 --- a/extensions/ecc/tests/programs/openvm_init_decompress_k256.rs +++ b/extensions/ecc/tests/programs/openvm_init_decompress_k256_ed25519.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. -openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089237316195423570985008687907853269984665640564039457584007913129639501", "1000000007", "26959946667150639794667015087019630673557916260026308143510066298881", "26959946667150639794667015087019625940457807714424391721682722368061" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point, CurvePoint5mod8, CurvePoint1mod4 } +openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089237316195423570985008687907853269984665640564039457584007913129639501", "1000000007", "26959946667150639794667015087019630673557916260026308143510066298881", "26959946667150639794667015087019625940457807714424391721682722368061", "57896044618658097711785492504343953926634992332820282019728792003956564819949", "7237005577332262213973186563042994240857116359379907606001950938285454250989" } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point", "CurvePoint5mod8", "CurvePoint1mod4" } +openvm_ecc_guest::te_macros::te_init! { "Ed25519Point" } diff --git a/extensions/ecc/tests/programs/openvm_init_ec_k256.rs b/extensions/ecc/tests/programs/openvm_init_ec_k256.rs index bec9f527e9..0905f21c53 100644 --- a/extensions/ecc/tests/programs/openvm_init_ec_k256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ec_k256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs b/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs index 02f8b5c05d..fc26ca238d 100644 --- a/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { P256Point } +openvm_ecc_guest::sw_macros::sw_init! { "P256Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs b/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs index 8689190544..331836d7a1 100644 --- a/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point, P256Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point", "P256Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs b/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs index bec9f527e9..0905f21c53 100644 --- a/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/extensions/ecc/tests/programs/openvm_init_edwards_ec_ed25519.rs b/extensions/ecc/tests/programs/openvm_init_edwards_ec_ed25519.rs new file mode 100644 index 0000000000..1a21c687cc --- /dev/null +++ b/extensions/ecc/tests/programs/openvm_init_edwards_ec_ed25519.rs @@ -0,0 +1,4 @@ +// This file is automatically generated by cargo openvm. Do not rename or edit. +openvm_algebra_guest::moduli_macros::moduli_init! { "57896044618658097711785492504343953926634992332820282019728792003956564819949", "7237005577332262213973186563042994240857116359379907606001950938285454250989" } +openvm_ecc_guest::sw_macros::sw_init! { } +openvm_ecc_guest::te_macros::te_init! { "Ed25519Point" } diff --git a/extensions/ecc/tests/programs/openvm_k256.toml b/extensions/ecc/tests/programs/openvm_k256.toml index 571fdb895c..2fa80a5af3 100644 --- a/extensions/ecc/tests/programs/openvm_k256.toml +++ b/extensions/ecc/tests/programs/openvm_k256.toml @@ -8,9 +8,11 @@ supported_moduli = [ "115792089237316195423570985008687907852837564279074904382605163141518161494337", ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "7" diff --git a/extensions/ecc/tests/programs/openvm_k256_keccak.toml b/extensions/ecc/tests/programs/openvm_k256_keccak.toml index c1261ee458..4dc77ccd80 100644 --- a/extensions/ecc/tests/programs/openvm_k256_keccak.toml +++ b/extensions/ecc/tests/programs/openvm_k256_keccak.toml @@ -9,9 +9,11 @@ supported_moduli = [ "115792089237316195423570985008687907852837564279074904382605163141518161494337", ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "7" diff --git a/extensions/ecc/tests/programs/openvm_p256.toml b/extensions/ecc/tests/programs/openvm_p256.toml index 0035cd83da..2cc5bd92c3 100644 --- a/extensions/ecc/tests/programs/openvm_p256.toml +++ b/extensions/ecc/tests/programs/openvm_p256.toml @@ -7,9 +7,11 @@ supported_moduli = [ "115792089210356248762697446949407573529996955224135760342422259061068512044369", ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "P256Point" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" b = "41058363725152142129326129780047268409114441015993725554835256314039467401291" diff --git a/extensions/ecc/tests/src/lib.rs b/extensions/ecc/tests/src/lib.rs index 1bd01eb936..a29e4a049e 100644 --- a/extensions/ecc/tests/src/lib.rs +++ b/extensions/ecc/tests/src/lib.rs @@ -11,15 +11,18 @@ mod tests { use openvm_algebra_transpiler::ModularTranspilerExtension; use openvm_circuit::{ arch::instructions::exe::VmExe, - utils::{air_test, air_test_with_min_segments}, + utils::{air_test, air_test_with_min_segments, test_system_config_with_continuations}, + }; + use openvm_ecc_circuit::{ + CurveConfig, Rv32EccConfig, Rv32EccCpuBuilder, SwCurveCoeffs, TeCurveCoeffs, + ED25519_CONFIG, P256_CONFIG, SECP256K1_CONFIG, }; - use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, P256_CONFIG, SECP256K1_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; use openvm_sdk::{ - config::{AppConfig, SdkVmConfig}, + config::{AppConfig, SdkVmConfig, SdkVmCpuBuilder}, StdIn, }; use openvm_stark_backend::p3_field::FieldAlgebra; @@ -35,9 +38,19 @@ mod tests { type F = BabyBear; + #[cfg(test)] + fn test_rv32ecc_config( + sw_curves: Vec>, + te_curves: Vec>, + ) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, te_curves); + *config.as_mut() = test_system_config_with_continuations(); + config + } + #[test] fn test_ec() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()], vec![]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec", @@ -53,13 +66,13 @@ mod tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } #[test] - fn test_ec_nonzero_a() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + fn test_nonzero_a() -> Result<()> { + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()], vec![]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec_nonzero_a", @@ -75,14 +88,14 @@ mod tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } #[test] - fn test_ec_two_curves() -> Result<()> { + fn test_two_curves() -> Result<()> { let config = - Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); + test_rv32ecc_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()], vec![]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec_two_curves", @@ -98,16 +111,15 @@ mod tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_decompress() -> Result<()> { - use halo2curves_axiom::{group::Curve, secp256k1::Secp256k1Affine}; + use halo2curves_axiom::{ed25519::Ed25519Affine, group::Curve, secp256k1::Secp256k1Affine}; - let config = - Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone(), CurveConfig { struct_name: "CurvePoint5mod8".to_string(), modulus: BigUint::from_str("115792089237316195423570985008687907853269984665640564039457584007913129639501") @@ -115,8 +127,10 @@ mod tests { // unused, set to 10e9 + 7 scalar: BigUint::from_str("1000000007") .unwrap(), - a: BigUint::ZERO, - b: BigUint::from_str("6").unwrap(), + coeffs: SwCurveCoeffs { + a: BigUint::ZERO, + b: BigUint::from_str("6").unwrap(), + }, }, CurveConfig { struct_name: "CurvePoint1mod4".to_string(), @@ -124,19 +138,24 @@ mod tests { .unwrap(), scalar: BigUint::from_radix_be(&hex!("ffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d"), 256) .unwrap(), - a: BigUint::from_radix_be(&hex!("fffffffffffffffffffffffffffffffefffffffffffffffffffffffe"), 256) - .unwrap(), - b: BigUint::from_radix_be(&hex!("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4"), 256) - .unwrap(), + coeffs: SwCurveCoeffs { + a: BigUint::from_radix_be(&hex!("fffffffffffffffffffffffffffffffefffffffffffffffffffffffe"), 256) + .unwrap(), + b: BigUint::from_radix_be(&hex!("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4"), 256) + .unwrap(), + }, }, - ]); + ], + vec![ED25519_CONFIG.clone()], + ); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "decompress", - ["k256"], + ["k256", "ed25519"], &config, )?; + let openvm_exe = VmExe::from_elf( elf, Transpiler::::default() @@ -149,8 +168,7 @@ mod tests { let p = Secp256k1Affine::generator(); let p = (p + p + p).to_affine(); - println!("decompressed: {:?}", p); - + println!("secp256k1 decompressed: {:?}", p); let q_x: [u8; 32] = hex!("0100000000000000000000000000000000000000000000000000000000000000"); let q_y: [u8; 32] = @@ -159,13 +177,25 @@ mod tests { hex!("211D5C11D68032342211C256D3C1034AB99013327FBFB46BBD0C0EB700000000"); let r_y: [u8; 32] = hex!("347E00859981D5446447075AA07543CDE6DF224CFB23F7B5886337BD00000000"); + let s = Ed25519Affine::generator(); + let s = (s + s + s).to_affine(); + + let coords = [ + p.x.to_bytes(), + p.y.to_bytes(), + q_x, + q_y, + r_x, + r_y, + s.x.to_bytes(), + s.y.to_bytes(), + ] + .concat() + .into_iter() + .map(FieldAlgebra::from_canonical_u8) + .collect(); - let coords = [p.x.to_bytes(), p.y.to_bytes(), q_x, q_y, r_x, r_y] - .concat() - .into_iter() - .map(FieldAlgebra::from_canonical_u8) - .collect(); - air_test_with_min_segments(config, openvm_exe, vec![coords], 1); + air_test_with_min_segments(Rv32EccCpuBuilder, config, openvm_exe, vec![coords], 1); Ok(()) } @@ -182,7 +212,7 @@ mod tests { &config, )?; let openvm_exe = VmExe::from_elf(elf, config.transpiler())?; - air_test(config, openvm_exe); + air_test(SdkVmCpuBuilder, config, openvm_exe); Ok(()) } @@ -200,7 +230,7 @@ mod tests { let openvm_exe = VmExe::from_elf(elf, config.transpiler())?; let mut input = StdIn::default(); input.write(&P256_RECOVERY_TEST_VECTORS.to_vec()); - air_test_with_min_segments(config, openvm_exe, input, 1); + air_test_with_min_segments(SdkVmCpuBuilder, config, openvm_exe, input, 1); Ok(()) } @@ -218,7 +248,7 @@ mod tests { let openvm_exe = VmExe::from_elf(elf, config.transpiler())?; let mut input = StdIn::default(); input.write(&K256_RECOVERY_TEST_VECTORS.to_vec()); - air_test_with_min_segments(config, openvm_exe, input, 1); + air_test_with_min_segments(SdkVmCpuBuilder, config, openvm_exe, input, 1); Ok(()) } @@ -236,7 +266,32 @@ mod tests { let openvm_exe = VmExe::from_elf(elf, config.transpiler())?; let mut input = StdIn::default(); input.write(&k256_sec1_decoding_test_vectors()); - air_test_with_min_segments(config, openvm_exe, input, 1); + air_test_with_min_segments(SdkVmCpuBuilder, config, openvm_exe, input, 1); + Ok(()) + } + + #[test] + fn test_edwards_ec() -> Result<()> { + let config = toml::from_str::>(include_str!( + "../programs/openvm_ed25519.toml" + ))? + .app_vm_config; + let elf = build_example_program_at_path_with_features::<&str>( + get_programs_dir!(), + "edwards_ec", + ["ed25519"], + &config, + )?; + let openvm_exe = VmExe::from_elf( + elf, + Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(EccTranspilerExtension) + .with_extension(ModularTranspilerExtension), + )?; + air_test(SdkVmCpuBuilder, config, openvm_exe); Ok(()) } @@ -261,7 +316,7 @@ mod tests { ) .unwrap(); let config = - Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); - air_test(config, openvm_exe); + test_rv32ecc_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()], vec![]); + air_test(Rv32EccCpuBuilder, config, openvm_exe); } } diff --git a/extensions/ecc/transpiler/src/lib.rs b/extensions/ecc/transpiler/src/lib.rs index 462e95dbdd..469868d3ae 100644 --- a/extensions/ecc/transpiler/src/lib.rs +++ b/extensions/ecc/transpiler/src/lib.rs @@ -1,4 +1,4 @@ -use openvm_ecc_guest::{SwBaseFunct7, OPCODE, SW_FUNCT3}; +use openvm_ecc_guest::{SwBaseFunct7, TeBaseFunct7, SW_FUNCT3, SW_OPCODE, TE_FUNCT3, TE_OPCODE}; use openvm_instructions::{ instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, VmOpcode, }; @@ -15,10 +15,21 @@ use strum::{EnumCount, EnumIter, FromRepr}; #[allow(non_camel_case_types)] #[repr(usize)] pub enum Rv32WeierstrassOpcode { - EC_ADD_NE, - SETUP_EC_ADD_NE, - EC_DOUBLE, - SETUP_EC_DOUBLE, + SW_ADD_NE, + SETUP_SW_ADD_NE, + SW_DOUBLE, + SETUP_SW_DOUBLE, +} + +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x680] +#[allow(non_camel_case_types)] +#[repr(usize)] +pub enum Rv32EdwardsOpcode { + TE_ADD, + SETUP_TE_ADD, } #[derive(Default)] @@ -26,6 +37,67 @@ pub struct EccTranspilerExtension; impl TranspilerExtension for EccTranspilerExtension { fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + self.process_weierstrass_instruction(instruction_stream) + .or(self.process_edwards_instruction(instruction_stream)) + } +} + +impl EccTranspilerExtension { + fn process_edwards_instruction( + &self, + instruction_stream: &[u32], + ) -> Option> { + if instruction_stream.is_empty() { + return None; + } + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; + + if opcode != TE_OPCODE { + return None; + } + if funct3 != TE_FUNCT3 { + return None; + } + + let instruction = { + // twisted edwards ec + assert!(Rv32EdwardsOpcode::COUNT <= TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS as usize); + let dec_insn = RType::new(instruction_u32); + let base_funct7 = (dec_insn.funct7 as u8) % TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS; + let curve_idx = + ((dec_insn.funct7 as u8) / TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS) as usize; + let curve_idx_shift = curve_idx * Rv32EdwardsOpcode::COUNT; + + if base_funct7 == TeBaseFunct7::TeSetup as u8 { + let local_opcode = Rv32EdwardsOpcode::SETUP_TE_ADD; + Some(Instruction::new( + VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2), + F::ONE, // d_as = 1 + F::TWO, // e_as = 2 + F::ZERO, + F::ZERO, + )) + } else { + let global_opcode = match TeBaseFunct7::from_repr(base_funct7) { + Some(TeBaseFunct7::TeAdd) => Rv32EdwardsOpcode::TE_ADD.global_opcode(), + _ => unimplemented!(), + }; + let global_opcode = global_opcode.as_usize() + curve_idx_shift; + Some(from_r_type(global_opcode, 2, &dec_insn, true)) + } + }; + instruction.map(TranspilerOutput::one_to_one) + } + + fn process_weierstrass_instruction( + &self, + instruction_stream: &[u32], + ) -> Option> { if instruction_stream.is_empty() { return None; } @@ -33,7 +105,7 @@ impl TranspilerExtension for EccTranspilerExtension { let opcode = (instruction_u32 & 0x7f) as u8; let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; - if opcode != OPCODE { + if opcode != SW_OPCODE { return None; } if funct3 != SW_FUNCT3 { @@ -52,8 +124,8 @@ impl TranspilerExtension for EccTranspilerExtension { let curve_idx_shift = curve_idx * Rv32WeierstrassOpcode::COUNT; if base_funct7 == SwBaseFunct7::SwSetup as u8 { let local_opcode = match dec_insn.rs2 { - 0 => Rv32WeierstrassOpcode::SETUP_EC_DOUBLE, - _ => Rv32WeierstrassOpcode::SETUP_EC_ADD_NE, + 0 => Rv32WeierstrassOpcode::SETUP_SW_DOUBLE, + _ => Rv32WeierstrassOpcode::SETUP_SW_ADD_NE, }; Some(Instruction::new( VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift), @@ -67,18 +139,14 @@ impl TranspilerExtension for EccTranspilerExtension { )) } else { let global_opcode = match SwBaseFunct7::from_repr(base_funct7) { - Some(SwBaseFunct7::SwAddNe) => { - Rv32WeierstrassOpcode::EC_ADD_NE as usize - + Rv32WeierstrassOpcode::CLASS_OFFSET - } + Some(SwBaseFunct7::SwAddNe) => Rv32WeierstrassOpcode::SW_ADD_NE.global_opcode(), Some(SwBaseFunct7::SwDouble) => { assert!(dec_insn.rs2 == 0); - Rv32WeierstrassOpcode::EC_DOUBLE as usize - + Rv32WeierstrassOpcode::CLASS_OFFSET + Rv32WeierstrassOpcode::SW_DOUBLE.global_opcode() } _ => unimplemented!(), }; - let global_opcode = global_opcode + curve_idx_shift; + let global_opcode = global_opcode.as_usize() + curve_idx_shift; Some(from_r_type(global_opcode, 2, &dec_insn, true)) } }; diff --git a/extensions/keccak256/circuit/Cargo.toml b/extensions/keccak256/circuit/Cargo.toml index 941303ab39..2299a0599a 100644 --- a/extensions/keccak256/circuit/Cargo.toml +++ b/extensions/keccak256/circuit/Cargo.toml @@ -23,12 +23,10 @@ p3-keccak-air = { workspace = true } strum.workspace = true tiny-keccak.workspace = true itertools.workspace = true -tracing.workspace = true derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true serde.workspace = true -serde-big-array.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/keccak256/circuit/src/extension.rs b/extensions/keccak256/circuit/src/extension.rs index 5993f69eda..9c44503299 100644 --- a/extensions/keccak256/circuit/src/extension.rs +++ b/extensions/keccak256/circuit/src/extension.rs @@ -1,20 +1,32 @@ +use std::{result::Result, sync::Arc}; + use derive_more::derive::From; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, - VmInventoryError, + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, + ExecutorInventoryBuilder, ExecutorInventoryError, InitFileGenerator, MatrixRecordArena, + RowMajorMatrixArena, SystemConfig, VmBuilder, VmChipComplex, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{ + memory::SharedMemoryHelper, SystemChipInventory, SystemCpuBuilder, SystemExecutor, + SystemPort, }, - system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; -use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupBus; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor, VmConfig}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, +}; use openvm_instructions::*; use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; @@ -22,7 +34,7 @@ use crate::*; #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Keccak256Rv32Config { - #[system] + #[config(executor = "SystemExecutor")] pub system: SystemConfig, #[extension] pub rv32i: Rv32I, @@ -49,62 +61,146 @@ impl Default for Keccak256Rv32Config { // Default implementation uses no init file impl InitFileGenerator for Keccak256Rv32Config {} -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -pub struct Keccak256; +#[derive(Clone)] +pub struct Keccak256Rv32CpuBuilder; + +impl VmBuilder for Keccak256Rv32CpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Keccak256Rv32Config; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum Keccak256Executor { - Keccak256(KeccakVmChip), + fn create_chip_complex( + &self, + config: &Keccak256Rv32Config, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover( + &Keccak256CpuProverExt, + &config.keccak, + inventory, + )?; + Ok(chip_complex) + } } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Keccak256Periphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - Phantom(PhantomChip), +// =================================== VM Extension Implementation ================================= +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub struct Keccak256; + +#[derive(Clone, Copy, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Keccak256Executor { + Keccak256(KeccakVmExecutor), } -impl VmExtension for Keccak256 { - type Executor = Keccak256Executor; - type Periphery = Keccak256Periphery; +impl VmExecutionExtension for Keccak256 { + type Executor = Keccak256Executor; - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + let keccak_step = KeccakVmExecutor::new(Rv32KeccakOpcode::CLASS_OFFSET, pointer_max_bits); + inventory.add_executor( + keccak_step, + Rv32KeccakOpcode::iter().map(|x| x.global_opcode()), + )?; + + Ok(()) + } +} + +impl VmCircuitExtension for Keccak256 { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { let SystemPort { execution_bus, program_bus, memory_bridge, - } = builder.system_port(); - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip + } = inventory.system().port(); + + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let pointer_max_bits = inventory.pointer_max_bits(); + + let bitwise_lu = { + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; - let keccak_chip = KeccakVmChip::new( - execution_bus, - program_bus, + let keccak = KeccakVmAir::new( + exec_bridge, memory_bridge, - address_bits, - bitwise_lu_chip, + bitwise_lu, + pointer_max_bits, Rv32KeccakOpcode::CLASS_OFFSET, - offline_memory, ); - inventory.add_executor( - keccak_chip, - Rv32KeccakOpcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_air(keccak); + + Ok(()) + } +} + +pub struct Keccak256CpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Keccak256CpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Keccak256, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + inventory.next_air::()?; + let keccak = KeccakVmChip::new( + KeccakVmFiller::new(bitwise_lu, pointer_max_bits), + mem_helper, + ); + inventory.add_executor_chip(keccak); - Ok(inventory) + Ok(()) } } diff --git a/extensions/keccak256/circuit/src/lib.rs b/extensions/keccak256/circuit/src/lib.rs index c9fd1c9f5a..d167a59ae5 100644 --- a/extensions/keccak256/circuit/src/lib.rs +++ b/extensions/keccak256/circuit/src/lib.rs @@ -1,17 +1,11 @@ //! Stateful keccak256 hasher. Handles full keccak sponge (padding, absorb, keccak-f) on //! variable length inputs read from VM memory. -use std::{ - array::from_fn, - cmp::min, - sync::{Arc, Mutex}, -}; + +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; -use tiny_keccak::{Hasher, Keccak}; -use utils::num_keccak_f; +use p3_keccak_air::NUM_ROUNDS; pub mod air; pub mod columns; @@ -19,24 +13,21 @@ pub mod trace; pub mod utils; mod extension; -pub use extension::*; - #[cfg(test)] mod tests; - pub use air::KeccakVmAir; -use openvm_circuit::{ - arch::{ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor}, - system::{ - memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory, RecordId}, - program::ProgramBus, - }, -}; +pub use extension::*; +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ - instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, }; use openvm_keccak256_transpiler::Rv32KeccakOpcode; -use openvm_rv32im_circuit::adapters::read_rv32_register; + +use crate::utils::{keccak256, num_keccak_f}; // ==== Constants for register/memory adapter ==== /// Register reads to get dst, src, len @@ -69,76 +60,136 @@ pub const KECCAK_DIGEST_BYTES: usize = 32; /// Number of 64-bit digest limbs. pub const KECCAK_DIGEST_U64S: usize = KECCAK_DIGEST_BYTES / 8; -pub struct KeccakVmChip { - pub air: KeccakVmAir, - /// IO and memory data necessary for each opcode call - pub records: Vec>, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, +pub type KeccakVmChip = VmChipWrapper; - offset: usize, +#[derive(derive_new::new, Clone, Copy)] +pub struct KeccakVmExecutor { + pub offset: usize, + pub pointer_max_bits: usize, +} - offline_memory: Arc>>, +#[derive(derive_new::new)] +pub struct KeccakVmFiller { + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pub pointer_max_bits: usize, } -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct KeccakRecord { - pub pc: F, - pub dst_read: RecordId, - pub src_read: RecordId, - pub len_read: RecordId, - pub input_blocks: Vec, - pub digest_writes: [RecordId; KECCAK_DIGEST_WRITES], +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct KeccakPreCompute { + a: u8, + b: u8, + c: u8, } -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct KeccakInputBlock { - /// Memory reads for non-padding bytes in this block. Length is at most [KECCAK_RATE_BYTES / - /// KECCAK_WORD_SIZE]. - pub reads: Vec, - /// Index in `reads` of the memory read for < KECCAK_WORD_SIZE bytes, if any. - pub partial_read_idx: Option, - /// Bytes with padding. Can be derived from `bytes_read` but we store for convenience. - #[serde(with = "BigArray")] - pub padded_bytes: [u8; KECCAK_RATE_BYTES], - pub remaining_len: usize, - pub src: usize, - pub is_new_start: bool, +impl Executor for KeccakVmExecutor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut KeccakPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _>) + } } -impl KeccakVmChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - offset: usize, - offline_memory: Arc>>, - ) -> Self { - Self { - air: KeccakVmAir::new( - ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_chip.bus(), - address_bits, - offset, - ), - bitwise_lookup_chip, - records: Vec::new(), - offset, - offline_memory, - } +impl MeteredExecutor for KeccakVmExecutor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _>) } } -impl InstructionExecutor for KeccakVmChip { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let &Instruction { +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &KeccakPreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); + let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let dst_u32 = u32::from_le_bytes(dst); + let src_u32 = u32::from_le_bytes(src); + let len_u32 = u32::from_le_bytes(len); + + let (output, height) = if IS_E1 { + // SAFETY: RV32_MEMORY_AS is memory address space of type u8 + let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); + let output = keccak256(message); + (output, 0) + } else { + let num_reads = (len_u32 as usize).div_ceil(KECCAK_WORD_SIZE); + let message: Vec<_> = (0..num_reads) + .flat_map(|i| { + vm_state.vm_read::( + RV32_MEMORY_AS, + src_u32 + (i * KECCAK_WORD_SIZE) as u32, + ) + }) + .collect(); + let output = keccak256(&message[..len_u32 as usize]); + let height = (num_keccak_f(len_u32 as usize) * NUM_ROUNDS) as u32; + (output, height) + }; + vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + height +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &KeccakPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +impl KeccakVmExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut KeccakPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { opcode, a, b, @@ -146,141 +197,17 @@ impl InstructionExecutor for KeccakVmChip { d, e, .. - } = instruction; - let local_opcode = Rv32KeccakOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - debug_assert_eq!(local_opcode, Rv32KeccakOpcode::KECCAK256); - - let mut timestamp_delta = 3; - let (dst_read, dst) = read_rv32_register(memory, d, a); - let (src_read, src) = read_rv32_register(memory, d, b); - let (len_read, len) = read_rv32_register(memory, d, c); - #[cfg(debug_assertions)] - { - assert!(dst < (1 << self.air.ptr_max_bits)); - assert!(src < (1 << self.air.ptr_max_bits)); - assert!(len < (1 << self.air.ptr_max_bits)); + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); } - - let mut remaining_len = len as usize; - let num_blocks = num_keccak_f(remaining_len); - let mut input_blocks = Vec::with_capacity(num_blocks); - let mut hasher = Keccak::v256(); - let mut src = src as usize; - - for block_idx in 0..num_blocks { - if block_idx != 0 { - memory.increment_timestamp_by(KECCAK_REGISTER_READS as u32); - timestamp_delta += KECCAK_REGISTER_READS as u32; - } - let mut reads = Vec::with_capacity(KECCAK_RATE_BYTES); - - let mut partial_read_idx = None; - let mut bytes = [0u8; KECCAK_RATE_BYTES]; - for i in (0..KECCAK_RATE_BYTES).step_by(KECCAK_WORD_SIZE) { - if i < remaining_len { - let read = - memory.read::(e, F::from_canonical_usize(src + i)); - - let chunk = read.1.map(|x| { - x.as_canonical_u32() - .try_into() - .expect("Memory cell not a byte") - }); - let copy_len = min(KECCAK_WORD_SIZE, remaining_len - i); - if copy_len != KECCAK_WORD_SIZE { - partial_read_idx = Some(reads.len()); - } - bytes[i..i + copy_len].copy_from_slice(&chunk[..copy_len]); - reads.push(read.0); - } else { - memory.increment_timestamp(); - } - timestamp_delta += 1; - } - - let mut block = KeccakInputBlock { - reads, - partial_read_idx, - padded_bytes: bytes, - remaining_len, - src, - is_new_start: block_idx == 0, - }; - if block_idx != num_blocks - 1 { - src += KECCAK_RATE_BYTES; - remaining_len -= KECCAK_RATE_BYTES; - hasher.update(&block.padded_bytes); - } else { - // handle padding here since it is convenient - debug_assert!(remaining_len < KECCAK_RATE_BYTES); - hasher.update(&block.padded_bytes[..remaining_len]); - - if remaining_len == KECCAK_RATE_BYTES - 1 { - block.padded_bytes[remaining_len] = 0b1000_0001; - } else { - block.padded_bytes[remaining_len] = 0x01; - block.padded_bytes[KECCAK_RATE_BYTES - 1] = 0x80; - } - } - input_blocks.push(block); - } - let mut output = [0u8; 32]; - hasher.finalize(&mut output); - let dst = dst as usize; - let digest_writes: [_; KECCAK_DIGEST_WRITES] = from_fn(|i| { - timestamp_delta += 1; - memory - .write::( - e, - F::from_canonical_usize(dst + i * KECCAK_WORD_SIZE), - from_fn(|j| F::from_canonical_u8(output[i * KECCAK_WORD_SIZE + j])), - ) - .0 - }); - tracing::trace!("[runtime] keccak256 output: {:?}", output); - - let record = KeccakRecord { - pc: F::from_canonical_u32(from_state.pc), - dst_read, - src_read, - len_read, - input_blocks, - digest_writes, + *data = KeccakPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, }; - - // Add the events to chip state for later trace generation usage - self.records.push(record); - - // NOTE: Check this is consistent with KeccakVmAir::timestamp_change (we don't use it to - // avoid unnecessary conversions here) - let total_timestamp_delta = - len + (KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32; - memory.increment_timestamp_by(total_timestamp_delta - timestamp_delta); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: from_state.timestamp + total_timestamp_delta, - }) - } - - fn get_opcode_name(&self, _: usize) -> String { - "KECCAK256".to_string() - } -} - -impl Default for KeccakInputBlock { - fn default() -> Self { - // Padding for empty byte array so padding constraints still hold - let mut padded_bytes = [0u8; KECCAK_RATE_BYTES]; - padded_bytes[0] = 0x01; - *padded_bytes.last_mut().unwrap() = 0x80; - Self { - padded_bytes, - partial_read_idx: None, - remaining_len: 0, - is_new_start: true, - reads: Vec::new(), - src: 0, - } + assert_eq!(&Rv32KeccakOpcode::KECCAK256.global_opcode(), opcode); + Ok(()) } } diff --git a/extensions/keccak256/circuit/src/tests.rs b/extensions/keccak256/circuit/src/tests.rs index 65a34491b8..b41c9abdef 100644 --- a/extensions/keccak256/circuit/src/tests.rs +++ b/extensions/keccak256/circuit/src/tests.rs @@ -1,104 +1,242 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut, sync::Arc}; use hex::FromHex; -use openvm_circuit::arch::testing::{VmChipTestBuilder, VmChipTester, BITWISE_OP_LOOKUP_BUS}; +use openvm_circuit::{ + arch::{ + testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + Arena, DenseRecordArena, PreflightExecutor, + }, + utils::get_random_message, +}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_keccak256_transpiler::Rv32KeccakOpcode; +use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_keccak256_transpiler::Rv32KeccakOpcode::{self, *}; use openvm_stark_backend::{ - p3_field::FieldAlgebra, utils::disable_debug_builder, verifier::VerificationError, -}; -use openvm_stark_sdk::{ - config::baby_bear_blake3::BabyBearBlake3Config, p3_baby_bear::BabyBear, - utils::create_seeded_rng, + p3_field::FieldAlgebra, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, }; -use p3_keccak_air::NUM_ROUNDS; -use rand::Rng; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; use tiny_keccak::Hasher; -use super::{columns::KeccakVmCols, utils::num_keccak_f, KeccakVmChip}; +use super::{columns::KeccakVmCols, KeccakVmChip}; +use crate::{ + trace::KeccakVmRecordLayout, utils::keccak256, KeccakVmAir, KeccakVmExecutor, KeccakVmFiller, +}; type F = BabyBear; -// io is vector of (input, expected_output, prank_output) where prank_output is Some if the trace -// will be replaced -#[allow(clippy::type_complexity)] -fn build_keccak256_test( - io: Vec<(Vec, Option<[u8; 32]>, Option<[u8; 32]>)>, -) -> VmChipTester { +const MAX_INS_CAPACITY: usize = 8192; +type Harness = TestChipHarness, RA>; + +fn create_test_chips( + tester: &mut VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::<8>::new(bitwise_bus); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); - let mut tester = VmChipTestBuilder::default(); - let mut chip = KeccakVmChip::new( - tester.execution_bus(), - tester.program_bus(), + let air = KeccakVmAir::new( + tester.execution_bridge(), tester.memory_bridge(), + bitwise_chip.bus(), tester.address_bits(), - bitwise_chip.clone(), Rv32KeccakOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), ); - let mut dst = 0; - let src = 0; + let executor = KeccakVmExecutor::new(Rv32KeccakOpcode::CLASS_OFFSET, tester.address_bits()); + let chip = KeccakVmChip::new( + KeccakVmFiller::new(bitwise_chip.clone(), tester.address_bits()), + tester.memory_helper(), + ); + + let harness = Harness::::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} + +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: Rv32KeccakOpcode, + message: Option<&[u8]>, + len: Option, + expected_output: Option<[u8; 32]>, +) where + KeccakVmExecutor: PreflightExecutor, +{ + let len = len.unwrap_or(rng.gen_range(1..3000)); + let tmp = get_random_message(rng, len); + let message: &[u8] = message.unwrap_or(&tmp); + let len = message.len(); + + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); - for (input, expected_output, _) in &io { - let [a, b, c] = [0, 4, 8]; // space apart for register limbs - let [d, e] = [1, 2]; + let max_mem_ptr: u32 = 1 << tester.address_bits(); + let dst_ptr = rng.gen_range(0..max_mem_ptr); + let dst_ptr = dst_ptr ^ (dst_ptr & 3); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + let src_ptr = rng.gen_range(0..(max_mem_ptr - len as u32)); + let src_ptr = src_ptr ^ (src_ptr & 3); + tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); - tester.write(d, a, (dst as u32).to_le_bytes().map(F::from_canonical_u8)); - tester.write(d, b, (src as u32).to_le_bytes().map(F::from_canonical_u8)); + message.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); tester.write( - d, - c, - (input.len() as u32).to_le_bytes().map(F::from_canonical_u8), + 2, + src_ptr as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), ); - for (i, byte) in input.iter().enumerate() { - tester.write_cell(e, src + i, F::from_canonical_u8(*byte)); - } + }); + + tester.execute( + harness, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + + let expected_output = expected_output.unwrap_or(keccak256(message)); + println!("expected_output: {:?}", expected_output); + println!("keccak256(message): {:?}", keccak256(message)); + assert_eq!( + expected_output.map(F::from_canonical_u8), + tester.read(2, dst_ptr as usize) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// +#[test] +fn rand_keccak256_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise) = create_test_chips(&mut tester); - tester.execute( - &mut chip, - &Instruction::from_isize( - Rv32KeccakOpcode::KECCAK256.global_opcode(), - a as isize, - b as isize, - c as isize, - d as isize, - e as isize, - ), + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + KECCAK256, + None, + None, + None, ); - if let Some(output) = expected_output { - for (i, byte) in output.iter().enumerate() { - assert_eq!(tester.read_cell(e, dst + i), F::from_canonical_u8(*byte)); - } - } - // shift dst to not deal with timestamps for pranking - dst += 32; } - let mut tester = tester.build().load(chip).load(bitwise_chip).finalize(); - - let keccak_trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let mut row = 0; - for (input, _, prank_output) in io { - let num_blocks = num_keccak_f(input.len()); - let num_rows = NUM_ROUNDS * num_blocks; - row += num_rows; - if prank_output.is_none() { - continue; - } - let output = prank_output.unwrap(); - let digest_row: &mut KeccakVmCols<_> = keccak_trace.row_mut(row - 1).borrow_mut(); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + KECCAK256, + None, + Some(10000), + None, + ); + + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +// Keccak Known Answer Test (KAT) vectors from https://keccak.team/obsolete/KeccakKAT-3.zip. +// Only selecting a small subset for now (add more later) +// KAT includes inputs at the bit level; we only include the ones that are bytes +#[test] +fn test_keccak256_positive_kat_vectors() { + // input, output, Len in bits + let test_vectors = vec![ + // ("", "C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470"), // ShortMsgKAT_256 Len = 0 + ("CC", "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A"), // ShortMsgKAT_256 Len = 8 + ("B55C10EAE0EC684C16D13463F29291BF26C82E2FA0422A99C71DB4AF14DD9C7F33EDA52FD73D017CC0F2DBE734D831F0D820D06D5F89DACC485739144F8CFD4799223B1AFF9031A105CB6A029BA71E6E5867D85A554991C38DF3C9EF8C1E1E9A7630BE61CAABCA69280C399C1FB7A12D12AEFC", "0347901965D3635005E75A1095695CCA050BC9ED2D440C0372A31B348514A889"), // ShortMsgKAT_256 Len = 920 + ("2EDC282FFB90B97118DD03AAA03B145F363905E3CBD2D50ECD692B37BF000185C651D3E9726C690D3773EC1E48510E42B17742B0B0377E7DE6B8F55E00A8A4DB4740CEE6DB0830529DD19617501DC1E9359AA3BCF147E0A76B3AB70C4984C13E339E6806BB35E683AF8527093670859F3D8A0FC7D493BCBA6BB12B5F65E71E705CA5D6C948D66ED3D730B26DB395B3447737C26FAD089AA0AD0E306CB28BF0ACF106F89AF3745F0EC72D534968CCA543CD2CA50C94B1456743254E358C1317C07A07BF2B0ECA438A709367FAFC89A57239028FC5FECFD53B8EF958EF10EE0608B7F5CB9923AD97058EC067700CC746C127A61EE3", "DD1D2A92B3F3F3902F064365838E1F5F3468730C343E2974E7A9ECFCD84AA6DB"), // ShortMsgKAT_256 Len = 1952, + ("724627916C50338643E6996F07877EAFD96BDF01DA7E991D4155B9BE1295EA7D21C9391F4C4A41C75F77E5D27389253393725F1427F57914B273AB862B9E31DABCE506E558720520D33352D119F699E784F9E548FF91BC35CA147042128709820D69A8287EA3257857615EB0321270E94B84F446942765CE882B191FAEE7E1C87E0F0BD4E0CD8A927703524B559B769CA4ECE1F6DBF313FDCF67C572EC4185C1A88E86EC11B6454B371980020F19633B6B95BD280E4FBCB0161E1A82470320CEC6ECFA25AC73D09F1536F286D3F9DACAFB2CD1D0CE72D64D197F5C7520B3CCB2FD74EB72664BA93853EF41EABF52F015DD591500D018DD162815CC993595B195", "EA0E416C0F7B4F11E3F00479FDDF954F2539E5E557753BD546F69EE375A5DE29"), // LongMsgKAT_256 Len = 2048 + ("6E1CADFB2A14C5FFB1DD69919C0124ED1B9A414B2BEA1E5E422D53B022BDD13A9C88E162972EBB9852330006B13C5B2F2AFBE754AB7BACF12479D4558D19DDBB1A6289387B3AC084981DF335330D1570850B97203DBA5F20CF7FF21775367A8401B6EBE5B822ED16C39383232003ABC412B0CE0DD7C7DA064E4BB73E8C58F222A1512D5FE6D947316E02F8AA87E7AA7A3AA1C299D92E6414AE3B927DB8FF708AC86A09B24E1884743BC34067BB0412453B4A6A6509504B550F53D518E4BCC3D9C1EFDB33DA2EACCB84C9F1CAEC81057A8508F423B25DB5500E5FC86AB3B5EB10D6D0BF033A716DDE55B09FD53451BBEA644217AE1EF91FAD2B5DCC6515249C96EE7EABFD12F1EF65256BD1CFF2087DABF2F69AD1FFB9CF3BC8CA437C7F18B6095BC08D65DF99CC7F657C418D8EB109FDC91A13DC20A438941726EF24F9738B6552751A320C4EA9C8D7E8E8592A3B69D30A419C55FB6CB0850989C029AAAE66305E2C14530B39EAA86EA3BA2A7DECF4B2848B01FAA8AA91F2440B7CC4334F63061CE78AA1589BEFA38B194711697AE3AADCB15C9FBF06743315E2F97F1A8B52236ACB444069550C2345F4ED12E5B8E881CDD472E803E5DCE63AE485C2713F81BC307F25AC74D39BAF7E3BC5E7617465C2B9C309CB0AC0A570A7E46C6116B2242E1C54F456F6589E20B1C0925BF1CD5F9344E01F63B5BA9D4671ABBF920C7ED32937A074C33836F0E019DFB6B35D865312C6058DFDAFF844C8D58B75071523E79DFBAB2EA37479DF12C474584F4FF40F00F92C6BADA025CE4DF8FAF0AFB2CE75C07773907CA288167D6B011599C3DE0FFF16C1161D31DF1C1DDE217CB574ED5A33751759F8ED2B1E6979C5088B940926B9155C9D250B479948C20ACB5578DC02C97593F646CC5C558A6A0F3D8D273258887CCFF259197CB1A7380622E371FD2EB5376225EC04F9ED1D1F2F08FA2376DB5B790E73086F581064ED1C5F47E989E955D77716B50FB64B853388FBA01DAC2CEAE99642341F2DA64C56BEFC4789C051E5EB79B063F2F084DB4491C3C5AA7B4BCF7DD7A1D7CED1554FA67DCA1F9515746A237547A4A1D22ACF649FA1ED3B9BB52BDE0C6996620F8CFDB293F8BACAD02BCE428363D0BB3D391469461D212769048219220A7ED39D1F9157DFEA3B4394CA8F5F612D9AC162BF0B961BFBC157E5F863CE659EB235CF98E8444BC8C7880BDDCD0B3B389AAA89D5E05F84D0649EEBACAB4F1C75352E89F0E9D91E4ACA264493A50D2F4AED66BD13650D1F18E7199E931C78AEB763E903807499F1CD99AF81276B615BE8EC709B039584B2B57445B014F6162577F3548329FD288B0800F936FC5EA1A412E3142E609FC8E39988CA53DF4D8FB5B5FB5F42C0A01648946AC6864CFB0E92856345B08E5DF0D235261E44CFE776456B40AEF0AC1A0DFA2FE639486666C05EA196B0C1A9D346435E03965E6139B1CE10129F8A53745F80100A94AE04D996C13AC14CF2713E39DFBB19A936CF3861318BD749B1FB82F40D73D714E406CBEB3D920EA037B7DE566455CCA51980F0F53A762D5BF8A4DBB55AAC0EDDB4B1F2AED2AA3D01449D34A57FDE4329E7FF3F6BECE4456207A4225218EE9F174C2DE0FF51CEAF2A07CF84F03D1DF316331E3E725C5421356C40ED25D5ABF9D24C4570FED618CA41000455DBD759E32E2BF0B6C5E61297C20F752C3042394CE840C70943C451DD5598EB0E4953CE26E833E5AF64FC1007C04456D19F87E45636F456B7DC9D31E757622E2739573342DE75497AE181AAE7A5425756C8E2A7EEF918E5C6A968AEFE92E8B261BBFE936B19F9E69A3C90094096DAE896450E1505ED5828EE2A7F0EA3A28E6EC47C0AF711823E7689166EA07ECA00FFC493131D65F93A4E1D03E0354AFC2115CFB8D23DAE8C6F96891031B23226B8BC82F1A73DAA5BB740FC8CC36C0975BEFA0C7895A9BBC261EDB7FD384103968F7A18353D5FE56274E4515768E4353046C785267DE01E816A2873F97AAD3AB4D7234EBFD9832716F43BE8245CF0B4408BA0F0F764CE9D24947AB6ABDD9879F24FCFF10078F5894B0D64F6A8D3EA3DD92A0C38609D3C14FDC0A44064D501926BE84BF8034F1D7A8C5F382E6989BFFA2109D4FBC56D1F091E8B6FABFF04D21BB19656929D19DECB8E8291E6AE5537A169874E0FE9890DFF11FFD159AD23D749FB9E8B676E2C31313C16D1EFA06F4D7BC191280A4EE63049FCEF23042B20303AECDD412A526D7A53F760A089FBDF13F361586F0DCA76BB928EDB41931D11F679619F948A6A9E8DBA919327769006303C6EF841438A7255C806242E2E7FF4621BB0F8AFA0B4A248EAD1A1E946F3E826FBFBBF8013CE5CC814E20FEF21FA5DB19EC7FF0B06C592247B27E500EB4705E6C37D41D09E83CB0A618008CA1AAAE8A215171D817659063C2FA385CFA3C1078D5C2B28CE7312876A276773821BE145785DFF24BBB24D590678158A61EA49F2BE56FDAC8CE7F94B05D62F15ADD351E5930FD4F31B3E7401D5C0FF7FC845B165FB6ABAFD4788A8B0615FEC91092B34B710A68DA518631622BA2AAE5D19010D307E565A161E64A4319A6B261FB2F6A90533997B1AEC32EF89CF1F232696E213DAFE4DBEB1CF1D5BBD12E5FF2EBB2809184E37CD9A0E58A4E0AF099493E6D8CC98B05A2F040A7E39515038F6EE21FC25F8D459A327B83EC1A28A234237ACD52465506942646AC248EC96EBBA6E1B092475F7ADAE4D35E009FD338613C7D4C12E381847310A10E6F02C02392FC32084FBE939689BC6518BE27AF7842DEEA8043828E3DFFE3BBAC4794CA0CC78699722709F2E4B0EAE7287DEB06A27B462423EC3F0DF227ACF589043292685F2C0E73203E8588B62554FF19D6260C7FE48DF301509D33BE0D8B31D3F658C921EF7F55449FF3887D91BFB894116DF57206098E8C5835B", "3C79A3BD824542C20AF71F21D6C28DF2213A041F77DD79A328A0078123954E7B"), // LongMsgKAT_256 Len = 16664 + ("7ADC0B6693E61C269F278E6944A5A2D8300981E40022F839AC644387BFAC9086650085C2CDC585FEA47B9D2E52D65A2B29A7DC370401EF5D60DD0D21F9E2B90FAE919319B14B8C5565B0423CEFB827D5F1203302A9D01523498A4DB10374", "4CC2AFF141987F4C2E683FA2DE30042BACDCD06087D7A7B014996E9CFEAA58CE"), // ShortMsgKAT_256 Len = 752 + ]; + + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise) = create_test_chips(&mut tester); + + for (input, output) in test_vectors { + let input = Vec::from_hex(input).unwrap(); + let output = Vec::from_hex(output).unwrap().try_into().unwrap(); + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + KECCAK256, + Some(&input), + None, + Some(output), + ); + } + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// +fn run_negative_keccak256_test( + input: &[u8], + prank_output: [u8; 32], + verification_error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise) = create_test_chips(&mut tester); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + KECCAK256, + Some(input), + None, + None, + ); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(16).to_vec(); + let digest_row: &mut KeccakVmCols<_> = trace_row.as_mut_slice().borrow_mut(); for i in 0..16 { - let out_limb = - F::from_canonical_u16(output[2 * i] as u16 + ((output[2 * i + 1] as u16) << 8)); + let out_limb = F::from_canonical_u16( + prank_output[2 * i] as u16 + ((prank_output[2 * i + 1] as u16) << 8), + ); let x = i / 4; let y = 0; let limb = i % 4; @@ -108,9 +246,16 @@ fn build_keccak256_test( digest_row.inner.a_prime_prime[y][x][limb] = out_limb; } } - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; - tester + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) + .finalize(); + tester.simple_test_with_expected_error(verification_error); } #[test] @@ -122,37 +267,49 @@ fn test_keccak256_negative() { let mut out = [0u8; 32]; hasher.finalize(&mut out); out[0] = rng.gen(); - let tester = build_keccak256_test(vec![(input, None, Some(out))]); - disable_debug_builder(); - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch) - ); + run_negative_keccak256_test(&input, out, VerificationError::OodEvaluationMismatch); } -// Keccak Known Answer Test (KAT) vectors from https://keccak.team/obsolete/KeccakKAT-3.zip. -// Only selecting a small subset for now (add more later) -// KAT includes inputs at the bit level; we only include the ones that are bytes +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// #[test] -fn test_keccak256_positive_kat_vectors() { - // input, output, Len in bits - let test_vectors = vec![ - ("", "C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470"), // ShortMsgKAT_256 Len = 0 - ("CC", "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A"), // ShortMsgKAT_256 Len = 8 - ("B55C10EAE0EC684C16D13463F29291BF26C82E2FA0422A99C71DB4AF14DD9C7F33EDA52FD73D017CC0F2DBE734D831F0D820D06D5F89DACC485739144F8CFD4799223B1AFF9031A105CB6A029BA71E6E5867D85A554991C38DF3C9EF8C1E1E9A7630BE61CAABCA69280C399C1FB7A12D12AEFC", "0347901965D3635005E75A1095695CCA050BC9ED2D440C0372A31B348514A889"), // ShortMsgKAT_256 Len = 920 - ("2EDC282FFB90B97118DD03AAA03B145F363905E3CBD2D50ECD692B37BF000185C651D3E9726C690D3773EC1E48510E42B17742B0B0377E7DE6B8F55E00A8A4DB4740CEE6DB0830529DD19617501DC1E9359AA3BCF147E0A76B3AB70C4984C13E339E6806BB35E683AF8527093670859F3D8A0FC7D493BCBA6BB12B5F65E71E705CA5D6C948D66ED3D730B26DB395B3447737C26FAD089AA0AD0E306CB28BF0ACF106F89AF3745F0EC72D534968CCA543CD2CA50C94B1456743254E358C1317C07A07BF2B0ECA438A709367FAFC89A57239028FC5FECFD53B8EF958EF10EE0608B7F5CB9923AD97058EC067700CC746C127A61EE3", "DD1D2A92B3F3F3902F064365838E1F5F3468730C343E2974E7A9ECFCD84AA6DB"), // ShortMsgKAT_256 Len = 1952, - ("724627916C50338643E6996F07877EAFD96BDF01DA7E991D4155B9BE1295EA7D21C9391F4C4A41C75F77E5D27389253393725F1427F57914B273AB862B9E31DABCE506E558720520D33352D119F699E784F9E548FF91BC35CA147042128709820D69A8287EA3257857615EB0321270E94B84F446942765CE882B191FAEE7E1C87E0F0BD4E0CD8A927703524B559B769CA4ECE1F6DBF313FDCF67C572EC4185C1A88E86EC11B6454B371980020F19633B6B95BD280E4FBCB0161E1A82470320CEC6ECFA25AC73D09F1536F286D3F9DACAFB2CD1D0CE72D64D197F5C7520B3CCB2FD74EB72664BA93853EF41EABF52F015DD591500D018DD162815CC993595B195", "EA0E416C0F7B4F11E3F00479FDDF954F2539E5E557753BD546F69EE375A5DE29"), // LongMsgKAT_256 Len = 2048 - ("6E1CADFB2A14C5FFB1DD69919C0124ED1B9A414B2BEA1E5E422D53B022BDD13A9C88E162972EBB9852330006B13C5B2F2AFBE754AB7BACF12479D4558D19DDBB1A6289387B3AC084981DF335330D1570850B97203DBA5F20CF7FF21775367A8401B6EBE5B822ED16C39383232003ABC412B0CE0DD7C7DA064E4BB73E8C58F222A1512D5FE6D947316E02F8AA87E7AA7A3AA1C299D92E6414AE3B927DB8FF708AC86A09B24E1884743BC34067BB0412453B4A6A6509504B550F53D518E4BCC3D9C1EFDB33DA2EACCB84C9F1CAEC81057A8508F423B25DB5500E5FC86AB3B5EB10D6D0BF033A716DDE55B09FD53451BBEA644217AE1EF91FAD2B5DCC6515249C96EE7EABFD12F1EF65256BD1CFF2087DABF2F69AD1FFB9CF3BC8CA437C7F18B6095BC08D65DF99CC7F657C418D8EB109FDC91A13DC20A438941726EF24F9738B6552751A320C4EA9C8D7E8E8592A3B69D30A419C55FB6CB0850989C029AAAE66305E2C14530B39EAA86EA3BA2A7DECF4B2848B01FAA8AA91F2440B7CC4334F63061CE78AA1589BEFA38B194711697AE3AADCB15C9FBF06743315E2F97F1A8B52236ACB444069550C2345F4ED12E5B8E881CDD472E803E5DCE63AE485C2713F81BC307F25AC74D39BAF7E3BC5E7617465C2B9C309CB0AC0A570A7E46C6116B2242E1C54F456F6589E20B1C0925BF1CD5F9344E01F63B5BA9D4671ABBF920C7ED32937A074C33836F0E019DFB6B35D865312C6058DFDAFF844C8D58B75071523E79DFBAB2EA37479DF12C474584F4FF40F00F92C6BADA025CE4DF8FAF0AFB2CE75C07773907CA288167D6B011599C3DE0FFF16C1161D31DF1C1DDE217CB574ED5A33751759F8ED2B1E6979C5088B940926B9155C9D250B479948C20ACB5578DC02C97593F646CC5C558A6A0F3D8D273258887CCFF259197CB1A7380622E371FD2EB5376225EC04F9ED1D1F2F08FA2376DB5B790E73086F581064ED1C5F47E989E955D77716B50FB64B853388FBA01DAC2CEAE99642341F2DA64C56BEFC4789C051E5EB79B063F2F084DB4491C3C5AA7B4BCF7DD7A1D7CED1554FA67DCA1F9515746A237547A4A1D22ACF649FA1ED3B9BB52BDE0C6996620F8CFDB293F8BACAD02BCE428363D0BB3D391469461D212769048219220A7ED39D1F9157DFEA3B4394CA8F5F612D9AC162BF0B961BFBC157E5F863CE659EB235CF98E8444BC8C7880BDDCD0B3B389AAA89D5E05F84D0649EEBACAB4F1C75352E89F0E9D91E4ACA264493A50D2F4AED66BD13650D1F18E7199E931C78AEB763E903807499F1CD99AF81276B615BE8EC709B039584B2B57445B014F6162577F3548329FD288B0800F936FC5EA1A412E3142E609FC8E39988CA53DF4D8FB5B5FB5F42C0A01648946AC6864CFB0E92856345B08E5DF0D235261E44CFE776456B40AEF0AC1A0DFA2FE639486666C05EA196B0C1A9D346435E03965E6139B1CE10129F8A53745F80100A94AE04D996C13AC14CF2713E39DFBB19A936CF3861318BD749B1FB82F40D73D714E406CBEB3D920EA037B7DE566455CCA51980F0F53A762D5BF8A4DBB55AAC0EDDB4B1F2AED2AA3D01449D34A57FDE4329E7FF3F6BECE4456207A4225218EE9F174C2DE0FF51CEAF2A07CF84F03D1DF316331E3E725C5421356C40ED25D5ABF9D24C4570FED618CA41000455DBD759E32E2BF0B6C5E61297C20F752C3042394CE840C70943C451DD5598EB0E4953CE26E833E5AF64FC1007C04456D19F87E45636F456B7DC9D31E757622E2739573342DE75497AE181AAE7A5425756C8E2A7EEF918E5C6A968AEFE92E8B261BBFE936B19F9E69A3C90094096DAE896450E1505ED5828EE2A7F0EA3A28E6EC47C0AF711823E7689166EA07ECA00FFC493131D65F93A4E1D03E0354AFC2115CFB8D23DAE8C6F96891031B23226B8BC82F1A73DAA5BB740FC8CC36C0975BEFA0C7895A9BBC261EDB7FD384103968F7A18353D5FE56274E4515768E4353046C785267DE01E816A2873F97AAD3AB4D7234EBFD9832716F43BE8245CF0B4408BA0F0F764CE9D24947AB6ABDD9879F24FCFF10078F5894B0D64F6A8D3EA3DD92A0C38609D3C14FDC0A44064D501926BE84BF8034F1D7A8C5F382E6989BFFA2109D4FBC56D1F091E8B6FABFF04D21BB19656929D19DECB8E8291E6AE5537A169874E0FE9890DFF11FFD159AD23D749FB9E8B676E2C31313C16D1EFA06F4D7BC191280A4EE63049FCEF23042B20303AECDD412A526D7A53F760A089FBDF13F361586F0DCA76BB928EDB41931D11F679619F948A6A9E8DBA919327769006303C6EF841438A7255C806242E2E7FF4621BB0F8AFA0B4A248EAD1A1E946F3E826FBFBBF8013CE5CC814E20FEF21FA5DB19EC7FF0B06C592247B27E500EB4705E6C37D41D09E83CB0A618008CA1AAAE8A215171D817659063C2FA385CFA3C1078D5C2B28CE7312876A276773821BE145785DFF24BBB24D590678158A61EA49F2BE56FDAC8CE7F94B05D62F15ADD351E5930FD4F31B3E7401D5C0FF7FC845B165FB6ABAFD4788A8B0615FEC91092B34B710A68DA518631622BA2AAE5D19010D307E565A161E64A4319A6B261FB2F6A90533997B1AEC32EF89CF1F232696E213DAFE4DBEB1CF1D5BBD12E5FF2EBB2809184E37CD9A0E58A4E0AF099493E6D8CC98B05A2F040A7E39515038F6EE21FC25F8D459A327B83EC1A28A234237ACD52465506942646AC248EC96EBBA6E1B092475F7ADAE4D35E009FD338613C7D4C12E381847310A10E6F02C02392FC32084FBE939689BC6518BE27AF7842DEEA8043828E3DFFE3BBAC4794CA0CC78699722709F2E4B0EAE7287DEB06A27B462423EC3F0DF227ACF589043292685F2C0E73203E8588B62554FF19D6260C7FE48DF301509D33BE0D8B31D3F658C921EF7F55449FF3887D91BFB894116DF57206098E8C5835B", "3C79A3BD824542C20AF71F21D6C28DF2213A041F77DD79A328A0078123954E7B"), // LongMsgKAT_256 Len = 16664 - ("7ADC0B6693E61C269F278E6944A5A2D8300981E40022F839AC644387BFAC9086650085C2CDC585FEA47B9D2E52D65A2B29A7DC370401EF5D60DD0D21F9E2B90FAE919319B14B8C5565B0423CEFB827D5F1203302A9D01523498A4DB10374", "4CC2AFF141987F4C2E683FA2DE30042BACDCD06087D7A7B014996E9CFEAA58CE"), // ShortMsgKAT_256 Len = 752 - ]; +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_harness, bitwise) = create_test_chips(&mut tester); - let mut io = vec![]; - for (input, output) in test_vectors { - let input = Vec::from_hex(input).unwrap(); - let output = Vec::from_hex(output).unwrap(); - io.push((input, Some(output.try_into().unwrap()), None)); + { + let mut dense_harness = create_test_chips::(&mut tester).0; + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut dense_harness, + &mut rng, + KECCAK256, + None, + None, + None, + ); + } + + let mut record_interpreter = dense_harness + .arena + .get_record_seeker::<_, KeccakVmRecordLayout>(); + record_interpreter.transfer_to_matrix_arena(&mut sparse_harness.arena); } - let tester = build_keccak256_test(io); + let tester = tester + .build() + .load(sparse_harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/extensions/keccak256/circuit/src/trace.rs b/extensions/keccak256/circuit/src/trace.rs index c314c38eac..695fbafd5d 100644 --- a/extensions/keccak256/circuit/src/trace.rs +++ b/extensions/keccak256/circuit/src/trace.rs @@ -1,16 +1,30 @@ -use std::{array::from_fn, borrow::BorrowMut, sync::Arc}; +use std::{ + array::{self, from_fn}, + borrow::{Borrow, BorrowMut}, + cmp::min, +}; -use openvm_circuit::system::memory::RecordId; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use openvm_circuit::{ + arch::*, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_keccak256_transpiler::Rv32KeccakOpcode; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, + p3_field::PrimeField32, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::get_air_name, - AirRef, Chip, ChipUsageGetter, }; use p3_keccak_air::{ generate_trace_rows, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, NUM_ROUNDS, U64_LIMBS, @@ -18,258 +32,537 @@ use p3_keccak_air::{ use tiny_keccak::keccakf; use super::{ - columns::{KeccakInstructionCols, KeccakVmCols}, - KeccakVmChip, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_RATE_U16S, + columns::KeccakVmCols, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_REGISTER_READS, NUM_ABSORB_ROUNDS, }; +use crate::{ + columns::NUM_KECCAK_VM_COLS, + utils::{keccak256, keccak_f, num_keccak_f}, + KeccakVmExecutor, KeccakVmFiller, KECCAK_DIGEST_BYTES, KECCAK_RATE_U16S, KECCAK_WORD_SIZE, +}; + +#[derive(Clone, Copy)] +pub struct KeccakVmMetadata { + pub len: usize, +} + +impl MultiRowMetadata for KeccakVmMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + num_keccak_f(self.len) * NUM_ROUNDS + } +} + +pub(crate) type KeccakVmRecordLayout = MultiRowLayout; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct KeccakVmRecordHeader { + pub from_pc: u32, + pub timestamp: u32, + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub dst: u32, + pub src: u32, + pub len: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; KECCAK_REGISTER_READS], + pub write_aux: [MemoryWriteBytesAuxRecord; KECCAK_DIGEST_WRITES], +} + +pub struct KeccakVmRecordMut<'a> { + pub inner: &'a mut KeccakVmRecordHeader, + // Having a continuous slice of the input is useful for fast hashing in `execute` + pub input: &'a mut [u8], + pub read_aux: &'a mut [MemoryReadAuxRecord], +} + +/// Custom borrowing that splits the buffer into a fixed `KeccakVmRecord` header +/// followed by a slice of `u8`'s of length `num_reads * KECCAK_WORD_SIZE` where `num_reads` is +/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length `num_reads`. +/// Uses `align_to_mut()` to make sure the slice is properly aligned to `MemoryReadAuxRecord`. +/// Has debug assertions that check the size and alignment of the slices. +impl<'a> CustomBorrow<'a, KeccakVmRecordMut<'a>, KeccakVmRecordLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: KeccakVmRecordLayout) -> KeccakVmRecordMut<'a> { + let (record_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + + let num_reads = layout.metadata.len.div_ceil(KECCAK_WORD_SIZE); + // Note: each read is `KECCAK_WORD_SIZE` bytes + let (input, rest) = unsafe { rest.split_at_mut_unchecked(num_reads * KECCAK_WORD_SIZE) }; + let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; + KeccakVmRecordMut { + inner: record_buf.borrow_mut(), + input, + read_aux: &mut read_aux_buf[..num_reads], + } + } -impl Chip for KeccakVmChip> + unsafe fn extract_layout(&self) -> KeccakVmRecordLayout { + let header: &KeccakVmRecordHeader = self.borrow(); + KeccakVmRecordLayout { + metadata: KeccakVmMetadata { + len: header.len as usize, + }, + } + } +} + +impl SizedRecord for KeccakVmRecordMut<'_> { + fn size(layout: &KeccakVmRecordLayout) -> usize { + let num_reads = layout.metadata.len.div_ceil(KECCAK_WORD_SIZE); + let mut total_len = size_of::(); + total_len += num_reads * KECCAK_WORD_SIZE; + // Align the pointer to the alignment of `MemoryReadAuxRecord` + total_len = total_len.next_multiple_of(align_of::()); + total_len += num_reads * size_of::(); + total_len + } + + fn alignment(_layout: &KeccakVmRecordLayout) -> usize { + align_of::() + } +} + +impl PreflightExecutor for KeccakVmExecutor where - Val: PrimeField32, + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, KeccakVmRecordLayout, KeccakVmRecordMut<'buf>>, { - fn air(&self) -> AirRef { - Arc::new(self.air) + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32KeccakOpcode::KECCAK256) } - fn generate_air_proof_input(self) -> AirProofInput { - let trace_width = self.trace_width(); - let records = self.records; - let total_num_blocks: usize = records.iter().map(|r| r.input_blocks.len()).sum(); - let mut states = Vec::with_capacity(total_num_blocks); - let mut instruction_blocks = Vec::with_capacity(total_num_blocks); - let memory = self.offline_memory.lock().unwrap(); - - #[derive(Clone)] - struct StateDiff { - /// hi-byte of pre-state - pre_hi: [u8; KECCAK_RATE_U16S], - /// hi-byte of post-state - post_hi: [u8; KECCAK_RATE_U16S], - /// if first block - register_reads: Option<[RecordId; KECCAK_REGISTER_READS]>, - /// if last block - digest_writes: Option<[RecordId; KECCAK_DIGEST_WRITES]>, + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(opcode, Rv32KeccakOpcode::KECCAK256.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Reading the length first without tracing to allocate a record of correct size + let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()) as usize; + + let num_reads = len.div_ceil(KECCAK_WORD_SIZE); + let num_blocks = num_keccak_f(len); + let record = state + .ctx + .alloc(KeccakVmRecordLayout::new(KeccakVmMetadata { len })); + + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.rd_ptr = a.as_canonical_u32(); + record.inner.rs1_ptr = b.as_canonical_u32(); + record.inner.rs2_ptr = c.as_canonical_u32(); + + record.inner.dst = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rd_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.src = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs1_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.len = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs2_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + debug_assert!(record.inner.src as usize + len <= (1 << self.pointer_max_bits)); + debug_assert!( + record.inner.dst as usize + KECCAK_DIGEST_BYTES <= (1 << self.pointer_max_bits) + ); + // We don't support messages longer than 2^[pointer_max_bits] bytes + debug_assert!(record.inner.len < (1 << self.pointer_max_bits)); + + for idx in 0..num_reads { + if idx % KECCAK_ABSORB_READS == 0 && idx != 0 { + // Need to increment the timestamp according at the start of each block due to the + // AIR constraints + state + .memory + .increment_timestamp_by(KECCAK_REGISTER_READS as u32); + } + let read = tracing_read::( + state.memory, + RV32_MEMORY_AS, + record.inner.src + (idx * KECCAK_WORD_SIZE) as u32, + &mut record.read_aux[idx].prev_timestamp, + ); + record.input[idx * KECCAK_WORD_SIZE..(idx + 1) * KECCAK_WORD_SIZE] + .copy_from_slice(&read); } - impl Default for StateDiff { - fn default() -> Self { - Self { - pre_hi: [0; KECCAK_RATE_U16S], - post_hi: [0; KECCAK_RATE_U16S], - register_reads: None, - digest_writes: None, - } + // Due to the AIR constraints, need to set the timestamp to the following: + state.memory.timestamp = record.inner.timestamp + + (num_blocks * (KECCAK_ABSORB_READS + KECCAK_REGISTER_READS)) as u32; + + let digest = keccak256(&record.input[..len]); + for (i, word) in digest.chunks_exact(KECCAK_WORD_SIZE).enumerate() { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst + (i * KECCAK_WORD_SIZE) as u32, + word.try_into().unwrap(), + &mut record.inner.write_aux[i].prev_timestamp, + &mut record.inner.write_aux[i].prev_data, + ); + } + + // Due to the AIR constraints, the final memory timestamp should be the following: + state.memory.timestamp = record.inner.timestamp + + (len + KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32; + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } +} + +impl TraceFiller for KeccakVmFiller { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace_matrix: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let mut chunks = Vec::with_capacity(trace_matrix.height() / NUM_ROUNDS); + let mut sizes = Vec::with_capacity(trace_matrix.height() / NUM_ROUNDS); + let mut trace = &mut trace_matrix.values[..]; + let mut num_blocks_so_far = 0; + + // First pass over the trace to get the number of blocks for each instruction + // and divide the matrix into chunks of needed sizes + loop { + if num_blocks_so_far * NUM_ROUNDS >= rows_used { + // Push all the dummy rows as a single chunk and break + chunks.push(trace); + sizes.push((0, 0)); + break; + } else { + let record: &KeccakVmRecordHeader = + unsafe { get_record_from_slice(&mut trace, ()) }; + let num_blocks = num_keccak_f(record.len as usize); + let (chunk, rest) = + trace.split_at_mut(NUM_KECCAK_VM_COLS * NUM_ROUNDS * num_blocks); + chunks.push(chunk); + sizes.push((num_blocks, record.len as usize)); + num_blocks_so_far += num_blocks; + trace = rest; } } - // prepare the states - let mut state: [u64; 25]; - for record in records { - let dst_read = memory.record_by_id(record.dst_read); - let src_read = memory.record_by_id(record.src_read); - let len_read = memory.record_by_id(record.len_read); - - state = [0u64; 25]; - let src_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = src_read.data_slice() - [1..RV32_REGISTER_NUM_LIMBS] - .try_into() - .unwrap(); - let len_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = len_read.data_slice() - [1..RV32_REGISTER_NUM_LIMBS] - .try_into() - .unwrap(); - let mut instruction = KeccakInstructionCols { - pc: record.pc, - is_enabled: Val::::ONE, - is_enabled_first_round: Val::::ZERO, - start_timestamp: Val::::from_canonical_u32(dst_read.timestamp), - dst_ptr: dst_read.pointer, - src_ptr: src_read.pointer, - len_ptr: len_read.pointer, - dst: dst_read.data_slice().try_into().unwrap(), - src_limbs, - src: Val::::from_canonical_usize(record.input_blocks[0].src), - len_limbs, - remaining_len: Val::::from_canonical_usize( - record.input_blocks[0].remaining_len, - ), - }; - let num_blocks = record.input_blocks.len(); - for (idx, block) in record.input_blocks.into_iter().enumerate() { - // absorb - for (bytes, s) in block.padded_bytes.chunks_exact(8).zip(state.iter_mut()) { - // u64 <-> bytes conversion is little-endian - for (i, &byte) in bytes.iter().enumerate() { - let s_byte = (*s >> (i * 8)) as u8; - // Update bitwise lookup (i.e. xor) chip state: order matters! - if idx != 0 { - self.bitwise_lookup_chip - .request_xor(byte as u32, s_byte as u32); - } - *s ^= (byte as u64) << (i * 8); - } - } - let pre_hi: [u8; KECCAK_RATE_U16S] = - from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8); - states.push(state); - keccakf(&mut state); - let post_hi: [u8; KECCAK_RATE_U16S] = - from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8); - // Range check the final state - if idx == num_blocks - 1 { - for s in state.into_iter().take(NUM_ABSORB_ROUNDS) { - for s_byte in s.to_le_bytes() { - self.bitwise_lookup_chip.request_xor(0, s_byte as u32); - } - } + // First, parallelize over instruction chunks, every instruction can have multiple blocks + // Then, compute some additional values for each block and parallelize over blocks within an + // instruction Finally, compute some additional values for each row and parallelize + // over rows within a block + chunks + .par_iter_mut() + .zip(sizes.par_iter()) + .for_each(|(slice, (num_blocks, len))| { + if *num_blocks == 0 { + // Fill in the dummy rows in parallel + // Note: a 'block' of dummy rows is generated by `generate_trace_rows` from the + // zero state dummy rows are repeated every + // `NUM_ROUNDS` rows + let p3_trace: RowMajorMatrix = generate_trace_rows(vec![[0u64; 25]; 1], 0); + + slice + .par_chunks_exact_mut(NUM_KECCAK_VM_COLS) + .enumerate() + .for_each(|(row_idx, row)| { + let idx = row_idx % NUM_ROUNDS; + row[..NUM_KECCAK_PERM_COLS].copy_from_slice( + &p3_trace.values + [idx * NUM_KECCAK_PERM_COLS..(idx + 1) * NUM_KECCAK_PERM_COLS], + ); + + // Need to get rid of the accidental garbage data that might overflow + // the F's prime field. Unfortunately, there + // is no good way around this + unsafe { + std::ptr::write_bytes( + row.as_mut_ptr().add(NUM_KECCAK_PERM_COLS) as *mut u8, + 0, + (NUM_KECCAK_VM_COLS - NUM_KECCAK_PERM_COLS) * size_of::(), + ); + } + let cols: &mut KeccakVmCols = row.borrow_mut(); + // The first row of a `dummy` block should have `is_new_start = F::ONE` + cols.sponge.is_new_start = F::from_bool(idx == 0); + cols.sponge.block_bytes[0] = F::ONE; + cols.sponge.block_bytes[KECCAK_RATE_BYTES - 1] = + F::from_canonical_u32(0x80); + cols.sponge.is_padding_byte = [F::ONE; KECCAK_RATE_BYTES]; + }); + return; } - let register_reads = - (idx == 0).then_some([record.dst_read, record.src_read, record.len_read]); - let digest_writes = (idx == num_blocks - 1).then_some(record.digest_writes); - let diff = StateDiff { - pre_hi, - post_hi, - register_reads, - digest_writes, + + let num_reads = len.div_ceil(KECCAK_WORD_SIZE); + let read_len = num_reads * KECCAK_WORD_SIZE; + + let record: KeccakVmRecordMut = unsafe { + get_record_from_slice( + slice, + KeccakVmRecordLayout::new(KeccakVmMetadata { len: *len }), + ) }; - instruction_blocks.push((instruction, diff, block)); - instruction.remaining_len -= Val::::from_canonical_usize(KECCAK_RATE_BYTES); - instruction.src += Val::::from_canonical_usize(KECCAK_RATE_BYTES); - instruction.start_timestamp += - Val::::from_canonical_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS); - } - } - // We need to transpose state matrices due to a plonky3 issue: https://github.com/Plonky3/Plonky3/issues/672 - // Note: the fix for this issue will be a commit after the major Field crate refactor PR https://github.com/Plonky3/Plonky3/pull/640 - // which will require a significant refactor to switch to. - let p3_states = states - .into_iter() - .map(|state| { - // transpose of 5x5 matrix - from_fn(|i| { - let x = i / 5; - let y = i % 5; - state[x + 5 * y] - }) - }) - .collect(); - let p3_keccak_trace: RowMajorMatrix> = generate_trace_rows(p3_states, 0); - let num_rows = p3_keccak_trace.height(); - // Every `NUM_ROUNDS` rows corresponds to one input block - let num_blocks = num_rows.div_ceil(NUM_ROUNDS); - // Resize with dummy `is_enabled = 0` - instruction_blocks.resize(num_blocks, Default::default()); - - let aux_cols_factory = memory.aux_cols_factory(); - - // Use unsafe alignment so we can parallelly write to the matrix - let mut trace = - RowMajorMatrix::new(Val::::zero_vec(num_rows * trace_width), trace_width); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.ptr_max_bits; - - trace - .values - .par_chunks_mut(trace_width * NUM_ROUNDS) - .zip( - p3_keccak_trace - .values - .par_chunks(NUM_KECCAK_PERM_COLS * NUM_ROUNDS), - ) - .zip(instruction_blocks.into_par_iter()) - .for_each(|((rows, p3_keccak_mat), (instruction, diff, block))| { - let height = rows.len() / trace_width; - for (row, p3_keccak_row) in rows - .chunks_exact_mut(trace_width) - .zip(p3_keccak_mat.chunks_exact(NUM_KECCAK_PERM_COLS)) - { - // Safety: `KeccakPermCols` **must** be the first field in `KeccakVmCols` - row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_keccak_row); - let row_mut: &mut KeccakVmCols> = row.borrow_mut(); - row_mut.instruction = instruction; - - row_mut.sponge.block_bytes = - block.padded_bytes.map(Val::::from_canonical_u8); - if let Some(partial_read_idx) = block.partial_read_idx { - let partial_read = memory.record_by_id(block.reads[partial_read_idx]); - row_mut - .mem_oc - .partial_block - .copy_from_slice(&partial_read.data_slice()[1..]); - } - for (i, is_padding) in row_mut.sponge.is_padding_byte.iter_mut().enumerate() { - *is_padding = Val::::from_bool(i >= block.remaining_len); - } + // Copy the read aux records and inner record to another place + // to safely fill in the trace matrix without overwriting the record + let mut read_aux_records = Vec::with_capacity(num_reads); + read_aux_records.extend_from_slice(record.read_aux); + let vm_record = record.inner.clone(); + let partial_block = if read_len != *len { + record.input[read_len - KECCAK_WORD_SIZE + 1..] + .try_into() + .unwrap() + } else { + [0u8; KECCAK_WORD_SIZE - 1] } - let first_row: &mut KeccakVmCols> = rows[..trace_width].borrow_mut(); - first_row.sponge.is_new_start = Val::::from_bool(block.is_new_start); - first_row.sponge.state_hi = diff.pre_hi.map(Val::::from_canonical_u8); - first_row.instruction.is_enabled_first_round = first_row.instruction.is_enabled; - // Make memory access aux columns. Any aux column not explicitly defined defaults to - // all 0s - if let Some(register_reads) = diff.register_reads { - let need_range_check = [ - ®ister_reads[0], // dst - ®ister_reads[1], // src - ®ister_reads[2], // len - ®ister_reads[2], - ] - .map(|r| { - memory - .record_by_id(*r) - .data_slice() - .last() - .unwrap() - .as_canonical_u32() + .map(F::from_canonical_u8); + let mut input = Vec::with_capacity(*num_blocks * KECCAK_RATE_BYTES); + input.extend_from_slice(&record.input[..*len]); + // Pad the input according to the Keccak spec + input.push(0x01); + input.resize(input.capacity(), 0); + *input.last_mut().unwrap() += 0x80; + + let mut states = Vec::with_capacity(*num_blocks); + let mut state = [0u64; 25]; + + input + .chunks_exact(KECCAK_RATE_BYTES) + .enumerate() + .for_each(|(idx, chunk)| { + // absorb + for (bytes, s) in chunk.chunks_exact(8).zip(state.iter_mut()) { + // u64 <-> bytes conversion is little-endian + for (i, &byte) in bytes.iter().enumerate() { + let s_byte = (*s >> (i * 8)) as u8; + // Update bitwise lookup (i.e. xor) chip state: order matters! + if idx != 0 { + self.bitwise_lookup_chip + .request_xor(byte as u32, s_byte as u32); + } + *s ^= (byte as u64) << (i * 8); + } + } + states.push(state); + keccakf(&mut state); }); - for bytes in need_range_check.chunks(2) { - self.bitwise_lookup_chip.request_range( - bytes[0] << limb_shift_bits, - bytes[1] << limb_shift_bits, - ); - } - for (i, id) in register_reads.into_iter().enumerate() { - aux_cols_factory.generate_read_aux( - memory.record_by_id(id), - &mut first_row.mem_oc.register_aux[i], - ); - } - } - for (i, id) in block.reads.into_iter().enumerate() { - aux_cols_factory.generate_read_aux( - memory.record_by_id(id), - &mut first_row.mem_oc.absorb_reads[i], - ); - } - let last_row: &mut KeccakVmCols> = - rows[(height - 1) * trace_width..].borrow_mut(); - last_row.sponge.state_hi = diff.post_hi.map(Val::::from_canonical_u8); - last_row.inner.export = instruction.is_enabled - * Val::::from_bool(block.remaining_len < KECCAK_RATE_BYTES); - if let Some(digest_writes) = diff.digest_writes { - for (i, record_id) in digest_writes.into_iter().enumerate() { - let record = memory.record_by_id(record_id); - aux_cols_factory - .generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]); - } - } - }); + slice + .par_chunks_exact_mut(NUM_ROUNDS * NUM_KECCAK_VM_COLS) + .enumerate() + .for_each(|(block_idx, block_slice)| { + // We need to transpose state matrices due to a plonky3 issue: https://github.com/Plonky3/Plonky3/issues/672 + // Note: the fix for this issue will be a commit after the major Field crate refactor PR https://github.com/Plonky3/Plonky3/pull/640 + // which will require a significant refactor to switch to. + let state = from_fn(|i| { + let x = i / 5; + let y = i % 5; + states[block_idx][x + 5 * y] + }); - AirProofInput::simple_no_pis(trace) - } -} + // Note: we can call `generate_trace_rows` for each block separately because + // its trace only depends on the current `state` + // `generate_trace_rows` will generate additional dummy rows to make the + // height into power of 2, but we can safely discard them + let p3_trace: RowMajorMatrix = generate_trace_rows(vec![state], 0); + let input_offset = block_idx * KECCAK_RATE_BYTES; + let start_timestamp = vm_record.timestamp + + (block_idx * (KECCAK_REGISTER_READS + KECCAK_ABSORB_READS)) as u32; + let rem_len = *len - input_offset; -impl ChipUsageGetter for KeccakVmChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - let num_blocks: usize = self.records.iter().map(|r| r.input_blocks.len()).sum(); - num_blocks * NUM_ROUNDS - } + block_slice + .par_chunks_exact_mut(NUM_KECCAK_VM_COLS) + .enumerate() + .zip(p3_trace.values.par_chunks(NUM_KECCAK_PERM_COLS)) + .for_each(|((row_idx, row), p3_row)| { + // Fill the inner columns + // Safety: `KeccakPermCols` **must** be the first field in + // `KeccakVmCols` + row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_row); - fn trace_width(&self) -> usize { - BaseAir::::width(&self.air) + let cols: &mut KeccakVmCols = row.borrow_mut(); + // Fill the sponge columns + cols.sponge.is_new_start = + F::from_bool(block_idx == 0 && row_idx == 0); + if rem_len < KECCAK_RATE_BYTES { + cols.sponge.is_padding_byte[..rem_len].fill(F::ZERO); + cols.sponge.is_padding_byte[rem_len..].fill(F::ONE); + } else { + cols.sponge.is_padding_byte = [F::ZERO; KECCAK_RATE_BYTES]; + } + cols.sponge.block_bytes = array::from_fn(|i| { + F::from_canonical_u8(input[input_offset + i]) + }); + if row_idx == 0 { + cols.sponge.state_hi = from_fn(|i| { + F::from_canonical_u8( + (states[block_idx][i / U64_LIMBS] + >> ((i % U64_LIMBS) * 16 + 8)) + as u8, + ) + }); + } else if row_idx == NUM_ROUNDS - 1 { + let state = keccak_f(states[block_idx]); + cols.sponge.state_hi = from_fn(|i| { + F::from_canonical_u8( + (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) + as u8, + ) + }); + if block_idx == num_blocks - 1 { + cols.inner.export = F::ONE; + for s in state.into_iter().take(NUM_ABSORB_ROUNDS) { + for s_byte in s.to_le_bytes() { + self.bitwise_lookup_chip + .request_xor(0, s_byte as u32); + } + } + } + } else { + cols.sponge.state_hi = [F::ZERO; KECCAK_RATE_U16S]; + } + + // Fill the instruction columns + cols.instruction.pc = F::from_canonical_u32(vm_record.from_pc); + cols.instruction.is_enabled = F::ONE; + cols.instruction.is_enabled_first_round = + F::from_bool(row_idx == 0); + cols.instruction.start_timestamp = + F::from_canonical_u32(start_timestamp); + cols.instruction.dst_ptr = F::from_canonical_u32(vm_record.rd_ptr); + cols.instruction.src_ptr = F::from_canonical_u32(vm_record.rs1_ptr); + cols.instruction.len_ptr = F::from_canonical_u32(vm_record.rs2_ptr); + cols.instruction.dst = + vm_record.dst.to_le_bytes().map(F::from_canonical_u8); + + let src = vm_record.src + (block_idx * KECCAK_RATE_BYTES) as u32; + cols.instruction.src = F::from_canonical_u32(src); + cols.instruction.src_limbs.copy_from_slice( + &src.to_le_bytes().map(F::from_canonical_u8)[1..], + ); + cols.instruction.len_limbs.copy_from_slice( + &(rem_len as u32).to_le_bytes().map(F::from_canonical_u8)[1..], + ); + cols.instruction.remaining_len = + F::from_canonical_u32(rem_len as u32); + + // Fill the register reads + if row_idx == 0 && block_idx == 0 { + for ((i, cols), vm_record) in cols + .mem_oc + .register_aux + .iter_mut() + .enumerate() + .zip(vm_record.register_reads_aux.iter()) + { + mem_helper.fill( + vm_record.prev_timestamp, + start_timestamp + i as u32, + cols.as_mut(), + ); + } + + let msl_rshift = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + let msl_lshift = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS + - self.pointer_max_bits; + // Update the bitwise lookup chip + self.bitwise_lookup_chip.request_range( + (vm_record.dst >> msl_rshift) << msl_lshift, + (vm_record.src >> msl_rshift) << msl_lshift, + ); + self.bitwise_lookup_chip.request_range( + (vm_record.len >> msl_rshift) << msl_lshift, + (vm_record.len >> msl_rshift) << msl_lshift, + ); + } else { + cols.mem_oc.register_aux.par_iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + } + + // Fill the absorb reads + if row_idx == 0 { + let reads_offs = block_idx * KECCAK_ABSORB_READS; + let num_reads = min( + rem_len.div_ceil(KECCAK_WORD_SIZE), + KECCAK_ABSORB_READS, + ); + let start_timestamp = + start_timestamp + KECCAK_REGISTER_READS as u32; + for i in 0..num_reads { + mem_helper.fill( + read_aux_records[i + reads_offs].prev_timestamp, + start_timestamp + i as u32, + cols.mem_oc.absorb_reads[i].as_mut(), + ); + } + for i in num_reads..KECCAK_ABSORB_READS { + mem_helper.fill_zero(cols.mem_oc.absorb_reads[i].as_mut()); + } + } else { + cols.mem_oc.absorb_reads.par_iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + } + + if block_idx == num_blocks - 1 && row_idx == NUM_ROUNDS - 1 { + let timestamp = start_timestamp + + (KECCAK_ABSORB_READS + KECCAK_REGISTER_READS) as u32; + cols.mem_oc + .digest_writes + .par_iter_mut() + .enumerate() + .zip(vm_record.write_aux.par_iter()) + .for_each(|((i, cols), vm_record)| { + cols.set_prev_data( + vm_record.prev_data.map(F::from_canonical_u8), + ); + mem_helper.fill( + vm_record.prev_timestamp, + timestamp + i as u32, + cols.as_mut(), + ); + }); + } else { + cols.mem_oc.digest_writes.par_iter_mut().for_each(|aux| { + aux.set_prev_data([F::ZERO; KECCAK_WORD_SIZE]); + mem_helper.fill_zero(aux.as_mut()); + }); + } + + // Set the partial block only for the last block + if block_idx == num_blocks - 1 { + cols.mem_oc.partial_block = partial_block; + } else { + cols.mem_oc.partial_block = [F::ZERO; KECCAK_WORD_SIZE - 1]; + } + }); + }); + }); } } diff --git a/extensions/keccak256/guest/src/lib.rs b/extensions/keccak256/guest/src/lib.rs index 7e2bb3da54..acfeea785b 100644 --- a/extensions/keccak256/guest/src/lib.rs +++ b/extensions/keccak256/guest/src/lib.rs @@ -1,5 +1,10 @@ #![no_std] +#[cfg(target_os = "zkvm")] +extern crate alloc; +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + /// This is custom-0 defined in RISC-V spec document pub const OPCODE: u8 = 0x0b; pub const KECCAK256_FUNCT3: u8 = 0b100; @@ -21,6 +26,43 @@ pub const KECCAK256_FUNCT7: u8 = 0; #[inline(always)] #[no_mangle] pub extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + const MIN_ALIGN: usize = 4; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, MIN_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN); + __native_keccak256(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_keccak256(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN); + __native_keccak256(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_keccak256(bytes, len, output); + } + }; + } +} + +/// keccak256 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { openvm_platform::custom_insn_r!( opcode = OPCODE, funct3 = KECCAK256_FUNCT3, diff --git a/extensions/native/circuit/Cargo.toml b/extensions/native/circuit/Cargo.toml index 5d5913b4be..67c3981ba7 100644 --- a/extensions/native/circuit/Cargo.toml +++ b/extensions/native/circuit/Cargo.toml @@ -17,23 +17,23 @@ openvm-circuit = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-rv32im-transpiler = { workspace = true } openvm-native-compiler = { workspace = true } strum.workspace = true itertools.workspace = true -tracing.workspace = true derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true serde.workspace = true -serde-big-array.workspace = true static_assertions.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel"] diff --git a/extensions/native/circuit/src/adapters/alu_native_adapter.rs b/extensions/native/circuit/src/adapters/alu_native_adapter.rs index e85797536f..24ce7dfbbb 100644 --- a/extensions/native/circuit/src/adapters/alu_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/alu_native_adapter.rs @@ -1,23 +1,26 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, + offline_checker::{ + MemoryBridge, MemoryReadAuxRecord, MemoryReadOrImmediateAuxCols, + MemoryWriteAuxCols, MemoryWriteAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - native_adapter::{NativeReadRecord, NativeWriteRecord}, - program::ProgramBus, + native_adapter::util::{tracing_read_or_imm_native, tracing_write_native}, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::conversion::AS; @@ -27,28 +30,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, }; -#[derive(Debug)] -pub struct AluNativeAdapterChip { - pub air: AluNativeAdapterAir, - _marker: PhantomData, -} - -impl AluNativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: AluNativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - #[repr(C)] #[derive(AlignedBorrow)] pub struct AluNativeAdapterCols { @@ -93,6 +74,8 @@ impl VmAdapterAir for AluNativeAdapterAir { let native_as = AB::Expr::from_canonical_u32(AS::Native as u32); + // TODO: we assume address space is either 0 or 4, should we add a + // constraint for that? self.memory_bridge .read_or_immediate( MemoryAddress::new(cols.e_as, cols.b_pointer), @@ -144,88 +127,131 @@ impl VmAdapterAir for AluNativeAdapterAir { } } -impl VmAdapterChip for AluNativeAdapterChip { - type ReadRecord = NativeReadRecord; - type WriteRecord = NativeWriteRecord; - type Air = AluNativeAdapterAir; - type Interface = BasicAdapterInterface, 2, 1, 1, 1>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct AluNativeAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, e, f, .. } = *instruction; - - let reads = vec![memory.read::<1>(e, b), memory.read::<1>(f, c)]; - let i_reads: [_; 2] = std::array::from_fn(|i| reads[i].1); - - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) + pub a_ptr: F, + pub b: F, + pub c: F, + + // Will set prev_timestamp to `u32::MAX` if the read is an immediate + pub reads_aux: [MemoryReadAuxRecord; 2], + pub write_aux: MemoryWriteAuxRecord, +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct AluNativeAdapterExecutor; + +#[derive(derive_new::new)] +pub struct AluNativeAdapterFiller; + +impl AdapterTraceExecutor for AluNativeAdapterExecutor { + const WIDTH: usize = size_of::>(); + type ReadData = [F; 2]; + type WriteData = [F; 1]; + type RecordMut<'a> = &'a mut AluNativeAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, .. } = *_instruction; - let writes = vec![memory.write( - F::from_canonical_u32(AS::Native as u32), - a, - output.writes[0], - )]; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: writes.try_into().unwrap(), - }, - )) + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, c, e, f, .. } = instruction; + + record.b = b; + let rs1 = tracing_read_or_imm_native(memory, e, b, &mut record.reads_aux[0].prev_timestamp); + record.c = c; + let rs2 = tracing_read_or_imm_native(memory, f, c, &mut record.reads_aux[1].prev_timestamp); + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let row_slice: &mut AluNativeAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); + let &Instruction { a, .. } = instruction; - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); + record.a_ptr = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data, + &mut record.write_aux.prev_timestamp, + &mut record.write_aux.prev_data, + ); + } +} - row_slice.a_pointer = memory.record_by_id(write_record.writes[0].0).pointer; - row_slice.b_pointer = memory.record_by_id(read_record.reads[0].0).pointer; - row_slice.c_pointer = memory.record_by_id(read_record.reads[1].0).pointer; - row_slice.e_as = memory.record_by_id(read_record.reads[0].0).address_space; - row_slice.f_as = memory.record_by_id(read_record.reads[1].0).address_space; +impl AdapterTraceFiller for AluNativeAdapterFiller { + const WIDTH: usize = size_of::>(); - for (i, x) in read_record.reads.iter().enumerate() { - let read = memory.record_by_id(x.0); - aux_cols_factory.generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i]); + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &AluNativeAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut AluNativeAdapterCols = adapter_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + adapter_row + .write_aux + .set_prev_data(record.write_aux.prev_data); + mem_helper.fill( + record.write_aux.prev_timestamp, + record.from_timestamp + 2, + adapter_row.write_aux.as_mut(), + ); + + let native_as = F::from_canonical_u32(AS::Native as u32); + for ((i, read_record), read_cols) in record + .reads_aux + .iter() + .enumerate() + .zip(adapter_row.reads_aux.iter_mut()) + .rev() + { + let as_col = if i == 0 { + &mut adapter_row.e_as + } else { + &mut adapter_row.f_as + }; + // previous timestamp is u32::MAX if the read is an immediate + if read_record.prev_timestamp == u32::MAX { + read_cols.is_zero_aux = F::ZERO; + read_cols.is_immediate = F::ONE; + mem_helper.fill(0, record.from_timestamp + i as u32, read_cols.as_mut()); + *as_col = F::ZERO; + } else { + read_cols.is_zero_aux = native_as.inverse(); + read_cols.is_immediate = F::ZERO; + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.as_mut(), + ); + *as_col = native_as; + } } - let write = memory.record_by_id(write_record.writes[0].0); - aux_cols_factory.generate_write_aux(write, &mut row_slice.write_aux); - } + adapter_row.c_pointer = record.c; + adapter_row.b_pointer = record.b; + adapter_row.a_pointer = record.a_ptr; - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/adapters/branch_native_adapter.rs b/extensions/native/circuit/src/adapters/branch_native_adapter.rs index 7d3e97a6bf..aa0c9c5259 100644 --- a/extensions/native/circuit/src/adapters/branch_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/branch_native_adapter.rs @@ -1,23 +1,23 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, + offline_checker::{MemoryBridge, MemoryReadAuxRecord, MemoryReadOrImmediateAuxCols}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - native_adapter::NativeReadRecord, - program::ProgramBus, + native_adapter::util::tracing_read_or_imm_native, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::conversion::AS; @@ -27,37 +27,15 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, }; -#[derive(Debug)] -pub struct BranchNativeAdapterChip { - pub air: BranchNativeAdapterAir, - _marker: PhantomData, -} - -impl BranchNativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: BranchNativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct BranchNativeAdapterReadCols { pub address: MemoryAddress, pub read_aux: MemoryReadOrImmediateAuxCols, } #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct BranchNativeAdapterCols { pub from_state: ExecutionState, pub reads_aux: [BranchNativeAdapterReadCols; 2], @@ -145,71 +123,110 @@ impl VmAdapterAir for BranchNativeAdapterAir { } } -impl VmAdapterChip for BranchNativeAdapterChip { - type ReadRecord = NativeReadRecord; - type WriteRecord = ExecutionState; - type Air = BranchNativeAdapterAir; - type Interface = BasicAdapterInterface, 2, 0, 1, 1>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct BranchNativeAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub ptrs: [F; 2], + // Will set prev_timestamp to `u32::MAX` if the read is an immediate + pub reads_aux: [MemoryReadAuxRecord; 2], +} - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, d, e, .. } = *instruction; - - let reads = vec![memory.read::<1>(d, a), memory.read::<1>(e, b)]; - let i_reads: [_; 2] = std::array::from_fn(|i| reads[i].1); - - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) +#[derive(derive_new::new, Clone, Copy)] +pub struct BranchNativeAdapterExecutor; + +#[derive(derive_new::new)] +pub struct BranchNativeAdapterFiller; + +impl AdapterTraceExecutor for BranchNativeAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [F; 2]; + type WriteData = (); + type RecordMut<'a> = &'a mut BranchNativeAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - from_state, - )) + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { a, b, d, e, .. } = instruction; + + record.ptrs[0] = a; + let rs1 = tracing_read_or_imm_native(memory, d, a, &mut record.reads_aux[0].prev_timestamp); + record.ptrs[1] = b; + let rs2 = tracing_read_or_imm_native(memory, e, b, &mut record.reads_aux[1].prev_timestamp); + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _data: Self::WriteData, + _record: &mut Self::RecordMut<'_>, ) { - let row_slice: &mut BranchNativeAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - row_slice.from_state = write_record.map(F::from_canonical_u32); - for (i, x) in read_record.reads.iter().enumerate() { - let read = memory.record_by_id(x.0); + // This adapter doesn't write anything + } +} - row_slice.reads_aux[i].address = MemoryAddress::new(read.address_space, read.pointer); - aux_cols_factory - .generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i].read_aux); +impl AdapterTraceFiller for BranchNativeAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &BranchNativeAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut BranchNativeAdapterCols = adapter_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + + let native_as = F::from_canonical_u32(AS::Native as u32); + for ((i, read_record), read_cols) in record + .reads_aux + .iter() + .enumerate() + .zip(adapter_row.reads_aux.iter_mut()) + .rev() + { + // previous timestamp is u32::MAX if the read is an immediate + if read_record.prev_timestamp == u32::MAX { + read_cols.read_aux.is_zero_aux = F::ZERO; + read_cols.read_aux.is_immediate = F::ONE; + mem_helper.fill( + 0, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = record.ptrs[i]; + read_cols.address.address_space = F::ZERO; + } else { + read_cols.read_aux.is_zero_aux = native_as.inverse(); + read_cols.read_aux.is_immediate = F::ZERO; + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = record.ptrs[i]; + read_cols.address.address_space = native_as; + } } - } - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/adapters/convert_adapter.rs b/extensions/native/circuit/src/adapters/convert_adapter.rs index cac6d91bac..9c76c73b59 100644 --- a/extensions/native/circuit/src/adapters/convert_adapter.rs +++ b/extensions/native/circuit/src/adapters/convert_adapter.rs @@ -1,71 +1,37 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::tracing_read_native, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_MEMORY_AS, +}; use openvm_native_compiler::conversion::AS; +use openvm_rv32im_circuit::adapters::tracing_write; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorReadRecord { - #[serde(with = "BigArray")] - pub reads: [RecordId; NUM_READS], -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorWriteRecord { - pub from_state: ExecutionState, - pub writes: [RecordId; 1], -} - -#[allow(dead_code)] -#[derive(Debug)] -pub struct ConvertAdapterChip { - pub air: ConvertAdapterAir, - _marker: PhantomData, -} - -impl - ConvertAdapterChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: ConvertAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} #[repr(C)] #[derive(AlignedBorrow)] @@ -155,74 +121,112 @@ impl Vm } } -impl VmAdapterChip - for ConvertAdapterChip -{ - type ReadRecord = VectorReadRecord<1, READ_SIZE>; - type WriteRecord = VectorWriteRecord; - type Air = ConvertAdapterAir; - type Interface = BasicAdapterInterface, 1, 1, READ_SIZE, WRITE_SIZE>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, e, .. } = *instruction; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct ConvertAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub a_ptr: F, + pub b_ptr: F, + + pub read_aux: MemoryReadAuxRecord, + pub write_aux: MemoryWriteBytesAuxRecord, +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct ConvertAdapterExecutor; - let y_val = memory.read::(e, b); +#[derive(derive_new::new)] +pub struct ConvertAdapterFiller; - Ok(([y_val.1], Self::ReadRecord { reads: [y_val.0] })) +impl AdapterTraceExecutor + for ConvertAdapterExecutor +{ + const WIDTH: usize = size_of::>(); + type ReadData = [F; READ_SIZE]; + type WriteData = [u8; WRITE_SIZE]; + type RecordMut<'a> = &'a mut ConvertAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (write_id, _) = memory.write::(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: [write_id], - }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, e, .. } = instruction; + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + + record.b_ptr = b; + + tracing_read_native( + memory, + b.as_canonical_u32(), + &mut record.read_aux.prev_timestamp, + ) } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut ConvertAdapterCols<_, READ_SIZE, WRITE_SIZE> = row_slice.borrow_mut(); - - let read = memory.record_by_id(read_record.reads[0]); - let write = memory.record_by_id(write_record.writes[0]); - - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.a_pointer = write.pointer; - row_slice.b_pointer = read.pointer; - - aux_cols_factory.generate_read_aux(read, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux[0]); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_MEMORY_AS); + + record.a_ptr = a; + tracing_write( + memory, + RV32_MEMORY_AS, + a.as_canonical_u32(), + data, + &mut record.write_aux.prev_timestamp, + &mut record.write_aux.prev_data, + ); } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterTraceFiller + for ConvertAdapterFiller +{ + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut row_slice: &mut [F]) { + let record: &ConvertAdapterRecord = + unsafe { get_record_from_slice(&mut row_slice, ()) }; + let adapter_row: &mut ConvertAdapterCols = row_slice.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + mem_helper.fill( + record.read_aux.prev_timestamp, + record.from_timestamp, + adapter_row.reads_aux[0].as_mut(), + ); + + adapter_row.writes_aux[0] + .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.write_aux.prev_timestamp, + record.from_timestamp + 1, + adapter_row.writes_aux[0].as_mut(), + ); + + adapter_row.b_pointer = record.b_ptr; + adapter_row.a_pointer = record.a_ptr; + + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs index 4bcf96d195..d33d74972e 100644 --- a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs @@ -5,19 +5,24 @@ use std::{ use openvm_circuit::{ arch::{ - instructions::LocalOpcode, AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, - ExecutionBus, ExecutionState, Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{tracing_read_native, tracing_write_native}, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, NativeLoadStoreOpcode::{self, *}, @@ -27,7 +32,6 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; pub struct NativeLoadStoreInstruction { pub is_valid: T, @@ -48,55 +52,6 @@ impl VmAdapterInterface type ProcessedInstruction = NativeLoadStoreInstruction; } -#[derive(Debug)] -pub struct NativeLoadStoreAdapterChip { - pub air: NativeLoadStoreAdapterAir, - offset: usize, - _marker: PhantomData, -} - -impl NativeLoadStoreAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - offset: usize, - ) -> Self { - Self { - air: NativeLoadStoreAdapterAir { - memory_bridge, - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - }, - offset, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeLoadStoreReadRecord { - pub pointer_read: RecordId, - pub data_read: Option, - pub write_as: F, - pub write_ptr: F, - - pub a: F, - pub b: F, - pub c: F, - pub d: F, - pub e: F, -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeLoadStoreWriteRecord { - pub from_state: ExecutionState, - pub write_id: RecordId, -} - #[repr(C)] #[derive(Clone, Debug, AlignedBorrow)] pub struct NativeLoadStoreAdapterCols { @@ -214,23 +169,52 @@ impl VmAdapterAir } } -impl VmAdapterChip - for NativeLoadStoreAdapterChip +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeLoadStoreAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + pub a: F, + pub b: F, + pub c: F, + pub write_ptr: F, + + pub ptr_read: MemoryReadAuxRecord, + // Will set `prev_timestamp` to u32::MAX if `HINT_STOREW` + pub data_read: MemoryReadAuxRecord, + pub data_write: MemoryWriteAuxRecord, +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct NativeLoadStoreAdapterExecutor { + offset: usize, +} + +#[derive(derive_new::new)] +pub struct NativeLoadStoreAdapterFiller; + +impl AdapterTraceExecutor + for NativeLoadStoreAdapterExecutor { - type ReadRecord = NativeLoadStoreReadRecord; - type WriteRecord = NativeLoadStoreWriteRecord; - type Air = NativeLoadStoreAdapterAir; - type Interface = NativeLoadStoreAdapterInterface; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = std::mem::size_of::>(); + type ReadData = (F, [F; NUM_CELLS]); + type WriteData = [F; NUM_CELLS]; + type RecordMut<'a> = &'a mut NativeLoadStoreAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp(); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { opcode, a, b, @@ -238,100 +222,116 @@ impl VmAdapterChip d, e, .. - } = *instruction; + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let read_as = d; - let read_ptr = c; - let read_cell = memory.read_cell(read_as, read_ptr); + record.a = a; + record.b = b; + record.c = c; - let (data_read_as, data_write_as) = { - match local_opcode { - LOADW => (e, d), - STOREW | HINT_STOREW => (d, e), + // Read the pointer value from memory + let [read_cell] = tracing_read_native::( + memory, + c.as_canonical_u32(), + &mut record.ptr_read.prev_timestamp, + ); + + let data_read_ptr = match local_opcode { + LOADW => read_cell + record.b, + STOREW | HINT_STOREW => record.a, + } + .as_canonical_u32(); + + // It's easier to do this here than in `write` + match local_opcode { + LOADW => record.write_ptr = record.a, + STOREW | HINT_STOREW => record.write_ptr = read_cell + record.b, + } + + // Read data based on opcode + let data_read: [F; NUM_CELLS] = match local_opcode { + HINT_STOREW => { + record.data_read.prev_timestamp = u32::MAX; + [F::ZERO; NUM_CELLS] } - }; - let (data_read_ptr, data_write_ptr) = { - match local_opcode { - LOADW => (read_cell.1 + b, a), - STOREW | HINT_STOREW => (a, read_cell.1 + b), + LOADW | STOREW => { + tracing_read_native(memory, data_read_ptr, &mut record.data_read.prev_timestamp) } }; - let data_read = match local_opcode { - HINT_STOREW => None, - LOADW | STOREW => Some(memory.read::(data_read_as, data_read_ptr)), - }; - let record = NativeLoadStoreReadRecord { - pointer_read: read_cell.0, - data_read: data_read.map(|x| x.0), - write_as: data_write_as, - write_ptr: data_write_ptr, - a, - b, - c, - d, - e, - }; - - Ok(( - (read_cell.1, data_read.map_or([F::ZERO; NUM_CELLS], |x| x.1)), - record, - )) + (read_cell, data_read) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let (write_id, _) = - memory.write::(read_record.write_as, read_record.write_ptr, output.writes); - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state: from_state.map(F::from_canonical_u32), - write_id, - }, - )) + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ) { + // Write data to memory + tracing_write_native( + memory, + record.write_ptr.as_canonical_u32(), + data, + &mut record.data_write.prev_timestamp, + &mut record.data_write.prev_data, + ); } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - let aux_cols_factory = memory.aux_cols_factory(); - let cols: &mut NativeLoadStoreAdapterCols<_, NUM_CELLS> = row_slice.borrow_mut(); - cols.from_state = write_record.from_state; - cols.a = read_record.a; - cols.b = read_record.b; - cols.c = read_record.c; - - let data_read = read_record.data_read.map(|read| memory.record_by_id(read)); - if let Some(data_read) = data_read { - aux_cols_factory.generate_read_aux(data_read, &mut cols.data_read_aux_cols); - } +impl AdapterTraceFiller + for NativeLoadStoreAdapterFiller +{ + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &NativeLoadStoreAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + + let is_hint_storew = record.data_read.prev_timestamp == u32::MAX; + + adapter_row + .data_write_aux_cols + .set_prev_data(record.data_write.prev_data); + // Note, if `HINT_STOREW` we didn't do a data read and we didn't update the timestamp + mem_helper.fill( + record.data_write.prev_timestamp, + record.from_timestamp + 2 - is_hint_storew as u32, + adapter_row.data_write_aux_cols.as_mut(), + ); - let write = memory.record_by_id(write_record.write_id); - cols.data_write_pointer = write.pointer; + if !is_hint_storew { + mem_helper.fill( + record.data_read.prev_timestamp, + record.from_timestamp + 1, + adapter_row.data_read_aux_cols.as_mut(), + ); + } else { + mem_helper.fill_zero(adapter_row.data_read_aux_cols.as_mut()); + } - aux_cols_factory.generate_read_aux( - memory.record_by_id(read_record.pointer_read), - &mut cols.pointer_read_aux_cols, + mem_helper.fill( + record.ptr_read.prev_timestamp, + record.from_timestamp, + adapter_row.pointer_read_aux_cols.as_mut(), ); - aux_cols_factory.generate_write_aux(write, &mut cols.data_write_aux_cols); - } - fn air(&self) -> &Self::Air { - &self.air + adapter_row.data_write_pointer = record.write_ptr; + adapter_row.c = record.c; + adapter_row.b = record.b; + adapter_row.a = record.a; + + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); } } diff --git a/extensions/native/circuit/src/adapters/mod.rs b/extensions/native/circuit/src/adapters/mod.rs index c5cd3b9422..308a0705a3 100644 --- a/extensions/native/circuit/src/adapters/mod.rs +++ b/extensions/native/circuit/src/adapters/mod.rs @@ -6,3 +6,9 @@ pub mod convert_adapter; pub mod loadstore_native_adapter; // 2 reads, 1 write, read size = write size = N, no imm support, read/write to address space d pub mod native_vectorized_adapter; + +pub use alu_native_adapter::*; +pub use branch_native_adapter::*; +pub use convert_adapter::*; +pub use loadstore_native_adapter::*; +pub use native_vectorized_adapter::*; diff --git a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs index c151197297..6545e8db39 100644 --- a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs +++ b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs @@ -1,22 +1,26 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{tracing_read_native, tracing_write_native}, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::conversion::AS; @@ -25,44 +29,6 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -#[allow(dead_code)] -#[derive(Debug)] -pub struct NativeVectorizedAdapterChip { - pub air: NativeVectorizedAdapterAir, - _marker: PhantomData, -} - -impl NativeVectorizedAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: NativeVectorizedAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeVectorizedReadRecord { - pub b: RecordId, - pub c: RecordId, -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeVectorizedWriteRecord { - pub from_state: ExecutionState, - pub a: RecordId, -} #[repr(C)] #[derive(AlignedBorrow)] @@ -156,80 +122,124 @@ impl VmAdapterAir for NativeVectoriz } } -impl VmAdapterChip for NativeVectorizedAdapterChip { - type ReadRecord = NativeVectorizedReadRecord; - type WriteRecord = NativeVectorizedWriteRecord; - type Air = NativeVectorizedAdapterAir; - type Interface = BasicAdapterInterface, 2, 1, N, N>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeVectorizedAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + pub a_ptr: F, + pub b_ptr: F, + pub c_ptr: F, + pub reads_aux: [MemoryReadAuxRecord; 2], + pub write_aux: MemoryWriteAuxRecord, +} - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; - - let y_val = memory.read::(d, b); - let z_val = memory.read::(e, c); - - Ok(( - [y_val.1, z_val.1], - Self::ReadRecord { - b: y_val.0, - c: z_val.0, - }, - )) +#[derive(derive_new::new, Clone, Copy)] +pub struct NativeVectorizedAdapterExecutor; + +#[derive(derive_new::new)] +pub struct NativeVectorizedAdapterFiller; + +impl AdapterTraceExecutor + for NativeVectorizedAdapterExecutor +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[F; N]; 2]; + type WriteData = [F; N]; + type RecordMut<'a> = &'a mut NativeVectorizedAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp(); } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (a_val, _) = memory.write(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - a: a_val, - }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, c, d, e, .. } = instruction; + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + + record.b_ptr = b; + let b_val = tracing_read_native( + memory, + b.as_canonical_u32(), + &mut record.reads_aux[0].prev_timestamp, + ); + record.c_ptr = c; + let c_val = tracing_read_native( + memory, + c.as_canonical_u32(), + &mut record.reads_aux[1].prev_timestamp, + ); + + [b_val, c_val] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut NativeVectorizedAdapterCols<_, N> = row_slice.borrow_mut(); - - let b_record = memory.record_by_id(read_record.b); - let c_record = memory.record_by_id(read_record.c); - let a_record = memory.record_by_id(write_record.a); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.a_pointer = a_record.pointer; - row_slice.b_pointer = b_record.pointer; - row_slice.c_pointer = c_record.pointer; - aux_cols_factory.generate_read_aux(b_record, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(c_record, &mut row_slice.reads_aux[1]); - aux_cols_factory.generate_write_aux(a_record, &mut row_slice.writes_aux[0]); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + + record.a_ptr = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data, + &mut record.write_aux.prev_timestamp, + &mut record.write_aux.prev_data, + ); } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterTraceFiller for NativeVectorizedAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &NativeVectorizedAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut NativeVectorizedAdapterCols = adapter_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + adapter_row.writes_aux[0].set_prev_data(record.write_aux.prev_data); + mem_helper.fill( + record.write_aux.prev_timestamp, + record.from_timestamp + 2, + adapter_row.writes_aux[0].as_mut(), + ); + + adapter_row + .reads_aux + .iter_mut() + .enumerate() + .zip(record.reads_aux.iter()) + .rev() + .for_each(|((i, read_cols), read_record)| { + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.as_mut(), + ); + }); + + adapter_row.c_pointer = record.c_ptr; + adapter_row.b_pointer = record.b_ptr; + adapter_row.a_pointer = record.a_ptr; + + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/branch_eq/core.rs b/extensions/native/circuit/src/branch_eq/core.rs new file mode 100644 index 0000000000..1a60782e0f --- /dev/null +++ b/extensions/native/circuit/src/branch_eq/core.rs @@ -0,0 +1,320 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, + utils::{transmute_field_to_u32, transmute_u32_to_field}, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_IMM_AS, LocalOpcode, NATIVE_AS, +}; +use openvm_native_compiler::NativeBranchEqualOpcode; +use openvm_rv32im_circuit::BranchEqualCoreCols; +use openvm_rv32im_transpiler::BranchEqualOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeBranchEqualCoreRecord { + pub a: F, + pub b: F, + pub imm: F, + pub is_beq: bool, +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct NativeBranchEqualExecutor { + adapter: A, + pub offset: usize, + pub pc_step: u32, +} + +#[derive(derive_new::new)] +pub struct NativeBranchEqualFiller { + adapter: A, +} + +impl PreflightExecutor for NativeBranchEqualExecutor +where + F: PrimeField32, + A: 'static + AdapterTraceExecutor, WriteData = ()>, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut NativeBranchEqualCoreRecord), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + NativeBranchEqualOpcode::from_usize(opcode - self.offset) + ) + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { opcode, c: imm, .. } = instruction; + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + [core_record.a, core_record.b] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let cmp_result = core_record.a == core_record.b; + + core_record.imm = imm; + core_record.is_beq = + opcode.local_opcode_idx(self.offset) == BranchEqualOpcode::BEQ as usize; + + if cmp_result == core_record.is_beq { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); + } + + Ok(()) + } +} + +impl TraceFiller for NativeBranchEqualFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &NativeBranchEqualCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut BranchEqualCoreCols = core_row.borrow_mut(); + let (cmp_result, diff_inv_val) = run_eq(record.is_beq, record.a, record.b); + + // Writing in reverse order to avoid overwriting the `record` + core_row.diff_inv_marker[0] = diff_inv_val; + + core_row.opcode_bne_flag = F::from_bool(!record.is_beq); + core_row.opcode_beq_flag = F::from_bool(record.is_beq); + + core_row.imm = record.imm; + core_row.cmp_result = F::from_bool(cmp_result); + + core_row.b = [record.b]; + core_row.a = [record.a]; + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeBranchEqualPreCompute { + imm: isize, + a_or_imm: u32, + b_or_imm: u32, +} + +impl NativeBranchEqualExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + _pc: u32, + inst: &Instruction, + data: &mut NativeBranchEqualPreCompute, + ) -> Result<(bool, bool, bool), StaticProgramError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + let a_is_imm = d == RV32_IMM_AS; + let b_is_imm = e == RV32_IMM_AS; + + let a_or_imm = if a_is_imm { + transmute_field_to_u32(&a) + } else { + a.as_canonical_u32() + }; + let b_or_imm = if b_is_imm { + transmute_field_to_u32(&b) + } else { + b.as_canonical_u32() + }; + + *data = NativeBranchEqualPreCompute { + imm, + a_or_imm, + b_or_imm, + }; + + let is_bne = local_opcode == BranchEqualOpcode::BNE; + + Ok((a_is_imm, b_is_imm, is_bne)) + } +} + +impl Executor for NativeBranchEqualExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut NativeBranchEqualPreCompute = data.borrow_mut(); + + let (a_is_imm, b_is_imm, is_bne) = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match (a_is_imm, b_is_imm, is_bne) { + (true, true, true) => execute_e1_impl::<_, _, true, true, true>, + (true, true, false) => execute_e1_impl::<_, _, true, true, false>, + (true, false, true) => execute_e1_impl::<_, _, true, false, true>, + (true, false, false) => execute_e1_impl::<_, _, true, false, false>, + (false, true, true) => execute_e1_impl::<_, _, false, true, true>, + (false, true, false) => execute_e1_impl::<_, _, false, true, false>, + (false, false, true) => execute_e1_impl::<_, _, false, false, true>, + (false, false, false) => execute_e1_impl::<_, _, false, false, false>, + }; + + Ok(fn_ptr) + } +} + +impl MeteredExecutor for NativeBranchEqualExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let (a_is_imm, b_is_imm, is_bne) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = match (a_is_imm, b_is_imm, is_bne) { + (true, true, true) => execute_e2_impl::<_, _, true, true, true>, + (true, true, false) => execute_e2_impl::<_, _, true, true, false>, + (true, false, true) => execute_e2_impl::<_, _, true, false, true>, + (true, false, false) => execute_e2_impl::<_, _, true, false, false>, + (false, true, true) => execute_e2_impl::<_, _, false, true, true>, + (false, true, false) => execute_e2_impl::<_, _, false, true, false>, + (false, false, true) => execute_e2_impl::<_, _, false, false, true>, + (false, false, false) => execute_e2_impl::<_, _, false, false, false>, + }; + + Ok(fn_ptr) + } +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const IS_NE: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &NativeBranchEqualPreCompute = pre_compute.borrow(); + execute_e12_impl::<_, _, A_IS_IMM, B_IS_IMM, IS_NE>(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const IS_NE: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, A_IS_IMM, B_IS_IMM, IS_NE>(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const IS_NE: bool, +>( + pre_compute: &NativeBranchEqualPreCompute, + vm_state: &mut VmExecState, +) { + let rs1 = if A_IS_IMM { + transmute_u32_to_field(&pre_compute.a_or_imm) + } else { + vm_state.vm_read::(NATIVE_AS, pre_compute.a_or_imm)[0] + }; + let rs2 = if B_IS_IMM { + transmute_u32_to_field(&pre_compute.b_or_imm) + } else { + vm_state.vm_read::(NATIVE_AS, pre_compute.b_or_imm)[0] + }; + if (rs1 == rs2) ^ IS_NE { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + } + vm_state.instret += 1; +} + +// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) +#[inline(always)] +pub(super) fn run_eq(is_beq: bool, x: F, y: F) -> (bool, F) +where + F: PrimeField32, +{ + if x != y { + return (!is_beq, (x - y).inverse()); + } + (is_beq, F::ZERO) +} diff --git a/extensions/native/circuit/src/branch_eq/mod.rs b/extensions/native/circuit/src/branch_eq/mod.rs index e1b566bb7f..9827384d67 100644 --- a/extensions/native/circuit/src/branch_eq/mod.rs +++ b/extensions/native/circuit/src/branch_eq/mod.rs @@ -1,8 +1,17 @@ use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use openvm_rv32im_circuit::{BranchEqualCoreAir, BranchEqualCoreChip}; +use openvm_rv32im_circuit::BranchEqualCoreAir; -use super::adapters::branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterChip}; +mod core; +pub use core::*; + +use crate::adapters::{ + BranchNativeAdapterAir, BranchNativeAdapterExecutor, BranchNativeAdapterFiller, +}; + +#[cfg(test)] +mod tests; pub type NativeBranchEqAir = VmAirWrapper>; +pub type NativeBranchEqExecutor = NativeBranchEqualExecutor; pub type NativeBranchEqChip = - VmChipWrapper, BranchEqualCoreChip<1>>; + VmChipWrapper>; diff --git a/extensions/native/circuit/src/branch_eq/tests.rs b/extensions/native/circuit/src/branch_eq/tests.rs new file mode 100644 index 0000000000..4a36045ed2 --- /dev/null +++ b/extensions/native/circuit/src/branch_eq/tests.rs @@ -0,0 +1,334 @@ +use std::borrow::BorrowMut; + +use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + utils::isize_to_field, + LocalOpcode, +}; +use openvm_native_compiler::NativeBranchEqualOpcode; +use openvm_rv32im_circuit::{ + adapters::RV_B_TYPE_IMM_BITS, BranchEqualCoreAir, BranchEqualCoreCols, +}; +use openvm_rv32im_transpiler::BranchEqualOpcode; +use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; + +use crate::{ + adapters::{BranchNativeAdapterAir, BranchNativeAdapterExecutor, BranchNativeAdapterFiller}, + branch_eq::{run_eq, NativeBranchEqAir, NativeBranchEqChip, NativeBranchEqExecutor}, + test_utils::write_native_or_imm, + NativeBranchEqualFiller, +}; + +type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); +type Harness = TestChipHarness>; + +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Harness { + let air = NativeBranchEqAir::new( + BranchNativeAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchEqualCoreAir::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ); + let executor = NativeBranchEqExecutor::new( + BranchNativeAdapterExecutor, + NativeBranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ); + let chip = NativeBranchEqChip::::new( + NativeBranchEqualFiller::new(BranchNativeAdapterFiller), + tester.memory_helper(), + ); + + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: NativeBranchEqualOpcode, + a: Option, + b: Option, + imm: Option, +) { + let a_val = a.unwrap_or(rng.gen()); + let b_val = b.unwrap_or(if rng.gen_bool(0.5) { a_val } else { rng.gen() }); + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); + let (a, a_as) = write_native_or_imm(tester, rng, a_val, None); + let (b, b_as) = write_native_or_imm(tester, rng, b_val, None); + let initial_pc = rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1)) - imm.unsigned_abs()); + + tester.execute_with_pc( + harness, + &Instruction::new( + opcode.global_opcode(), + a, + b, + isize_to_field::(imm as isize), + F::from_canonical_usize(a_as), + F::from_canonical_usize(b_as), + F::ZERO, + F::ZERO, + ), + initial_pc, + ); + + let cmp_result = run_eq(opcode.0 == BranchEqualOpcode::BEQ, a_val, b_val).0; + let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; + let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; + let pc_inc = if cmp_result { + imm + } else { + DEFAULT_PC_STEP as i32 + }; + + assert_eq!(to_pc, from_pc + pc_inc); +} + +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(BranchEqualOpcode::BEQ, 100)] +#[test_case(BranchEqualOpcode::BNE, 100)] +fn rand_rv32_branch_eq_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&mut tester); + let opcode = NativeBranchEqualOpcode(opcode); + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + None, + None, + None, + ); + } + + let tester = tester.build().load(harness).finalize(); + tester.simple_test().expect("Verification failed"); +} + +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[allow(clippy::too_many_arguments)] +fn run_negative_branch_eq_test( + opcode: BranchEqualOpcode, + a: F, + b: F, + prank_cmp_result: Option, + prank_diff_inv_marker: Option, + error: VerificationError, +) { + let imm = 16i32; + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&mut tester); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + NativeBranchEqualOpcode(opcode), + Some(a), + Some(b), + Some(imm), + ); + + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut BranchEqualCoreCols = + values.split_at_mut(adapter_width).1.borrow_mut(); + if let Some(cmp_result) = prank_cmp_result { + cols.cmp_result = F::from_bool(cmp_result); + } + if let Some(diff_inv_marker) = prank_diff_inv_marker { + cols.diff_inv_marker = [diff_inv_marker]; + } + *trace = RowMajorMatrix::new(values, trace.width()); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); +} + +#[test] +fn rv32_beq_wrong_cmp_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(true), + None, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 16), + Some(false), + None, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_beq_zero_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(true), + Some(F::ZERO), + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_beq_invalid_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(false), + Some(F::from_canonical_u32(1 << 16)), + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_bne_wrong_cmp_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(false), + None, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 16), + Some(true), + None, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_bne_zero_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(false), + Some(F::ZERO), + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_bne_invalid_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(true), + Some(F::from_canonical_u32(1 << 16)), + VerificationError::OodEvaluationMismatch, + ); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&mut tester); + + let x = F::from_canonical_u32(u32::from_le_bytes([19, 4, 179, 60])); + let y = F::from_canonical_u32(u32::from_le_bytes([19, 32, 180, 60])); + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + NativeBranchEqualOpcode(BranchEqualOpcode::BEQ), + Some(x), + Some(y), + Some(8), + ); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + NativeBranchEqualOpcode(BranchEqualOpcode::BNE), + Some(x), + Some(y), + Some(8), + ); +} + +#[test] +fn run_eq_sanity_test() { + let x = F::from_canonical_u32(u32::from_le_bytes([19, 4, 17, 60])); + let (cmp_result, diff_val) = run_eq(true, x, x); + assert!(cmp_result); + assert_eq!(diff_val, F::ZERO); + + let (cmp_result, diff_val) = run_eq(false, x, x); + assert!(!cmp_result); + assert_eq!(diff_val, F::ZERO); +} + +#[test] +fn run_ne_sanity_test() { + let x = F::from_canonical_u32(u32::from_le_bytes([19, 4, 17, 60])); + let y = F::from_canonical_u32(u32::from_le_bytes([19, 32, 18, 60])); + let (cmp_result, diff_val) = run_eq(true, x, y); + assert!(!cmp_result); + assert_eq!(diff_val * (x - y), F::ONE); + + let (cmp_result, diff_val) = run_eq(false, x, y); + assert!(cmp_result); + assert_eq!(diff_val * (x - y), F::ONE); +} diff --git a/extensions/native/circuit/src/castf/core.rs b/extensions/native/circuit/src/castf/core.rs index 664767e35e..e48c1f06e1 100644 --- a/extensions/native/circuit/src/castf/core.rs +++ b/extensions/native/circuit/src/castf/core.rs @@ -1,15 +1,21 @@ use std::borrow::{Borrow, BorrowMut}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit_primitives::{ + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::CastfOpcode; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_MEMORY_AS, LocalOpcode, +}; +use openvm_native_compiler::{conversion::AS, CastfOpcode}; use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -17,7 +23,8 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; + +use crate::CASTF_MAX_BITS; // LIMB_BITS is the size of the limbs in bits. pub(crate) const LIMB_BITS: usize = 8; @@ -32,7 +39,7 @@ pub struct CastFCoreCols { pub is_valid: T, } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct CastFCoreAir { pub bus: VariableRangeCheckerBus, /* to communicate with the range checker that checks that * all limbs are < 2^LIMB_BITS */ @@ -105,97 +112,216 @@ where } #[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct CastFRecord { - pub in_val: F, - pub out_val: [u32; RV32_REGISTER_NUM_LIMBS], +#[derive(AlignedBytesBorrow, Debug)] +pub struct CastFCoreRecord { + pub val: u32, } -pub struct CastFCoreChip { - pub air: CastFCoreAir, - pub range_checker_chip: SharedVariableRangeCheckerChip, +#[derive(derive_new::new, Clone, Copy)] +pub struct CastFCoreExecutor { + adapter: A, } -impl CastFCoreChip { - pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self { - Self { - air: CastFCoreAir { - bus: range_checker_chip.bus(), - }, - range_checker_chip, - } - } +#[derive(derive_new::new)] +pub struct CastFCoreFiller { + adapter: A, + pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl> VmCoreChip for CastFCoreChip +impl PreflightExecutor for CastFCoreExecutor where - I::Reads: Into<[[F; 1]; 1]>, - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut CastFCoreRecord), + >, { - type Record = CastFRecord; - type Air = CastFCoreAir; + fn get_opcode_name(&self, _opcode: usize) -> String { + format!("{:?}", CastfOpcode::CASTF) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = instruction; + ) -> Result<(), ExecutionError> { + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); - assert_eq!( - opcode.local_opcode_idx(CastfOpcode::CLASS_OFFSET), - CastfOpcode::CASTF as usize - ); + A::start(*state.pc, state.memory, &mut adapter_record); - let y = reads.into()[0][0]; - let x = CastF::solve(y.as_canonical_u32()); + core_record.val = self + .adapter + .read(state.memory, instruction, &mut adapter_record)[0] + .as_canonical_u32(); - let output = AdapterRuntimeContext { - to_pc: None, - writes: [x.map(F::from_canonical_u32)].into(), - }; + let x = run_castf(core_record.val); - let record = CastFRecord { - in_val: y, - out_val: x, - }; + self.adapter + .write(state.memory, instruction, x, &mut adapter_record); - Ok((output, record)) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn get_opcode_name(&self, _opcode: usize) -> String { - format!("{:?}", CastfOpcode::CASTF) + Ok(()) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - for (i, limb) in record.out_val.iter().enumerate() { - if i == 3 { - self.range_checker_chip.add_count(*limb, FINAL_LIMB_BITS); +impl TraceFiller for CastFCoreFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &CastFCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut CastFCoreCols<_> = core_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + let out = run_castf(record.val); + for (i, &limb) in out.iter().enumerate() { + let limb_bits = if i == out.len() - 1 { + FINAL_LIMB_BITS } else { - self.range_checker_chip.add_count(*limb, LIMB_BITS); - } + LIMB_BITS + }; + self.range_checker_chip.add_count(limb as u32, limb_bits); + } + core_row.is_valid = F::ONE; + core_row.out_val = out.map(F::from_canonical_u8); + core_row.in_val = F::from_canonical_u32(record.val); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct CastFPreCompute { + a: u32, + b: u32, +} + +impl CastFCoreExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut CastFPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { + a, b, d, e, opcode, .. + } = inst; + + if opcode.local_opcode_idx(CastfOpcode::CLASS_OFFSET) != CastfOpcode::CASTF as usize { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if d.as_canonical_u32() != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); } + if e.as_canonical_u32() != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + *data = CastFPreCompute { a, b }; - let cols: &mut CastFCoreCols = row_slice.borrow_mut(); - cols.in_val = record.in_val; - cols.out_val = record.out_val.map(F::from_canonical_u32); - cols.is_valid = F::ONE; + Ok(()) } +} - fn air(&self) -> &Self::Air { - &self.air +impl Executor for CastFCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut CastFPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl::<_, _>; + + Ok(fn_ptr) } } -pub struct CastF; -impl CastF { - pub(super) fn solve(y: u32) -> [u32; RV32_REGISTER_NUM_LIMBS] { - let mut x = [0; 4]; - for (i, limb) in x.iter_mut().enumerate() { - *limb = (y >> (8 * i)) & 0xFF; - } - x +impl MeteredExecutor for CastFCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_impl::<_, _>; + + Ok(fn_ptr) } } + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &CastFPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &CastFPreCompute, + vm_state: &mut VmExecState, +) { + let y = vm_state.vm_read::(AS::Native as u32, pre_compute.b)[0]; + let x = run_castf(y.as_canonical_u32()); + + vm_state.vm_write::(RV32_MEMORY_AS, pre_compute.a, &x); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +#[inline(always)] +pub(super) fn run_castf(y: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { + debug_assert!(y < 1 << CASTF_MAX_BITS); + y.to_le_bytes() +} diff --git a/extensions/native/circuit/src/castf/mod.rs b/extensions/native/circuit/src/castf/mod.rs index 9fbd77f245..8cbb48900b 100644 --- a/extensions/native/circuit/src/castf/mod.rs +++ b/extensions/native/circuit/src/castf/mod.rs @@ -1,12 +1,13 @@ use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::convert_adapter::{ConvertAdapterAir, ConvertAdapterChip}; - -#[cfg(test)] -mod tests; +use crate::adapters::{ConvertAdapterAir, ConvertAdapterExecutor, ConvertAdapterFiller}; mod core; pub use core::*; +#[cfg(test)] +mod tests; + pub type CastFAir = VmAirWrapper, CastFCoreAir>; -pub type CastFChip = VmChipWrapper, CastFCoreChip>; +pub type CastFExecutor = CastFCoreExecutor>; +pub type CastFChip = VmChipWrapper>>; diff --git a/extensions/native/circuit/src/castf/tests.rs b/extensions/native/circuit/src/castf/tests.rs index 9758e6b956..9801bb235c 100644 --- a/extensions/native/circuit/src/castf/tests.rs +++ b/extensions/native/circuit/src/castf/tests.rs @@ -1,254 +1,222 @@ use std::borrow::BorrowMut; -use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::CastfOpcode; -use openvm_stark_backend::{ - p3_field::FieldAlgebra, utils::disable_debug_builder, verifier::VerificationError, Chip, +use openvm_circuit::arch::{ + testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}, + MemoryConfig, +}; +use openvm_instructions::{ + instruction::Instruction, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, }; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, - p3_baby_bear::BabyBear, utils::create_seeded_rng, +use openvm_native_compiler::{conversion::AS, CastfOpcode}; +use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, }; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{ - super::adapters::convert_adapter::{ConvertAdapterChip, ConvertAdapterCols}, - CastF, CastFChip, CastFCoreChip, CastFCoreCols, FINAL_LIMB_BITS, LIMB_BITS, +use super::{CastFChip, CastFCoreAir, CastFCoreCols, CastFExecutor, LIMB_BITS}; +use crate::{ + adapters::{ + ConvertAdapterAir, ConvertAdapterCols, ConvertAdapterExecutor, ConvertAdapterFiller, + }, + castf::run_castf, + test_utils::write_native_array, + CastFAir, CastFCoreFiller, CASTF_MAX_BITS, }; + +const MAX_INS_CAPACITY: usize = 128; +const READ_SIZE: usize = 1; +const WRITE_SIZE: usize = 4; type F = BabyBear; +type Harness = TestChipHarness>; -fn generate_uint_number(rng: &mut StdRng) -> u32 { - rng.gen_range(0..(1 << 30) - 1) +fn create_test_chip(tester: &VmChipTestBuilder) -> Harness { + let range_checker = tester.range_checker().clone(); + let air = CastFAir::new( + ConvertAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + CastFCoreAir::new(range_checker.bus()), + ); + let executor = CastFExecutor::new(ConvertAdapterExecutor::::new()); + let chip = CastFChip::::new( + CastFCoreFiller::new(ConvertAdapterFiller, range_checker), + tester.memory_helper(), + ); + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) } -fn prepare_castf_rand_write_execute( +fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut CastFChip, - y: u32, + harness: &mut Harness, rng: &mut StdRng, + b: Option, ) { - let operand1 = y; - - let as_x = 2usize; // d - let as_y = 4usize; // e - let address_x = gen_pointer(rng, 32); // a - let address_y = gen_pointer(rng, 32); // b - - let operand1_f = F::from_canonical_u32(y); - - tester.write_cell(as_y, address_y, operand1_f); - let x = CastF::solve(operand1); + let b_val = b.unwrap_or(F::from_canonical_u32(rng.gen_range(0..1 << CASTF_MAX_BITS))); + let b_ptr = write_native_array(tester, rng, Some([b_val])).1; + let a = gen_pointer(rng, RV32_REGISTER_NUM_LIMBS); tester.execute( - chip, + harness, &Instruction::from_usize( CastfOpcode::CASTF.global_opcode(), - [address_x, address_y, 0, as_x, as_y], + [a, b_ptr, 0, RV32_MEMORY_AS as usize, AS::Native as usize], ), ); - assert_eq!( - x.map(F::from_canonical_u32), - tester.read::<4>(as_x, address_x) - ); + let expected = run_castf(b_val.as_canonical_u32()); + let result = tester.read::(RV32_MEMORY_AS as usize, a); + assert_eq!(result.map(|x| x.as_canonical_u32() as u8), expected); } +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// + #[test] fn castf_rand_test() { let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(tester.range_checker()), - tester.offline_memory_mutex_arc(), - ); - let num_tests: usize = 3; + let mut tester = VmChipTestBuilder::volatile(MemoryConfig::default()); + let mut harness = create_test_chip(&tester); + let num_ops = 100; - for _ in 0..num_tests { - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut harness, &mut rng, None); } - let tester = tester.build().load(chip).finalize(); + set_and_execute(&mut tester, &mut harness, &mut rng, Some(F::ZERO)); + + let tester = tester.build().load(harness).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn negative_castf_overflow_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.range_checker(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Clone, Copy, Default)] +struct CastFPrankValues { + pub in_val: Option, + pub out_val: Option<[u32; 4]>, + pub a_pointer: Option, + pub b_pointer: Option, +} +fn run_negative_castf_test(prank_vals: CastFPrankValues, b: Option, error: VerificationError) { let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut CastFCoreCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .1 - .borrow_mut(); - cols.out_val[3] = F::from_canonical_u32(rng.gen_range(1 << FINAL_LIMB_BITS..1 << LIMB_BITS)); - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); + let mut tester = VmChipTestBuilder::volatile(MemoryConfig::default()); + + let mut harness = create_test_chip(&tester); + set_and_execute(&mut tester, &mut harness, &mut rng, b); + + let adapter_width = BaseAir::::width(&harness.air.adapter); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let (adapter_row, core_row) = values.split_at_mut(adapter_width); + let core_cols: &mut CastFCoreCols = core_row.borrow_mut(); + let adapter_cols: &mut ConvertAdapterCols = + adapter_row.borrow_mut(); + + if let Some(in_val) = prank_vals.in_val { + // TODO: in_val is actually never used in the AIR, should remove it + core_cols.in_val = F::from_canonical_u32(in_val); + } + if let Some(out_val) = prank_vals.out_val { + core_cols.out_val = out_val.map(F::from_canonical_u32); + } + if let Some(a_pointer) = prank_vals.a_pointer { + adapter_cols.a_pointer = F::from_canonical_u32(a_pointer); + } + if let Some(b_pointer) = prank_vals.b_pointer { + adapter_cols.b_pointer = F::from_canonical_u32(b_pointer); + } + *trace = RowMajorMatrix::new(values, trace.width()); + }; disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" - ); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); } #[test] -fn negative_castf_memread_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), +fn casf_invalid_out_val_test() { + run_negative_castf_test( + CastFPrankValues { + out_val: Some([2 << LIMB_BITS, 0, 0, 0]), + ..Default::default() + }, + Some(F::from_canonical_u32(2 << LIMB_BITS)), + VerificationError::ChallengePhaseError, ); - let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut ConvertAdapterCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .0 - .borrow_mut(); - cols.b_pointer += F::ONE; - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); - - disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" + let prime = F::NEG_ONE.as_canonical_u32() + 1; + run_negative_castf_test( + CastFPrankValues { + out_val: Some(prime.to_le_bytes().map(|x| x as u32)), + ..Default::default() + }, + Some(F::ZERO), + VerificationError::ChallengePhaseError, ); } #[test] -fn negative_castf_memwrite_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), +fn negative_convert_adapter_test() { + // overflowing the memory pointer + run_negative_castf_test( + CastFPrankValues { + b_pointer: Some(1 << 30), + ..Default::default() + }, + None, + VerificationError::ChallengePhaseError, ); - let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut ConvertAdapterCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .0 - .borrow_mut(); - cols.a_pointer += F::ONE; - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); - - disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" + // Memory address space pointer has to be 4-byte aligned + run_negative_castf_test( + CastFPrankValues { + a_pointer: Some(1), + ..Default::default() + }, + None, + VerificationError::ChallengePhaseError, ); } +#[should_panic] #[test] -fn negative_castf_as_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); - +fn castf_overflow_in_val_test() { let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut ConvertAdapterCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .0 - .borrow_mut(); - cols.a_pointer += F::ONE; - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); + let mut tester = VmChipTestBuilder::volatile(MemoryConfig::default()); + let mut harness = create_test_chip(&tester); + set_and_execute(&mut tester, &mut harness, &mut rng, Some(F::NEG_ONE)); +} - disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" - ); +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn castf_sanity_test() { + let b = 160558167; + let expected = [87, 236, 145, 9]; + assert_eq!(run_castf(b), expected); } diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index 385c9392ac..b6fc08abe0 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -1,18 +1,20 @@ -use air::VerifyBatchBus; -use alu_native_adapter::AluNativeAdapterChip; -use branch_native_adapter::BranchNativeAdapterChip; +use alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterExecutor}; +use branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterExecutor}; +use convert_adapter::{ConvertAdapterAir, ConvertAdapterExecutor}; use derive_more::derive::From; -use loadstore_native_adapter::NativeLoadStoreAdapterChip; -use native_vectorized_adapter::NativeVectorizedAdapterChip; +use fri::{FriReducedOpeningAir, FriReducedOpeningChip, FriReducedOpeningExecutor}; +use jal_rangecheck::{JalRangeCheckAir, JalRangeCheckExecutor}; +use loadstore_native_adapter::{NativeLoadStoreAdapterAir, NativeLoadStoreAdapterExecutor}; +use native_vectorized_adapter::{NativeVectorizedAdapterAir, NativeVectorizedAdapterExecutor}; use openvm_circuit::{ arch::{ - ExecutionBridge, InitFileGenerator, MemoryConfig, SystemConfig, SystemPort, VmExtension, - VmInventory, VmInventoryBuilder, VmInventoryError, + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, }, - system::phantom::PhantomChip, + system::{memory::SharedMemoryHelper, SystemPort}, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_native_compiler::{ CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, @@ -20,185 +22,106 @@ use openvm_native_compiler::{ NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE, }; use openvm_poseidon2_air::Poseidon2Config; -use openvm_rv32im_circuit::{ - BranchEqualCoreChip, Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, - Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, +use openvm_rv32im_circuit::BranchEqualCoreAir; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::{Field, PrimeField32}, + prover::cpu::{CpuBackend, CpuDevice}, }; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_sdk::engine::StarkEngine; use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; use crate::{ - adapters::{convert_adapter::ConvertAdapterChip, *}, - chip::NativePoseidon2Chip, + adapters::*, + air::{NativePoseidon2Air, VerifyBatchBus}, + chip::{NativePoseidon2Executor, NativePoseidon2Filler}, phantom::*, *, }; -#[derive(Clone, Debug, Serialize, Deserialize, VmConfig, derive_new::new)] -pub struct NativeConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub native: Native, -} - -impl NativeConfig { - pub fn aggregation(num_public_values: usize, max_constraint_degree: usize) -> Self { - Self { - system: SystemConfig::new( - max_constraint_degree, - MemoryConfig { - max_access_adapter_n: 8, - ..Default::default() - }, - num_public_values, - ) - .with_max_segment_len((1 << 24) - 100), - native: Default::default(), - } - } -} - -// Default implementation uses no init file -impl InitFileGenerator for NativeConfig {} +// ============ VmExtension Implementations ============ #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Native; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum NativeExecutor { - LoadStore(NativeLoadStoreChip), - BlockLoadStore(NativeLoadStoreChip), - BranchEqual(NativeBranchEqChip), - Jal(JalRangeCheckChip), - FieldArithmetic(FieldArithmeticChip), - FieldExtension(FieldExtensionChip), - FriReducedOpening(FriReducedOpeningChip), - VerifyBatch(NativePoseidon2Chip), +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum NativeExecutor { + LoadStore(NativeLoadStoreExecutor<1>), + BlockLoadStore(NativeLoadStoreExecutor), + BranchEqual(NativeBranchEqExecutor), + Jal(JalRangeCheckExecutor), + FieldArithmetic(FieldArithmeticExecutor), + FieldExtension(FieldExtensionExecutor), + FriReducedOpening(FriReducedOpeningExecutor), + VerifyBatch(NativePoseidon2Executor), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum NativePeriphery { - Phantom(PhantomChip), -} - -impl VmExtension for Native { +impl VmExecutionExtension for Native { type Executor = NativeExecutor; - type Periphery = NativePeriphery; - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, NativePeriphery>, VmInventoryError> { - let mut inventory = VmInventory::new(); - let SystemPort { - execution_bus, - program_bus, - memory_bridge, - } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); - - let mut load_store_chip = NativeLoadStoreChip::::new( - NativeLoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - NativeLoadStoreOpcode::CLASS_OFFSET, - ), - NativeLoadStoreCoreChip::new(NativeLoadStoreOpcode::CLASS_OFFSET), - offline_memory.clone(), + inventory: &mut ExecutorInventoryBuilder>, + ) -> Result<(), ExecutorInventoryError> { + let load_store = NativeLoadStoreExecutor::<1>::new( + NativeLoadStoreAdapterExecutor::new(NativeLoadStoreOpcode::CLASS_OFFSET), + NativeLoadStoreOpcode::CLASS_OFFSET, ); - load_store_chip.core.set_streams(builder.streams().clone()); - inventory.add_executor( - load_store_chip, + load_store, NativeLoadStoreOpcode::iter().map(|x| x.global_opcode()), )?; - let mut block_load_store_chip = NativeLoadStoreChip::::new( - NativeLoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - NativeLoadStore4Opcode::CLASS_OFFSET, - ), - NativeLoadStoreCoreChip::new(NativeLoadStore4Opcode::CLASS_OFFSET), - offline_memory.clone(), + let block_load_store = NativeLoadStoreExecutor::::new( + NativeLoadStoreAdapterExecutor::new(NativeLoadStore4Opcode::CLASS_OFFSET), + NativeLoadStore4Opcode::CLASS_OFFSET, ); - block_load_store_chip - .core - .set_streams(builder.streams().clone()); - inventory.add_executor( - block_load_store_chip, + block_load_store, NativeLoadStore4Opcode::iter().map(|x| x.global_opcode()), )?; - let branch_equal_chip = NativeBranchEqChip::new( - BranchNativeAdapterChip::<_>::new(execution_bus, program_bus, memory_bridge), - BranchEqualCoreChip::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + let branch_equal = NativeBranchEqExecutor::new( + BranchNativeAdapterExecutor::new(), + NativeBranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, ); inventory.add_executor( - branch_equal_chip, + branch_equal, NativeBranchEqualOpcode::iter().map(|x| x.global_opcode()), )?; - let jal_chip = JalRangeCheckChip::new( - ExecutionBridge::new(execution_bus, program_bus), - offline_memory.clone(), - builder.system_base().range_checker_chip.clone(), - ); + let jal_rangecheck = JalRangeCheckExecutor; inventory.add_executor( - jal_chip, + jal_rangecheck, [ NativeJalOpcode::JAL.global_opcode(), NativeRangeCheckOpcode::RANGE_CHECK.global_opcode(), ], )?; - let field_arithmetic_chip = FieldArithmeticChip::new( - AluNativeAdapterChip::::new(execution_bus, program_bus, memory_bridge), - FieldArithmeticCoreChip::new(), - offline_memory.clone(), - ); + let field_arithmetic = FieldArithmeticExecutor::new(AluNativeAdapterExecutor::new()); inventory.add_executor( - field_arithmetic_chip, + field_arithmetic, FieldArithmeticOpcode::iter().map(|x| x.global_opcode()), )?; - let field_extension_chip = FieldExtensionChip::new( - NativeVectorizedAdapterChip::new(execution_bus, program_bus, memory_bridge), - FieldExtensionCoreChip::new(), - offline_memory.clone(), - ); + let field_extension = FieldExtensionExecutor::new(NativeVectorizedAdapterExecutor::new()); inventory.add_executor( - field_extension_chip, + field_extension, FieldExtensionOpcode::iter().map(|x| x.global_opcode()), )?; - let fri_reduced_opening_chip = FriReducedOpeningChip::new( - execution_bus, - program_bus, - memory_bridge, - offline_memory.clone(), - builder.streams().clone(), - ); + let fri_reduced_opening = FriReducedOpeningExecutor::new(); inventory.add_executor( - fri_reduced_opening_chip, + fri_reduced_opening, FriOpcode::iter().map(|x| x.global_opcode()), )?; - let poseidon2_chip = NativePoseidon2Chip::new( - builder.system_port(), - offline_memory.clone(), - Poseidon2Config::default(), - VerifyBatchBus::new(builder.new_bus_idx()), - builder.streams().clone(), - ); + let verify_batch = NativePoseidon2Executor::::new(Poseidon2Config::default()); inventory.add_executor( - poseidon2_chip, + verify_batch, [ VerifyBatchOpcode::VERIFY_BATCH.global_opcode(), Poseidon2Opcode::PERM_POS2.global_opcode(), @@ -206,32 +129,180 @@ impl VmExtension for Native { ], )?; - builder.add_phantom_sub_executor( + inventory.add_phantom_sub_executor( NativeHintInputSubEx, PhantomDiscriminant(NativePhantom::HintInput as u16), )?; - builder.add_phantom_sub_executor( + inventory.add_phantom_sub_executor( NativeHintSliceSubEx::<1>, PhantomDiscriminant(NativePhantom::HintFelt as u16), )?; - builder.add_phantom_sub_executor( + inventory.add_phantom_sub_executor( NativeHintBitsSubEx, PhantomDiscriminant(NativePhantom::HintBits as u16), )?; - builder.add_phantom_sub_executor( + inventory.add_phantom_sub_executor( NativePrintSubEx, PhantomDiscriminant(NativePhantom::Print as u16), )?; - builder.add_phantom_sub_executor( + inventory.add_phantom_sub_executor( NativeHintLoadSubEx, PhantomDiscriminant(NativePhantom::HintLoad as u16), )?; - Ok(inventory) + Ok(()) + } +} + +impl VmCircuitExtension for Native +where + Val: PrimeField32, +{ + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = inventory.system().port(); + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = inventory.range_checker().bus; + + let load_store = NativeLoadStoreAir::<1>::new( + NativeLoadStoreAdapterAir::new(memory_bridge, exec_bridge), + NativeLoadStoreCoreAir::new(NativeLoadStoreOpcode::CLASS_OFFSET), + ); + inventory.add_air(load_store); + + let block_load_store = NativeLoadStoreAir::::new( + NativeLoadStoreAdapterAir::new(memory_bridge, exec_bridge), + NativeLoadStoreCoreAir::new(NativeLoadStore4Opcode::CLASS_OFFSET), + ); + inventory.add_air(block_load_store); + + let branch_equal = NativeBranchEqAir::new( + BranchNativeAdapterAir::new(exec_bridge, memory_bridge), + BranchEqualCoreAir::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ); + inventory.add_air(branch_equal); + + let jal_rangecheck = JalRangeCheckAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + range_checker, + ); + inventory.add_air(jal_rangecheck); + + let field_arithmetic = FieldArithmeticAir::new( + AluNativeAdapterAir::new(exec_bridge, memory_bridge), + FieldArithmeticCoreAir::new(), + ); + inventory.add_air(field_arithmetic); + + let field_extension = FieldExtensionAir::new( + NativeVectorizedAdapterAir::new(exec_bridge, memory_bridge), + FieldExtensionCoreAir::new(), + ); + inventory.add_air(field_extension); + + let fri_reduced_opening = FriReducedOpeningAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ); + inventory.add_air(fri_reduced_opening); + + let verify_batch = NativePoseidon2Air::<_, 1>::new( + exec_bridge, + memory_bridge, + VerifyBatchBus::new(inventory.new_bus_idx()), + Poseidon2Config::default(), + ); + inventory.add_air(verify_batch); + + Ok(()) + } +} + +pub struct NativeCpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for NativeCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Native, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + + // These calls to next_air are not strictly necessary to construct the chips, but provide a + // safeguard to ensure that chip construction matches the circuit definition + inventory.next_air::>()?; + let load_store = NativeLoadStoreChip::<_, 1>::new( + NativeLoadStoreCoreFiller::new(NativeLoadStoreAdapterFiller), + mem_helper.clone(), + ); + inventory.add_executor_chip(load_store); + + inventory.next_air::>()?; + let block_load_store = NativeLoadStoreChip::<_, BLOCK_LOAD_STORE_SIZE>::new( + NativeLoadStoreCoreFiller::new(NativeLoadStoreAdapterFiller), + mem_helper.clone(), + ); + inventory.add_executor_chip(block_load_store); + + inventory.next_air::()?; + let branch_eq = NativeBranchEqChip::new( + NativeBranchEqualFiller::new(BranchNativeAdapterFiller), + mem_helper.clone(), + ); + + inventory.add_executor_chip(branch_eq); + + inventory.next_air::()?; + let jal_rangecheck = NativeJalRangeCheckChip::new( + JalRangeCheckFiller::new(range_checker.clone()), + mem_helper.clone(), + ); + inventory.add_executor_chip(jal_rangecheck); + + inventory.next_air::()?; + let field_arithmetic = FieldArithmeticChip::new( + FieldArithmeticCoreFiller::new(AluNativeAdapterFiller), + mem_helper.clone(), + ); + inventory.add_executor_chip(field_arithmetic); + + inventory.next_air::()?; + let field_extension = FieldExtensionChip::new( + FieldExtensionCoreFiller::new(NativeVectorizedAdapterFiller), + mem_helper.clone(), + ); + inventory.add_executor_chip(field_extension); + + inventory.next_air::()?; + let fri_reduced_opening = + FriReducedOpeningChip::new(FriReducedOpeningFiller::new(), mem_helper.clone()); + inventory.add_executor_chip(fri_reduced_opening); + + inventory.next_air::, 1>>()?; + let poseidon2 = NativePoseidon2Chip::<_, 1>::new( + NativePoseidon2Filler::new(Poseidon2Config::default()), + mem_helper.clone(), + ); + inventory.add_executor_chip(poseidon2); + + Ok(()) } } @@ -239,10 +310,11 @@ pub(crate) mod phantom { use eyre::bail; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::PhantomDiscriminant; use openvm_stark_backend::p3_field::{Field, PrimeField32}; + use rand::rngs::StdRng; pub struct NativeHintInputSubEx; pub struct NativeHintSliceSubEx; @@ -252,12 +324,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintInputSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let hint = match streams.input_stream.pop_front() { @@ -277,12 +350,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintSliceSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let hint = match streams.input_stream.pop_front() { @@ -300,36 +374,35 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativePrintSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, _: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { - let addr_space = F::from_canonical_u16(c_upper); - let value = memory.unsafe_read_cell(addr_space, a); - println!("{}", value); + let [value] = unsafe { memory.read::(c_upper as u32, a) }; + println!("{value}"); Ok(()) } } impl PhantomSubExecutor for NativeHintBitsSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + len: u32, c_upper: u16, ) -> eyre::Result<()> { - let addr_space = F::from_canonical_u16(c_upper); - let val = memory.unsafe_read_cell(addr_space, a); + let [val] = unsafe { memory.read::(c_upper as u32, a) }; let mut val = val.as_canonical_u32(); - let len = b.as_canonical_u32(); assert!(streams.hint_stream.is_empty()); for _ in 0..len { streams @@ -343,12 +416,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintLoadSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let payload = match streams.input_stream.pop_front() { @@ -370,72 +444,74 @@ pub(crate) mod phantom { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct CastFExtension; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum CastFExtensionExecutor { - CastF(CastFChip), +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum CastFExtensionExecutor { + CastF(CastFExecutor), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum CastFExtensionPeriphery { - Placeholder(CastFChip), -} +impl VmExecutionExtension for CastFExtension { + type Executor = CastFExtensionExecutor; -impl VmExtension for CastFExtension { - type Executor = CastFExtensionExecutor; - type Periphery = CastFExtensionPeriphery; - - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let castf = CastFExecutor::new(ConvertAdapterExecutor::new()); + inventory.add_executor(castf, [CastfOpcode::CASTF.global_opcode()])?; + Ok(()) + } +} + +impl VmCircuitExtension for CastFExtension { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { let SystemPort { execution_bus, program_bus, memory_bridge, - } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - - let castf_chip = CastFChip::new( - ConvertAdapterChip::new(execution_bus, program_bus, memory_bridge), - CastFCoreChip::new(range_checker.clone()), - offline_memory.clone(), - ); - inventory.add_executor(castf_chip, [CastfOpcode::CASTF.global_opcode()])?; + } = inventory.system().port(); + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = inventory.range_checker().bus; - Ok(inventory) + let castf = CastFAir::new( + ConvertAdapterAir::new(exec_bridge, memory_bridge), + CastFCoreAir::new(range_checker), + ); + inventory.add_air(castf); + Ok(()) } } -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Rv32WithKernelsConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub rv32i: Rv32I, - #[extension] - pub rv32m: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub native: Native, - #[extension] - pub castf: CastFExtension, -} +impl VmProverExtension for NativeCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &CastFExtension, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + + inventory.next_air::()?; + let castf = CastFChip::new( + CastFCoreFiller::new(ConvertAdapterFiller::new(), range_checker), + mem_helper.clone(), + ); + inventory.add_executor_chip(castf); -impl Default for Rv32WithKernelsConfig { - fn default() -> Self { - Self { - system: SystemConfig::default().with_continuations(), - rv32i: Rv32I, - rv32m: Rv32M::default(), - io: Rv32Io, - native: Native, - castf: CastFExtension, - } + Ok(()) } } -// Default implementation uses no init file -impl InitFileGenerator for Rv32WithKernelsConfig {} +// Pre-computed maximum trace heights for NativeConfig. Found by doubling +// the actual trace heights of kitchen-sink leaf verification (except for +// VariableRangeChecker, which has a fixed height). +pub const NATIVE_MAX_TRACE_HEIGHTS: &[u32] = &[ + 4194304, 4, 128, 2097152, 8388608, 4194304, 262144, 2097152, 16777216, 2097152, 8388608, + 262144, 2097152, 1048576, 4194304, 65536, 262144, +]; diff --git a/extensions/native/circuit/src/field_arithmetic/core.rs b/extensions/native/circuit/src/field_arithmetic/core.rs index c813f6a066..c144c44e97 100644 --- a/extensions/native/circuit/src/field_arithmetic/core.rs +++ b/extensions/native/circuit/src/field_arithmetic/core.rs @@ -1,20 +1,29 @@ use std::borrow::{Borrow, BorrowMut}; use itertools::izip; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, + utils::{transmute_field_to_u32, transmute_u32_to_field}, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldArithmeticOpcode::{self, *}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_IMM_AS, LocalOpcode, +}; +use openvm_native_compiler::{ + conversion::AS, + FieldArithmeticOpcode::{self, *}, +}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; #[repr(C)] #[derive(AlignedBorrow)] @@ -31,7 +40,7 @@ pub struct FieldArithmeticCoreCols { pub divisor_inv: T, } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct FieldArithmeticCoreAir {} impl BaseAir for FieldArithmeticCoreAir { @@ -106,120 +115,402 @@ where } #[repr(C)] -#[derive(Debug, Serialize, Deserialize)] +#[derive(AlignedBytesBorrow, Debug)] pub struct FieldArithmeticRecord { - pub opcode: FieldArithmeticOpcode, - pub a: F, pub b: F, pub c: F, + pub local_opcode: u8, } -pub struct FieldArithmeticCoreChip { - pub air: FieldArithmeticCoreAir, +#[derive(derive_new::new, Clone, Copy)] +pub struct FieldArithmeticCoreExecutor { + adapter: A, } -impl FieldArithmeticCoreChip { - pub fn new() -> Self { - Self { - air: FieldArithmeticCoreAir {}, - } - } +#[derive(derive_new::new)] +pub struct FieldArithmeticCoreFiller { + adapter: A, } -impl Default for FieldArithmeticCoreChip { - fn default() -> Self { - Self::new() +impl PreflightExecutor for FieldArithmeticCoreExecutor +where + F: PrimeField32, + A: 'static + AdapterTraceExecutor, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut FieldArithmeticRecord), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET) + ) + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { opcode, .. } = instruction; + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + [core_record.b, core_record.c] = + self.adapter + .read(state.memory, instruction, &mut adapter_record); + + core_record.local_opcode = + opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET) as u8; + + let opcode = FieldArithmeticOpcode::from_usize(core_record.local_opcode as usize); + let a_val = run_field_arithmetic(opcode, core_record.b, core_record.c); + + self.adapter + .write(state.memory, instruction, [a_val], &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } } -impl> VmCoreChip for FieldArithmeticCoreChip +impl TraceFiller for FieldArithmeticCoreFiller where - I::Reads: Into<[[F; 1]; 2]>, - I::Writes: From<[[F; 1]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = FieldArithmeticRecord; - type Air = FieldArithmeticCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &FieldArithmeticRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut FieldArithmeticCoreCols<_> = core_row.borrow_mut(); + + let opcode = FieldArithmeticOpcode::from_usize(record.local_opcode as usize); + let result = run_field_arithmetic(opcode, record.b, record.c); + + // Writing in reverse order to avoid overwriting the `record` + core_row.divisor_inv = if opcode == FieldArithmeticOpcode::DIV { + record.c.inverse() + } else { + F::ZERO + }; + + core_row.is_div = F::from_bool(opcode == FieldArithmeticOpcode::DIV); + core_row.is_mul = F::from_bool(opcode == FieldArithmeticOpcode::MUL); + core_row.is_sub = F::from_bool(opcode == FieldArithmeticOpcode::SUB); + core_row.is_add = F::from_bool(opcode == FieldArithmeticOpcode::ADD); + + core_row.c = record.c; + core_row.b = record.b; + core_row.a = result; + } +} - #[allow(clippy::type_complexity)] - fn execute_instruction( +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FieldArithmeticPreCompute { + a: u32, + b_or_imm: u32, + c_or_imm: u32, + e: u32, + f: u32, +} + +impl FieldArithmeticCoreExecutor { + #[inline(always)] + fn pre_compute_impl( &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = instruction; + _pc: u32, + inst: &Instruction, + data: &mut FieldArithmeticPreCompute, + ) -> Result<(bool, bool, FieldArithmeticOpcode), StaticProgramError> { + let &Instruction { + opcode, + a, + b, + c, + e, + f, + .. + } = inst; + let local_opcode = FieldArithmeticOpcode::from_usize( opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET), ); - let data: [[F; 1]; 2] = reads.into(); - let b = data[0][0]; - let c = data[1][0]; - let a = FieldArithmetic::run_field_arithmetic(local_opcode, b, c).unwrap(); + let a = a.as_canonical_u32(); + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); - let output: AdapterRuntimeContext = AdapterRuntimeContext { - to_pc: None, - writes: [[a]].into(), + let a_is_imm = e == RV32_IMM_AS; + let b_is_imm = f == RV32_IMM_AS; + + let b_or_imm = if a_is_imm { + transmute_field_to_u32(&b) + } else { + b.as_canonical_u32() + }; + let c_or_imm = if b_is_imm { + transmute_field_to_u32(&c) + } else { + c.as_canonical_u32() }; - let record = Self::Record { - opcode: local_opcode, + *data = FieldArithmeticPreCompute { a, - b, - c, + b_or_imm, + c_or_imm, + e, + f, }; - Ok((output, record)) + Ok((a_is_imm, b_is_imm, local_opcode)) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET) - ) +impl Executor for FieldArithmeticCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let FieldArithmeticRecord { opcode, a, b, c } = record; - let row_slice: &mut FieldArithmeticCoreCols<_> = row_slice.borrow_mut(); - row_slice.a = a; - row_slice.b = b; - row_slice.c = c; - - row_slice.is_add = F::from_bool(opcode == FieldArithmeticOpcode::ADD); - row_slice.is_sub = F::from_bool(opcode == FieldArithmeticOpcode::SUB); - row_slice.is_mul = F::from_bool(opcode == FieldArithmeticOpcode::MUL); - row_slice.is_div = F::from_bool(opcode == FieldArithmeticOpcode::DIV); - row_slice.divisor_inv = if opcode == FieldArithmeticOpcode::DIV { - c.inverse() - } else { - F::ZERO + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut FieldArithmeticPreCompute = data.borrow_mut(); + + let (a_is_imm, b_is_imm, local_opcode) = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match (local_opcode, a_is_imm, b_is_imm) { + (FieldArithmeticOpcode::ADD, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::SUB, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::MUL, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::DIV, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::DIV as u8 }> + } }; + + Ok(fn_ptr) } +} - fn air(&self) -> &Self::Air { - &self.air +impl MeteredExecutor for FieldArithmeticCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } + + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let (a_is_imm, b_is_imm, local_opcode) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = match (local_opcode, a_is_imm, b_is_imm) { + (FieldArithmeticOpcode::ADD, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::SUB, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::MUL, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::DIV, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::DIV as u8 }> + } + }; + + Ok(fn_ptr) + } +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const OPCODE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &FieldArithmeticPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const OPCODE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); } -pub struct FieldArithmetic; -impl FieldArithmetic { - pub(super) fn run_field_arithmetic( - opcode: FieldArithmeticOpcode, - b: F, - c: F, - ) -> Option { - match opcode { - FieldArithmeticOpcode::ADD => Some(b + c), - FieldArithmeticOpcode::SUB => Some(b - c), - FieldArithmeticOpcode::MUL => Some(b * c), - FieldArithmeticOpcode::DIV => { - if c.is_zero() { - None - } else { - Some(b * c.inverse()) - } +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const OPCODE: u8, +>( + pre_compute: &FieldArithmeticPreCompute, + vm_state: &mut VmExecState, +) { + // Read values based on the adapter logic + let b_val = if A_IS_IMM { + transmute_u32_to_field(&pre_compute.b_or_imm) + } else { + vm_state.vm_read::(pre_compute.e, pre_compute.b_or_imm)[0] + }; + let c_val = if B_IS_IMM { + transmute_u32_to_field(&pre_compute.c_or_imm) + } else { + vm_state.vm_read::(pre_compute.f, pre_compute.c_or_imm)[0] + }; + + let a_val = match OPCODE { + 0 => b_val + c_val, // ADD + 1 => b_val - c_val, // SUB + 2 => b_val * c_val, // MUL + 3 => { + // DIV + if c_val.is_zero() { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "DivF divide by zero", + }); + return; } + b_val * c_val.inverse() + } + _ => panic!("Invalid field arithmetic opcode: {OPCODE}"), + }; + + vm_state.vm_write::(AS::Native as u32, pre_compute.a, &[a_val]); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +pub(super) fn run_field_arithmetic(opcode: FieldArithmeticOpcode, b: F, c: F) -> F { + match opcode { + FieldArithmeticOpcode::ADD => b + c, + FieldArithmeticOpcode::SUB => b - c, + FieldArithmeticOpcode::MUL => b * c, + FieldArithmeticOpcode::DIV => { + assert!(!c.is_zero(), "Division by zero"); + b * c.inverse() } } } diff --git a/extensions/native/circuit/src/field_arithmetic/mod.rs b/extensions/native/circuit/src/field_arithmetic/mod.rs index 865434cb37..e861dda42f 100644 --- a/extensions/native/circuit/src/field_arithmetic/mod.rs +++ b/extensions/native/circuit/src/field_arithmetic/mod.rs @@ -1,6 +1,6 @@ use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use crate::adapters::alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterChip}; +use crate::adapters::{AluNativeAdapterAir, AluNativeAdapterExecutor, AluNativeAdapterFiller}; #[cfg(test)] mod tests; @@ -9,5 +9,6 @@ mod core; pub use core::*; pub type FieldArithmeticAir = VmAirWrapper; +pub type FieldArithmeticExecutor = FieldArithmeticCoreExecutor; pub type FieldArithmeticChip = - VmChipWrapper, FieldArithmeticCoreChip>; + VmChipWrapper>; diff --git a/extensions/native/circuit/src/field_arithmetic/tests.rs b/extensions/native/circuit/src/field_arithmetic/tests.rs index 8e69f8c44b..06e0837d14 100644 --- a/extensions/native/circuit/src/field_arithmetic/tests.rs +++ b/extensions/native/circuit/src/field_arithmetic/tests.rs @@ -1,184 +1,254 @@ use std::borrow::BorrowMut; -use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; +use openvm_circuit::arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldArithmeticOpcode; +use openvm_native_compiler::{conversion::AS, FieldArithmeticOpcode}; use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, verifier::VerificationError, - Chip, }; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, - p3_baby_bear::BabyBear, utils::create_seeded_rng, -}; -use rand::Rng; -use strum::EnumCount; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::{ - core::FieldArithmeticCoreChip, FieldArithmetic, FieldArithmeticChip, FieldArithmeticCoreCols, + FieldArithmeticChip, FieldArithmeticCoreAir, FieldArithmeticCoreCols, FieldArithmeticExecutor, +}; +use crate::{ + adapters::{AluNativeAdapterAir, AluNativeAdapterExecutor, AluNativeAdapterFiller}, + field_arithmetic::{run_field_arithmetic, FieldArithmeticAir}, + test_utils::write_native_or_imm, + FieldArithmeticCoreFiller, }; -use crate::adapters::alu_native_adapter::{AluNativeAdapterChip, AluNativeAdapterCols}; -#[test] -fn new_field_arithmetic_air_test() { - let num_ops = 3; // non-power-of-2 to also test padding - let elem_range = || 1..=100; - let xy_address_space_range = || 0usize..=1; - - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; +type Harness = + TestChipHarness>; + +fn create_test_chip(tester: &VmChipTestBuilder) -> Harness { + let air = FieldArithmeticAir::new( + AluNativeAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + FieldArithmeticCoreAir::new(), + ); + let executor = FieldArithmeticExecutor::new(AluNativeAdapterExecutor::new()); + let chip = FieldArithmeticChip::::new( + FieldArithmeticCoreFiller::new(AluNativeAdapterFiller), + tester.memory_helper(), + ); + + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: FieldArithmeticOpcode, + b: Option, + c: Option, +) { + let b_val = b.unwrap_or(rng.gen()); + let c_val = c.unwrap_or(if opcode == FieldArithmeticOpcode::DIV { + // If division, make sure c is not zero + F::from_canonical_u32(rng.gen_range(0..F::NEG_ONE.as_canonical_u32()) + 1) + } else { + rng.gen() + }); + assert!(!c_val.is_zero(), "Division by zero"); + let (b, b_as) = write_native_or_imm(tester, rng, b_val, None); + let (c, c_as) = write_native_or_imm(tester, rng, c_val, None); + let a = gen_pointer(rng, 1); + + tester.execute( + harness, + &Instruction::new( + opcode.global_opcode(), + F::from_canonical_usize(a), + b, + c, + F::from_canonical_usize(AS::Native as usize), + F::from_canonical_usize(b_as), + F::from_canonical_usize(c_as), + F::ZERO, ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), ); + let expected = run_field_arithmetic(opcode, b_val, c_val); + let result = tester.read::<1>(AS::Native as usize, a)[0]; + assert_eq!(result, expected); +} + +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(FieldArithmeticOpcode::ADD, 100)] +#[test_case(FieldArithmeticOpcode::SUB, 100)] +#[test_case(FieldArithmeticOpcode::MUL, 100)] +#[test_case(FieldArithmeticOpcode::DIV, 100)] +fn new_field_arithmetic_air_test(opcode: FieldArithmeticOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); for _ in 0..num_ops { - let opcode = - FieldArithmeticOpcode::from_usize(rng.gen_range(0..FieldArithmeticOpcode::COUNT)); + set_and_execute(&mut tester, &mut harness, &mut rng, opcode, None, None); + } - let operand1 = BabyBear::from_canonical_u32(rng.gen_range(elem_range())); - let operand2 = BabyBear::from_canonical_u32(rng.gen_range(elem_range())); + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(F::ZERO), + None, + ); - if opcode == FieldArithmeticOpcode::DIV && operand2.is_zero() { - continue; - } + let tester = tester.build().load(harness).finalize(); + tester.simple_test().expect("Verification failed"); +} - let result_as = 4usize; - let as1 = rng.gen_range(xy_address_space_range()) * 4; - let as2 = rng.gen_range(xy_address_space_range()) * 4; - let address1 = if as1 == 0 { - operand1.as_canonical_u32() as usize - } else { - gen_pointer(&mut rng, 1) - }; - let address2 = if as2 == 0 { - operand2.as_canonical_u32() as usize - } else { - gen_pointer(&mut rng, 1) - }; - assert_ne!(address1, address2); - let result_address = gen_pointer(&mut rng, 1); - - let result = FieldArithmetic::run_field_arithmetic(opcode, operand1, operand2).unwrap(); - tracing::debug!( - "{opcode:?} d = {}, e = {}, f = {}, result_addr = {}, addr1 = {}, addr2 = {}, z = {}, x = {}, y = {}", - result_as, as1, as2, result_address, address1, address2, result, operand1, operand2, - ); - - if as1 != 0 { - tester.write_cell(as1, address1, operand1); - } - if as2 != 0 { - tester.write_cell(as2, address2, operand2); - } - tester.execute( - &mut chip, - &Instruction::from_usize( - opcode.global_opcode(), - [result_address, address1, address2, result_as, as1, as2], - ), - ); - assert_eq!(result, tester.read_cell(result_as, result_address)); - } +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// - let mut tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); +#[derive(Default)] +struct FieldExpressionPrankVals { + a: Option, + b: Option, + c: Option, + opcode_flags: Option<[bool; 4]>, + divisor_inv: Option, +} +#[allow(clippy::too_many_arguments)] +fn run_negative_field_arithmetic_test( + opcode: FieldArithmeticOpcode, + b: F, + c: F, + prank_vals: FieldExpressionPrankVals, + error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); - disable_debug_builder(); - // negative test pranking each IO value - for height in 0..num_ops { - // TODO: better way to modify existing traces in tester - let arith_trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let old_trace = arith_trace.clone(); - for width in 0..FieldArithmeticCoreCols::::width() { - let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); - arith_trace.row_mut(height)[width] = prank_value; - } + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(b), + Some(c), + ); - // Run a test after pranking each row - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" - ); + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut FieldArithmeticCoreCols = + values.split_at_mut(adapter_width).1.borrow_mut(); + if let Some(a) = prank_vals.a { + cols.a = a; + } + if let Some(b) = prank_vals.b { + cols.b = b; + } + if let Some(c) = prank_vals.c { + cols.c = c; + } + if let Some(opcode_flags) = prank_vals.opcode_flags { + [cols.is_add, cols.is_sub, cols.is_mul, cols.is_div] = opcode_flags.map(F::from_bool); + } + if let Some(divisor_inv) = prank_vals.divisor_inv { + cols.divisor_inv = divisor_inv; + } + *trace = RowMajorMatrix::new(values, trace.width()); + }; - tester.air_proof_inputs[2].1.raw.common_main = Some(old_trace); - } + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); } #[test] -fn new_field_arithmetic_air_zero_div_zero() { - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), +fn field_arithmetic_negative_zero_div_test() { + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::from_canonical_u32(111), + F::from_canonical_u32(222), + FieldExpressionPrankVals { + b: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, ); - tester.write_cell(4, 6, BabyBear::from_canonical_u32(111)); - tester.write_cell(4, 7, BabyBear::from_canonical_u32(222)); - tester.execute( - &mut chip, - &Instruction::from_usize( - FieldArithmeticOpcode::DIV.global_opcode(), - [5, 6, 7, 4, 4, 4], - ), + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::ZERO, + F::TWO, + FieldExpressionPrankVals { + c: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, ); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - // set the value of [c]_f to zero, necessary to bypass trace gen checks - let row = chip_input.raw.common_main.as_mut().unwrap().row_mut(0); - let cols: &mut FieldArithmeticCoreCols = row - .split_at_mut(AluNativeAdapterCols::::width()) - .1 - .borrow_mut(); - cols.b = BabyBear::ZERO; - disable_debug_builder(); + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::ZERO, + F::TWO, + FieldExpressionPrankVals { + c: Some(F::ZERO), + opcode_flags: Some([false, false, true, false]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); +} - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast(vec![chip_air], vec![chip_input]).err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" +#[test] +fn field_arithmetic_negative_rand() { + let mut rng = create_seeded_rng(); + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::from_canonical_u32(111), + F::from_canonical_u32(222), + FieldExpressionPrankVals { + a: Some(rng.gen()), + b: Some(rng.gen()), + c: Some(rng.gen()), + opcode_flags: Some([rng.gen(), rng.gen(), rng.gen(), rng.gen()]), + divisor_inv: Some(rng.gen()), + }, + VerificationError::OodEvaluationMismatch, ); } #[should_panic] #[test] fn new_field_arithmetic_air_test_panic() { - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), - ); - tester.write_cell(4, 0, BabyBear::ZERO); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); + tester.write(4, 0, [BabyBear::ZERO]); // should panic tester.execute( - &mut chip, + &mut harness, &Instruction::from_usize( FieldArithmeticOpcode::DIV.global_opcode(), [0, 0, 0, 4, 4, 4], diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index d8c83fabdd..3da9b69a92 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -5,20 +5,26 @@ use std::{ }; use itertools::izip; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldExtensionOpcode::{self, *}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_native_compiler::{ + conversion::AS, + FieldExtensionOpcode::{self, *}, +}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; pub const BETA: usize = 11; pub const EXT_DEG: usize = 4; @@ -34,11 +40,11 @@ pub struct FieldExtensionCoreCols { pub is_sub: T, pub is_mul: T, pub is_div: T, - /// `divisor_inv` is y.inverse() when opcode is FDIV and zero otherwise. + /// `divisor_inv` is z.inverse() when opcode is FDIV and zero otherwise. pub divisor_inv: [T; EXT_DEG], } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct FieldExtensionCoreAir {} impl BaseAir for FieldExtensionCoreAir { @@ -78,8 +84,8 @@ where // - Each flag in `flags` is a boolean. // - Exactly one flag in `flags` is true. // - The inner product of the `flags` and `opcodes` equals `io.opcode`. - // - The inner product of the `flags` and `results[:,j]` equals `io.z[j]` for each `j`. - // - If `is_div` is true, then `aux.divisor_inv` correctly represents the inverse of `io.y`. + // - The inner product of the `flags` and `results[:,j]` equals `io.x[j]` for each `j`. + // - If `is_div` is true, then `aux.divisor_inv` correctly represents the inverse of `io.z`. let mut is_valid = AB::Expr::ZERO; let mut expected_opcode = AB::Expr::ZERO; @@ -133,116 +139,278 @@ where } #[repr(C)] -#[derive(Debug, Serialize, Deserialize)] +#[derive(AlignedBytesBorrow, Debug)] pub struct FieldExtensionRecord { - pub opcode: FieldExtensionOpcode, - pub x: [F; EXT_DEG], pub y: [F; EXT_DEG], pub z: [F; EXT_DEG], + pub local_opcode: u8, } -pub struct FieldExtensionCoreChip { - pub air: FieldExtensionCoreAir, +#[derive(derive_new::new, Clone, Copy)] +pub struct FieldExtensionCoreExecutor { + adapter: A, } -impl FieldExtensionCoreChip { - pub fn new() -> Self { - Self { - air: FieldExtensionCoreAir {}, - } - } +#[derive(derive_new::new)] +pub struct FieldExtensionCoreFiller { + adapter: A, } -impl Default for FieldExtensionCoreChip { - fn default() -> Self { - Self::new() +impl PreflightExecutor for FieldExtensionCoreExecutor +where + F: PrimeField32, + A: 'static + AdapterTraceExecutor, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut FieldExtensionRecord), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET) + ) + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { opcode, .. } = instruction; + + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.local_opcode = + opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET) as u8; + + [core_record.y, core_record.z] = + self.adapter + .read(state.memory, instruction, &mut adapter_record); + + let x = run_field_extension( + FieldExtensionOpcode::from_usize(core_record.local_opcode as usize), + core_record.y, + core_record.z, + ); + + self.adapter + .write(state.memory, instruction, x, &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } } -impl> VmCoreChip for FieldExtensionCoreChip +impl TraceFiller for FieldExtensionCoreFiller where - I::Reads: Into<[[F; EXT_DEG]; 2]>, - I::Writes: From<[[F; EXT_DEG]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = FieldExtensionRecord; - type Air = FieldExtensionCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &FieldExtensionRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut FieldExtensionCoreCols<_> = core_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + let opcode = FieldExtensionOpcode::from_usize(record.local_opcode as usize); + if opcode == FieldExtensionOpcode::BBE4DIV { + core_row.divisor_inv = FieldExtension::invert(record.z); + } else { + core_row.divisor_inv = [F::ZERO; EXT_DEG]; + } - #[allow(clippy::type_complexity)] - fn execute_instruction( + core_row.is_div = F::from_bool(opcode == FieldExtensionOpcode::BBE4DIV); + core_row.is_mul = F::from_bool(opcode == FieldExtensionOpcode::BBE4MUL); + core_row.is_sub = F::from_bool(opcode == FieldExtensionOpcode::FE4SUB); + core_row.is_add = F::from_bool(opcode == FieldExtensionOpcode::FE4ADD); + + core_row.z = record.z; + core_row.y = record.y; + core_row.x = run_field_extension(opcode, core_row.y, core_row.z); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FieldExtensionPreCompute { + a: u32, + b: u32, + c: u32, +} + +impl FieldExtensionCoreExecutor { + #[inline(always)] + fn pre_compute_impl( &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET); + pc: u32, + inst: &Instruction, + data: &mut FieldExtensionPreCompute, + ) -> Result { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + let local_opcode = FieldExtensionOpcode::from_usize( + opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET), + ); + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + if d != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } - let data: [[F; EXT_DEG]; 2] = reads.into(); - let y: [F; EXT_DEG] = data[0]; - let z: [F; EXT_DEG] = data[1]; + *data = FieldExtensionPreCompute { a, b, c }; - let x = FieldExtension::solve(FieldExtensionOpcode::from_usize(local_opcode_idx), y, z) - .unwrap(); + Ok(local_opcode as u8) + } +} - let output = AdapterRuntimeContext { - to_pc: None, - writes: [x].into(), - }; +impl Executor for FieldExtensionCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } - let record = Self::Record { - opcode: FieldExtensionOpcode::from_usize(local_opcode_idx), - x, - y, - z, + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut FieldExtensionPreCompute = data.borrow_mut(); + + let opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match opcode { + 0 => execute_e1_impl::<_, _, 0>, // FE4ADD + 1 => execute_e1_impl::<_, _, 1>, // FE4SUB + 2 => execute_e1_impl::<_, _, 2>, // BBE4MUL + 3 => execute_e1_impl::<_, _, 3>, // BBE4DIV + _ => panic!("Invalid field extension opcode: {opcode}"), }; - Ok((output, record)) + Ok(fn_ptr) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET) - ) +impl MeteredExecutor for FieldExtensionCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let FieldExtensionRecord { opcode, x, y, z } = record; - let cols: &mut FieldExtensionCoreCols<_> = row_slice.borrow_mut(); - cols.x = x; - cols.y = y; - cols.z = z; - cols.is_add = F::from_bool(opcode == FieldExtensionOpcode::FE4ADD); - cols.is_sub = F::from_bool(opcode == FieldExtensionOpcode::FE4SUB); - cols.is_mul = F::from_bool(opcode == FieldExtensionOpcode::BBE4MUL); - cols.is_div = F::from_bool(opcode == FieldExtensionOpcode::BBE4DIV); - cols.divisor_inv = if opcode == FieldExtensionOpcode::BBE4DIV { - FieldExtension::invert(z) - } else { - [F::ZERO; EXT_DEG] + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = match opcode { + 0 => execute_e2_impl::<_, _, 0>, // FE4ADD + 1 => execute_e2_impl::<_, _, 1>, // FE4SUB + 2 => execute_e2_impl::<_, _, 2>, // BBE4MUL + 3 => execute_e2_impl::<_, _, 3>, // BBE4DIV + _ => panic!("Invalid field extension opcode: {opcode}"), }; - } - fn air(&self) -> &Self::Air { - &self.air + Ok(fn_ptr) } } -pub struct FieldExtension; -impl FieldExtension { - pub(super) fn solve( - opcode: FieldExtensionOpcode, - x: [F; EXT_DEG], - y: [F; EXT_DEG], - ) -> Option<[F; EXT_DEG]> { - match opcode { - FieldExtensionOpcode::FE4ADD => Some(Self::add(x, y)), - FieldExtensionOpcode::FE4SUB => Some(Self::subtract(x, y)), - FieldExtensionOpcode::BBE4MUL => Some(Self::multiply(x, y)), - FieldExtensionOpcode::BBE4DIV => Some(Self::divide(x, y)), - } +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &FieldExtensionPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &FieldExtensionPreCompute, + vm_state: &mut VmExecState, +) { + let y: [F; EXT_DEG] = vm_state.vm_read::(AS::Native as u32, pre_compute.b); + let z: [F; EXT_DEG] = vm_state.vm_read::(AS::Native as u32, pre_compute.c); + + let x = match OPCODE { + 0 => FieldExtension::add(y, z), // FE4ADD + 1 => FieldExtension::subtract(y, z), // FE4SUB + 2 => FieldExtension::multiply(y, z), // BBE4MUL + 3 => FieldExtension::divide(y, z), // BBE4DIV + _ => panic!("Invalid field extension opcode: {OPCODE}"), + }; + + vm_state.vm_write(AS::Native as u32, pre_compute.a, &x); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +// Returns the result of the field extension operation. +// Will panic if divide by zero. +pub(super) fn run_field_extension( + opcode: FieldExtensionOpcode, + y: [F; EXT_DEG], + z: [F; EXT_DEG], +) -> [F; EXT_DEG] { + match opcode { + FieldExtensionOpcode::FE4ADD => FieldExtension::add(y, z), + FieldExtensionOpcode::FE4SUB => FieldExtension::subtract(y, z), + FieldExtensionOpcode::BBE4MUL => FieldExtension::multiply(y, z), + FieldExtensionOpcode::BBE4DIV => FieldExtension::divide(y, z), } +} +pub(crate) struct FieldExtension; + +impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where V: Copy, diff --git a/extensions/native/circuit/src/field_extension/mod.rs b/extensions/native/circuit/src/field_extension/mod.rs index d109deb528..6506ed8809 100644 --- a/extensions/native/circuit/src/field_extension/mod.rs +++ b/extensions/native/circuit/src/field_extension/mod.rs @@ -1,16 +1,18 @@ use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::native_vectorized_adapter::{ - NativeVectorizedAdapterAir, NativeVectorizedAdapterChip, +use crate::adapters::{ + NativeVectorizedAdapterAir, NativeVectorizedAdapterExecutor, NativeVectorizedAdapterFiller, }; -#[cfg(test)] -mod tests; - mod core; pub use core::*; +#[cfg(test)] +mod tests; + pub type FieldExtensionAir = VmAirWrapper, FieldExtensionCoreAir>; +pub type FieldExtensionExecutor = + FieldExtensionCoreExecutor>; pub type FieldExtensionChip = - VmChipWrapper, FieldExtensionCoreChip>; + VmChipWrapper>>; diff --git a/extensions/native/circuit/src/field_extension/tests.rs b/extensions/native/circuit/src/field_extension/tests.rs index 66d6c94004..afe6b649ba 100644 --- a/extensions/native/circuit/src/field_extension/tests.rs +++ b/extensions/native/circuit/src/field_extension/tests.rs @@ -1,102 +1,228 @@ use std::{ array, + borrow::BorrowMut, ops::{Add, Div, Mul, Sub}, }; -use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; +use openvm_circuit::arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldExtensionOpcode; +use openvm_native_compiler::{conversion::AS, FieldExtensionOpcode}; use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::{extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; -use strum::EnumCount; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{ - super::adapters::native_vectorized_adapter::NativeVectorizedAdapterChip, FieldExtension, - FieldExtensionChip, FieldExtensionCoreChip, +use crate::{ + adapters::{ + NativeVectorizedAdapterAir, NativeVectorizedAdapterExecutor, NativeVectorizedAdapterFiller, + }, + field_extension::run_field_extension, + test_utils::write_native_array, + FieldExtension, FieldExtensionAir, FieldExtensionChip, FieldExtensionCoreAir, + FieldExtensionCoreCols, FieldExtensionCoreFiller, FieldExtensionExecutor, EXT_DEG, }; -#[test] -fn new_field_extension_air_test() { - type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; +type Harness = TestChipHarness>; + +fn create_test_chip(tester: &VmChipTestBuilder) -> Harness { + let air = FieldExtensionAir::new( + NativeVectorizedAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + FieldExtensionCoreAir::new(), + ); + let executor = FieldExtensionExecutor::new(NativeVectorizedAdapterExecutor::new()); + let chip = FieldExtensionChip::::new( + FieldExtensionCoreFiller::new(NativeVectorizedAdapterFiller), + tester.memory_helper(), + ); - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldExtensionChip::new( - NativeVectorizedAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) +} + +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: FieldExtensionOpcode, + y: Option<[F; EXT_DEG]>, + z: Option<[F; EXT_DEG]>, +) { + let (y_val, y_ptr) = write_native_array(tester, rng, y); + let (z_val, z_ptr) = write_native_array(tester, rng, z); + + let x_ptr = gen_pointer(rng, EXT_DEG); + + tester.execute( + harness, + &Instruction::from_usize( + opcode.global_opcode(), + [ + x_ptr, + y_ptr, + z_ptr, + AS::Native as usize, + AS::Native as usize, + ], ), - FieldExtensionCoreChip::new(), - tester.offline_memory_mutex_arc(), ); - let trace_width = chip.trace_width(); + let result = tester.read::(AS::Native as usize, x_ptr); + let expected = run_field_extension(opcode, y_val, z_val); + assert_eq!(result, expected); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(FieldExtensionOpcode::FE4ADD, 100)] +#[test_case(FieldExtensionOpcode::FE4SUB, 100)] +#[test_case(FieldExtensionOpcode::BBE4MUL, 100)] +#[test_case(FieldExtensionOpcode::BBE4DIV, 100)] +fn rand_field_extension_test(opcode: FieldExtensionOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let num_ops: usize = 7; // test padding with dummy row + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); for _ in 0..num_ops { - let opcode = - FieldExtensionOpcode::from_usize(rng.gen_range(0..FieldExtensionOpcode::COUNT)); + set_and_execute(&mut tester, &mut harness, &mut rng, opcode, None, None); + } - let as_d = 4usize; - let as_e = 4usize; - let address1 = gen_pointer(&mut rng, 4); - let address2 = gen_pointer(&mut rng, 4); - let result_address = gen_pointer(&mut rng, 4); + let tester = tester.build().load(harness).finalize(); + tester.simple_test().expect("Verification failed"); +} - let operand1 = array::from_fn(|_| rng.gen::()); - let operand2 = array::from_fn(|_| rng.gen::()); +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// - assert!(address1.abs_diff(address2) >= 4); +#[derive(Clone, Copy, Default)] +struct FieldExtensionPrankValues { + pub x: Option<[F; EXT_DEG]>, + pub y: Option<[F; EXT_DEG]>, + pub z: Option<[F; EXT_DEG]>, + pub opcode_flags: Option<[bool; 4]>, + pub divisor_inv: Option<[F; EXT_DEG]>, +} - tester.write(as_d, address1, operand1); - tester.write(as_e, address2, operand2); +fn run_negative_field_extension_test( + opcode: FieldExtensionOpcode, + y: Option<[F; EXT_DEG]>, + z: Option<[F; EXT_DEG]>, + prank_vals: FieldExtensionPrankValues, + error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); + set_and_execute(&mut tester, &mut harness, &mut rng, opcode, y, z); - let result = FieldExtension::solve(opcode, operand1, operand2).unwrap(); + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let core_cols: &mut FieldExtensionCoreCols = + values.split_at_mut(adapter_width).1.borrow_mut(); - tester.execute( - &mut chip, - &Instruction::from_usize( - opcode.global_opcode(), - [result_address, address1, address2, as_d, as_e], - ), - ); - assert_eq!(result, tester.read(as_d, result_address)); - } + if let Some(x) = prank_vals.x { + core_cols.x = x; + } + if let Some(y) = prank_vals.y { + core_cols.y = y; + } + if let Some(z) = prank_vals.z { + core_cols.z = z; + } + if let Some(opcode_flags) = prank_vals.opcode_flags { + [ + core_cols.is_add, + core_cols.is_sub, + core_cols.is_mul, + core_cols.is_div, + ] = opcode_flags.map(F::from_bool); + } + if let Some(divisor_inv) = prank_vals.divisor_inv { + core_cols.divisor_inv = divisor_inv; + } - // positive test - let mut tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); + *trace = RowMajorMatrix::new(values, trace.width()); + }; disable_debug_builder(); - // negative test pranking each IO value - for height in [0, num_ops - 1] { - // TODO: better way to modify existing traces in tester - let extension_trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let original_trace = extension_trace.clone(); - for width in 0..trace_width { - let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); - extension_trace.row_mut(height)[width] = prank_value; - } + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); +} - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" - ); - tester.air_proof_inputs[2].1.raw.common_main = Some(original_trace); - } +#[test] +fn rand_negative_field_extension_test() { + let mut rng = create_seeded_rng(); + run_negative_field_extension_test( + FieldExtensionOpcode::FE4ADD, + None, + None, + FieldExtensionPrankValues { + x: Some(array::from_fn(|_| rng.gen::())), + y: Some(array::from_fn(|_| rng.gen::())), + z: Some(array::from_fn(|_| rng.gen::())), + opcode_flags: Some(array::from_fn(|_| rng.gen_bool(0.5))), + divisor_inv: Some(array::from_fn(|_| rng.gen::())), + }, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn field_extension_negative_tests() { + run_negative_field_extension_test( + FieldExtensionOpcode::BBE4DIV, + None, + None, + FieldExtensionPrankValues { + z: Some([F::ZERO; EXT_DEG]), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_field_extension_test( + FieldExtensionOpcode::BBE4DIV, + None, + None, + FieldExtensionPrankValues { + divisor_inv: Some([F::ZERO; EXT_DEG]), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_field_extension_test( + FieldExtensionOpcode::BBE4MUL, + Some([F::ZERO; EXT_DEG]), + None, + FieldExtensionPrankValues { + z: Some([F::ZERO; EXT_DEG]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); } #[test] diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 7dbc3fd851..01148ed613 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -2,38 +2,35 @@ use core::ops::Deref; use std::{ borrow::{Borrow, BorrowMut}, mem::offset_of, - sync::{Arc, Mutex}, }; -use itertools::{zip_eq, Itertools}; +use itertools::zip_eq; use openvm_circuit::{ - arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams, - }, + arch::*, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteAuxRecord, + }, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{memory_read_native, tracing_read_native, tracing_write_native}, }, }; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{conversion::AS, FriOpcode::FRI_REDUCED_OPENING}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, }; -use serde::{Deserialize, Serialize}; use static_assertions::const_assert_eq; use crate::{ @@ -219,8 +216,8 @@ const INSTRUCTION_READS: usize = 5; /// it starts with a Workload row (T1) and ends with either a Disabled or Instruction2 row (T7). /// The other transition constraints then ensure the proper state transitions from Workload to /// Instruction2. -#[derive(Copy, Clone, Debug)] -struct FriReducedOpeningAir { +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct FriReducedOpeningAir { execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge, } @@ -544,355 +541,729 @@ fn elem_to_ext(elem: F) -> [F; EXT_DEG] { ret } -#[derive(Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct FriReducedOpeningRecord { - pub pc: F, - pub start_timestamp: F, - pub instruction: Instruction, - pub alpha_read: RecordId, - pub length_read: RecordId, - pub a_ptr_read: RecordId, - pub is_init_read: RecordId, - pub b_ptr_read: RecordId, - pub a_rws: Vec, - pub b_reads: Vec, - pub result_write: RecordId, +#[derive(Copy, Clone, Debug)] +pub struct FriReducedOpeningMetadata { + length: usize, + is_init: bool, } -impl FriReducedOpeningRecord { - pub fn get_height(&self) -> usize { - // 2 for instruction rows - self.a_rws.len() + 2 +impl MultiRowMetadata for FriReducedOpeningMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + // Allocates `length` workload rows + 1 Instruction1 row + 1 Instruction2 row + self.length + 2 } } -pub struct FriReducedOpeningChip { - air: FriReducedOpeningAir, - pub records: Vec>, - pub height: usize, - offline_memory: Arc>>, - streams: Arc>>, +type FriReducedOpeningLayout = MultiRowLayout; + +// Header of record that is common for all trace rows for an instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct FriReducedOpeningHeaderRecord { + pub length: u32, + pub is_init: bool, +} + +// Part of record that is common for all trace rows for an instruction +// NOTE: Order for fields is important here to prevent overwriting. +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct FriReducedOpeningCommonRecord { + pub timestamp: u32, + + pub a_ptr: u32, + + pub b_ptr: u32, + + pub alpha: [F; EXT_DEG], + + pub from_pc: u32, + + pub a_ptr_ptr: F, + pub a_ptr_aux: MemoryReadAuxRecord, + + pub b_ptr_ptr: F, + pub b_ptr_aux: MemoryReadAuxRecord, + + pub length_ptr: F, + pub length_aux: MemoryReadAuxRecord, + + pub alpha_ptr: F, + pub alpha_aux: MemoryReadAuxRecord, + + pub result_ptr: F, + pub result_aux: MemoryWriteAuxRecord, + + pub hint_id_ptr: F, + + pub is_init_ptr: F, + pub is_init_aux: MemoryReadAuxRecord, +} + +// Part of record for each workload row that calculates the partial `result` +// NOTE: Order for fields is important here to prevent overwriting. +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct FriReducedOpeningWorkloadRowRecord { + pub a: F, + pub a_aux: MemoryReadAuxRecord, + // The result of this workload row + // b can be computed from a, alpha, result, and previous result: + // b = result + a - prev_result * alpha + pub result: [F; EXT_DEG], + pub b_aux: MemoryReadAuxRecord, +} + +// NOTE: Order for fields is important here to prevent overwriting. +#[derive(Debug)] +pub struct FriReducedOpeningRecordMut<'a, F> { + pub header: &'a mut FriReducedOpeningHeaderRecord, + pub workload: &'a mut [FriReducedOpeningWorkloadRowRecord], + // if is_init this will be an empty slice, otherwise it will be the previous data of writing + // `a`s + pub a_write_prev_data: &'a mut [F], + pub common: &'a mut FriReducedOpeningCommonRecord, } -impl FriReducedOpeningChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - offline_memory: Arc>>, - streams: Arc>>, - ) -> Self { - let air = FriReducedOpeningAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, + +impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpeningLayout> + for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: FriReducedOpeningLayout, + ) -> FriReducedOpeningRecordMut<'a, F> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let header: &mut FriReducedOpeningHeaderRecord = header_buf.borrow_mut(); + + let workload_size = + layout.metadata.length * size_of::>(); + + let (workload_buf, rest) = unsafe { rest.split_at_mut_unchecked(workload_size) }; + let a_prev_size = if layout.metadata.is_init { + 0 + } else { + layout.metadata.length * size_of::() }; - Self { - records: vec![], - air, - height: 0, - offline_memory, - streams, + + let (a_prev_buf, common_buf) = unsafe { rest.split_at_mut_unchecked(a_prev_size) }; + + let (_, a_prev_records, _) = unsafe { a_prev_buf.align_to_mut::() }; + let (_, workload_records, _) = + unsafe { workload_buf.align_to_mut::>() }; + + let common: &mut FriReducedOpeningCommonRecord = common_buf.borrow_mut(); + + FriReducedOpeningRecordMut { + header, + workload: &mut workload_records[..layout.metadata.length], + a_write_prev_data: &mut a_prev_records[..], + common, + } + } + + unsafe fn extract_layout(&self) -> FriReducedOpeningLayout { + let header: &FriReducedOpeningHeaderRecord = self.borrow(); + FriReducedOpeningLayout::new(FriReducedOpeningMetadata { + length: header.length as usize, + is_init: header.is_init, + }) + } +} + +impl SizedRecord for FriReducedOpeningRecordMut<'_, F> { + fn size(layout: &FriReducedOpeningLayout) -> usize { + let mut total_len = size_of::(); + total_len += layout.metadata.length * size_of::>(); + if !layout.metadata.is_init { + total_len += layout.metadata.length * size_of::(); } + total_len += size_of::>(); + total_len + } + + fn alignment(_layout: &FriReducedOpeningLayout) -> usize { + align_of::() + } +} + +#[derive(derive_new::new, Copy, Clone)] +pub struct FriReducedOpeningExecutor; + +#[derive(derive_new::new)] +pub struct FriReducedOpeningFiller; + +pub type FriReducedOpeningChip = VmChipWrapper; + +impl Default for FriReducedOpeningExecutor { + fn default() -> Self { + Self::new() } } -impl InstructionExecutor for FriReducedOpeningChip { + +impl PreflightExecutor for FriReducedOpeningExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, FriReducedOpeningLayout, FriReducedOpeningRecordMut<'buf, F>>, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + assert_eq!(opcode, FRI_REDUCED_OPENING.global_opcode().as_usize()); + String::from("FRI_REDUCED_OPENING") + } + fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + ) -> Result<(), ExecutionError> { let &Instruction { - a: a_ptr_ptr, - b: b_ptr_ptr, - c: length_ptr, - d: alpha_ptr, - e: result_ptr, - f: hint_id_ptr, - g: is_init_ptr, + a, + b, + c, + d, + e, + f, + g, .. } = instruction; - let addr_space = F::from_canonical_u32(AS::Native as u32); - let alpha_read = memory.read(addr_space, alpha_ptr); - let length_read = memory.read_cell(addr_space, length_ptr); - let a_ptr_read = memory.read_cell(addr_space, a_ptr_ptr); - let b_ptr_read = memory.read_cell(addr_space, b_ptr_ptr); - let is_init_read = memory.read_cell(addr_space, is_init_ptr); - let is_init = is_init_read.1.as_canonical_u32(); + let timestamp_start = state.memory.timestamp; - let hint_id_f = memory.unsafe_read_cell(addr_space, hint_id_ptr); - let hint_id = hint_id_f.as_canonical_u32() as usize; + // Read length from memory to allocate record + let length_ptr = c.as_canonical_u32(); + let [length]: [F; 1] = memory_read_native(&state.memory.data, length_ptr); + let length = length.as_canonical_u32(); + let is_init_ptr = g.as_canonical_u32(); + let [is_init]: [F; 1] = memory_read_native(&state.memory.data, is_init_ptr); + let is_init = is_init != F::ZERO; - let alpha = alpha_read.1; - let length = length_read.1.as_canonical_u32() as usize; - let a_ptr = a_ptr_read.1; - let b_ptr = b_ptr_read.1; + let metadata = FriReducedOpeningMetadata { + length: length as usize, + is_init, + }; + let record = state.ctx.alloc(MultiRowLayout::new(metadata)); - let mut a_rws = Vec::with_capacity(length); - let mut b_reads = Vec::with_capacity(length); - let mut result = [F::ZERO; EXT_DEG]; + record.common.from_pc = *state.pc; + record.common.timestamp = timestamp_start; + + let alpha_ptr = d.as_canonical_u32(); + let alpha = tracing_read_native( + state.memory, + alpha_ptr, + &mut record.common.alpha_aux.prev_timestamp, + ); + record.common.alpha_ptr = d; + record.common.alpha = alpha; + + tracing_read_native::( + state.memory, + length_ptr, + &mut record.common.length_aux.prev_timestamp, + ); + record.common.length_ptr = c; + record.header.length = length; + + let a_ptr_ptr = a.as_canonical_u32(); + let [a_ptr]: [F; 1] = tracing_read_native( + state.memory, + a_ptr_ptr, + &mut record.common.a_ptr_aux.prev_timestamp, + ); + record.common.a_ptr_ptr = a; + record.common.a_ptr = a_ptr.as_canonical_u32(); + + let b_ptr_ptr = b.as_canonical_u32(); + let [b_ptr]: [F; 1] = tracing_read_native( + state.memory, + b_ptr_ptr, + &mut record.common.b_ptr_aux.prev_timestamp, + ); + record.common.b_ptr_ptr = b; + record.common.b_ptr = b_ptr.as_canonical_u32(); + + tracing_read_native::( + state.memory, + is_init_ptr, + &mut record.common.is_init_aux.prev_timestamp, + ); + record.common.is_init_ptr = g; + record.header.is_init = is_init; + + let hint_id_ptr = f.as_canonical_u32(); + let [hint_id]: [F; 1] = memory_read_native(state.memory.data(), hint_id_ptr); + let hint_id = hint_id.as_canonical_u32() as usize; + record.common.hint_id_ptr = f; - let data = if is_init == 0 { - let mut streams = self.streams.lock().unwrap(); - let hint_steam = &mut streams.hint_space[hint_id]; + let length = length as usize; + + let data = if !is_init { + let hint_steam = &mut state.streams.hint_space[hint_id]; hint_steam.drain(0..length).collect() } else { vec![] }; + + let mut as_and_bs = Vec::with_capacity(length); #[allow(clippy::needless_range_loop)] for i in 0..length { - let a_rw = if is_init == 0 { - let (record_id, _) = - memory.write_cell(addr_space, a_ptr + F::from_canonical_usize(i), data[i]); - (record_id, data[i]) + let workload_row = &mut record.workload[length - i - 1]; + + let a_ptr_i = record.common.a_ptr + i as u32; + let [a]: [F; 1] = if !is_init { + let mut prev = [F::ZERO; 1]; + tracing_write_native( + state.memory, + a_ptr_i, + [data[i]], + &mut workload_row.a_aux.prev_timestamp, + &mut prev, + ); + record.a_write_prev_data[length - i - 1] = prev[0]; + [data[i]] } else { - memory.read_cell(addr_space, a_ptr + F::from_canonical_usize(i)) + tracing_read_native( + state.memory, + a_ptr_i, + &mut workload_row.a_aux.prev_timestamp, + ) }; - let b_read = - memory.read::(addr_space, b_ptr + F::from_canonical_usize(EXT_DEG * i)); - a_rws.push(a_rw); - b_reads.push(b_read); + let b_ptr_i = record.common.b_ptr + (EXT_DEG * i) as u32; + let b = tracing_read_native::( + state.memory, + b_ptr_i, + &mut workload_row.b_aux.prev_timestamp, + ); + + as_and_bs.push((a, b)); } - for (a_rw, b_read) in a_rws.iter().rev().zip_eq(b_reads.iter().rev()) { - let a = a_rw.1; - let b = b_read.1; + let mut result = [F::ZERO; EXT_DEG]; + for (i, (a, b)) in as_and_bs.into_iter().rev().enumerate() { + let workload_row = &mut record.workload[i]; + // result = result * alpha + (b - a) result = FieldExtension::add( FieldExtension::multiply(result, alpha), FieldExtension::subtract(b, elem_to_ext(a)), ); + workload_row.a = a; + workload_row.result = result; } - let (result_write, _) = memory.write(addr_space, result_ptr, result); - - let record = FriReducedOpeningRecord { - pc: F::from_canonical_u32(from_state.pc), - start_timestamp: F::from_canonical_u32(from_state.timestamp), - instruction: instruction.clone(), - alpha_read: alpha_read.0, - length_read: length_read.0, - a_ptr_read: a_ptr_read.0, - is_init_read: is_init_read.0, - b_ptr_read: b_ptr_read.0, - a_rws: a_rws.into_iter().map(|r| r.0).collect(), - b_reads: b_reads.into_iter().map(|r| r.0).collect(), - result_write, - }; - self.height += record.get_height(); - self.records.push(record); + let result_ptr = e.as_canonical_u32(); + tracing_write_native( + state.memory, + result_ptr, + result, + &mut record.common.result_aux.prev_timestamp, + &mut record.common.result_aux.prev_data, + ); + record.common.result_ptr = e; - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn get_opcode_name(&self, opcode: usize) -> String { - assert_eq!(opcode, FRI_REDUCED_OPENING.global_opcode().as_usize()); - String::from("FRI_REDUCED_OPENING") + Ok(()) } } -fn record_to_rows( - record: FriReducedOpeningRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, -) { - let Instruction { - a: a_ptr_ptr, - b: b_ptr_ptr, - c: length_ptr, - d: alpha_ptr, - e: result_ptr, - f: hint_id_ptr, - g: is_init_ptr, - .. - } = record.instruction; - - let length_read = memory.record_by_id(record.length_read); - let alpha_read = memory.record_by_id(record.alpha_read); - let a_ptr_read = memory.record_by_id(record.a_ptr_read); - let b_ptr_read = memory.record_by_id(record.b_ptr_read); - let is_init_read = memory.record_by_id(record.is_init_read); - let is_init = is_init_read.data_at(0); - let write_a = F::ONE - is_init; - - let length = length_read.data_at(0).as_canonical_u32() as usize; - let alpha: [F; EXT_DEG] = alpha_read.data_slice().try_into().unwrap(); - let a_ptr = a_ptr_read.data_at(0); - let b_ptr = b_ptr_read.data_at(0); +impl TraceFiller for FriReducedOpeningFiller { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + debug_assert_eq!(trace.width, OVERALL_WIDTH); - let mut result = [F::ZERO; EXT_DEG]; + let mut remaining_trace = &mut trace.values[..OVERALL_WIDTH * rows_used]; + let mut chunks = Vec::with_capacity(rows_used); + while !remaining_trace.is_empty() { + let header: &FriReducedOpeningHeaderRecord = + unsafe { get_record_from_slice(&mut remaining_trace, ()) }; + let num_rows = header.length as usize + 2; + let chunk_size = OVERALL_WIDTH * num_rows; + let (chunk, rest) = remaining_trace.split_at_mut(chunk_size); + chunks.push((chunk, header.is_init)); + remaining_trace = rest; + } - let alpha_aux = aux_cols_factory.make_read_aux_cols(alpha_read); - let length_aux = aux_cols_factory.make_read_aux_cols(length_read); - let a_ptr_aux = aux_cols_factory.make_read_aux_cols(a_ptr_read); - let b_ptr_aux = aux_cols_factory.make_read_aux_cols(b_ptr_read); - let is_init_aux = aux_cols_factory.make_read_aux_cols(is_init_read); - - let result_aux = aux_cols_factory.make_write_aux_cols(memory.record_by_id(record.result_write)); - - // WorkloadCols - for (i, (&a_record_id, &b_record_id)) in record - .a_rws - .iter() - .rev() - .zip_eq(record.b_reads.iter().rev()) - .enumerate() - { - let a_rw = memory.record_by_id(a_record_id); - let b_read = memory.record_by_id(b_record_id); - let a = a_rw.data_at(0); - let b: [F; EXT_DEG] = b_read.data_slice().try_into().unwrap(); - - let start = i * OVERALL_WIDTH; - let cols: &mut WorkloadCols = slice[start..start + WL_WIDTH].borrow_mut(); - *cols = WorkloadCols { - prefix: PrefixCols { - general: GeneralCols { - is_workload_row: F::ONE, - is_ins_row: F::ZERO, - timestamp: record.start_timestamp + F::from_canonical_usize((length - i) * 2), - }, - a_or_is_first: a, - data: DataCols { - a_ptr: a_ptr + F::from_canonical_usize(length - i), - write_a, - b_ptr: b_ptr + F::from_canonical_usize((length - i) * EXT_DEG), - idx: F::from_canonical_usize(i), - result, - alpha, - }, - }, - // Generate write aux columns no matter `a` is read or written. When `a` is written, - // `prev_data` is not constrained. - a_aux: if a_rw.prev_data_slice().is_some() { - aux_cols_factory.make_write_aux_cols(a_rw) + chunks.into_par_iter().for_each(|(mut chunk, is_init)| { + let num_rows = chunk.len() / OVERALL_WIDTH; + let metadata = FriReducedOpeningMetadata { + length: num_rows - 2, + is_init, + }; + let record: FriReducedOpeningRecordMut = + unsafe { get_record_from_slice(&mut chunk, MultiRowLayout::new(metadata)) }; + + let timestamp = record.common.timestamp; + let length = record.header.length as usize; + let alpha = record.common.alpha; + let is_init = record.header.is_init; + let write_a = F::from_bool(!is_init); + + let a_ptr = record.common.a_ptr; + let b_ptr = record.common.b_ptr; + + let (workload_chunk, rest) = chunk.split_at_mut(length * OVERALL_WIDTH); + let (ins1_chunk, ins2_chunk) = rest.split_at_mut(OVERALL_WIDTH); + + { + // ins2 row + let cols: &mut Instruction2Cols = ins2_chunk[..INS_2_WIDTH].borrow_mut(); + + cols.write_a_x_is_first = F::ZERO; + + mem_helper.fill( + record.common.is_init_aux.prev_timestamp, + timestamp + 4, + cols.is_init_aux.as_mut(), + ); + cols.is_init_ptr = record.common.is_init_ptr; + + cols.hint_id_ptr = record.common.hint_id_ptr; + + cols.result_aux + .set_prev_data(record.common.result_aux.prev_data); + mem_helper.fill( + record.common.result_aux.prev_timestamp, + timestamp + 5 + 2 * length as u32, + cols.result_aux.as_mut(), + ); + cols.result_ptr = record.common.result_ptr; + + mem_helper.fill( + record.common.alpha_aux.prev_timestamp, + timestamp, + cols.alpha_aux.as_mut(), + ); + cols.alpha_ptr = record.common.alpha_ptr; + + mem_helper.fill( + record.common.length_aux.prev_timestamp, + timestamp + 1, + cols.length_aux.as_mut(), + ); + cols.length_ptr = record.common.length_ptr; + + cols.is_first = F::ZERO; + + cols.general.timestamp = F::from_canonical_u32(timestamp); + cols.general.is_ins_row = F::ONE; + cols.general.is_workload_row = F::ZERO; + + ins2_chunk[INS_2_WIDTH..OVERALL_WIDTH].fill(F::ZERO); + } + + { + // ins 1 row + let cols: &mut Instruction1Cols = ins1_chunk[..INS_1_WIDTH].borrow_mut(); + + cols.write_a_x_is_first = write_a; + + mem_helper.fill( + record.common.b_ptr_aux.prev_timestamp, + timestamp + 3, + cols.b_ptr_aux.as_mut(), + ); + cols.b_ptr_ptr = record.common.b_ptr_ptr; + + mem_helper.fill( + record.common.a_ptr_aux.prev_timestamp, + timestamp + 2, + cols.a_ptr_aux.as_mut(), + ); + cols.a_ptr_ptr = record.common.a_ptr_ptr; + + cols.pc = F::from_canonical_u32(record.common.from_pc); + + cols.prefix.data.alpha = alpha; + cols.prefix.data.result = record.workload.last().unwrap().result; + cols.prefix.data.idx = F::from_canonical_usize(length); + cols.prefix.data.b_ptr = F::from_canonical_u32(b_ptr); + cols.prefix.data.write_a = write_a; + cols.prefix.data.a_ptr = F::from_canonical_u32(a_ptr); + + cols.prefix.a_or_is_first = F::ONE; + + cols.prefix.general.timestamp = F::from_canonical_u32(timestamp); + cols.prefix.general.is_ins_row = F::ONE; + cols.prefix.general.is_workload_row = F::ZERO; + ins1_chunk[INS_1_WIDTH..OVERALL_WIDTH].fill(F::ZERO); + } + + // To fill the WorkloadRows we do 2 passes: + // - First, a serial pass to fill some of the records into the trace + // - Then, a parallel pass to fill the rest of the records into the trace + // Note, the first pass is done to avoid overwriting the records + + // Copy of `a_write_prev_data` to avoid overwriting it and to use it in the parallel + // pass + let a_prev_data = if !is_init { + let mut tmp = Vec::with_capacity(length); + tmp.extend_from_slice(record.a_write_prev_data); + tmp } else { - let read_aux = aux_cols_factory.make_read_aux_cols(a_rw); - MemoryWriteAuxCols::from_base(read_aux.get_base(), [F::ZERO]) - }, - b, - b_aux: aux_cols_factory.make_read_aux_cols(b_read), - }; - // result = result * alpha + (b - a) - result = FieldExtension::add( - FieldExtension::multiply(result, alpha), - FieldExtension::subtract(b, elem_to_ext(a)), - ); + vec![] + }; + + for (i, (workload_row, row_chunk)) in record + .workload + .iter() + .zip(workload_chunk.chunks_exact_mut(OVERALL_WIDTH)) + .enumerate() + .rev() + { + let cols: &mut WorkloadCols = row_chunk[..WL_WIDTH].borrow_mut(); + + let timestamp = timestamp + ((length - i) * 2) as u32; + + // fill in reverse order + mem_helper.fill( + workload_row.b_aux.prev_timestamp, + timestamp + 4, + cols.b_aux.as_mut(), + ); + + // We temporarily store the result here + // the correct value of b is computed during the serial pass below + cols.b = record.workload[i].result; + + mem_helper.fill( + workload_row.a_aux.prev_timestamp, + timestamp + 3, + cols.a_aux.as_mut(), + ); + cols.prefix.a_or_is_first = workload_row.a; + + if i > 0 { + cols.prefix.data.result = record.workload[i - 1].result; + } + } + + workload_chunk + .par_chunks_exact_mut(OVERALL_WIDTH) + .enumerate() + .for_each(|(i, row_chunk)| { + let cols: &mut WorkloadCols = row_chunk[..WL_WIDTH].borrow_mut(); + let timestamp = timestamp + ((length - i) * 2) as u32; + if is_init { + cols.a_aux.set_prev_data([F::ZERO; 1]); + } else { + cols.a_aux.set_prev_data([a_prev_data[i]]); + } + + // DataCols + cols.prefix.data.a_ptr = F::from_canonical_u32(a_ptr + (length - i) as u32); + cols.prefix.data.write_a = write_a; + cols.prefix.data.b_ptr = + F::from_canonical_u32(b_ptr + ((length - i) * EXT_DEG) as u32); + cols.prefix.data.idx = F::from_canonical_usize(i); + if i == 0 { + cols.prefix.data.result = [F::ZERO; EXT_DEG]; + } + cols.prefix.data.alpha = alpha; + + // GeneralCols + cols.prefix.general.is_workload_row = F::ONE; + cols.prefix.general.is_ins_row = F::ZERO; + + // WorkloadCols + cols.prefix.general.timestamp = F::from_canonical_u32(timestamp); + + cols.b = FieldExtension::subtract( + FieldExtension::add(cols.b, elem_to_ext(cols.prefix.a_or_is_first)), + FieldExtension::multiply(cols.prefix.data.result, alpha), + ); + row_chunk[WL_WIDTH..OVERALL_WIDTH].fill(F::ZERO); + }); + }); } - // Instruction1Cols - { - let start = length * OVERALL_WIDTH; - let cols: &mut Instruction1Cols = slice[start..start + INS_1_WIDTH].borrow_mut(); - *cols = Instruction1Cols { - prefix: PrefixCols { - general: GeneralCols { - is_workload_row: F::ZERO, - is_ins_row: F::ONE, - timestamp: record.start_timestamp, - }, - a_or_is_first: F::ONE, - data: DataCols { - a_ptr, - write_a, - b_ptr, - idx: F::from_canonical_usize(length), - result, - alpha, - }, - }, - pc: record.pc, +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FriReducedOpeningPreCompute { + a_ptr_ptr: u32, + b_ptr_ptr: u32, + length_ptr: u32, + alpha_ptr: u32, + result_ptr: u32, + hint_id_ptr: u32, + is_init_ptr: u32, +} + +impl FriReducedOpeningExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + _pc: u32, + inst: &Instruction, + data: &mut FriReducedOpeningPreCompute, + ) -> Result<(), StaticProgramError> { + let &Instruction { + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let a_ptr_ptr = a.as_canonical_u32(); + let b_ptr_ptr = b.as_canonical_u32(); + let length_ptr = c.as_canonical_u32(); + let alpha_ptr = d.as_canonical_u32(); + let result_ptr = e.as_canonical_u32(); + let hint_id_ptr = f.as_canonical_u32(); + let is_init_ptr = g.as_canonical_u32(); + + *data = FriReducedOpeningPreCompute { a_ptr_ptr, - a_ptr_aux, b_ptr_ptr, - b_ptr_aux, - write_a_x_is_first: write_a, - }; - } - // Instruction2Cols - { - let start = (length + 1) * OVERALL_WIDTH; - let cols: &mut Instruction2Cols = slice[start..start + INS_2_WIDTH].borrow_mut(); - *cols = Instruction2Cols { - general: GeneralCols { - is_workload_row: F::ZERO, - is_ins_row: F::ONE, - timestamp: record.start_timestamp, - }, - is_first: F::ZERO, length_ptr, - length_aux, alpha_ptr, - alpha_aux, result_ptr, - result_aux, hint_id_ptr, is_init_ptr, - is_init_aux, - write_a_x_is_first: F::ZERO, }; + + Ok(()) } } -impl ChipUsageGetter for FriReducedOpeningChip { - fn air_name(&self) -> String { - "FriReducedOpeningAir".to_string() +impl Executor for FriReducedOpeningExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn current_trace_height(&self) -> usize { - self.height - } + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut FriReducedOpeningPreCompute = data.borrow_mut(); - fn trace_width(&self) -> usize { - OVERALL_WIDTH + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl; + Ok(fn_ptr) } } -impl Chip for FriReducedOpeningChip> +impl MeteredExecutor for FriReducedOpeningExecutor where - Val: PrimeField32, + F: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air) + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_air_proof_input(self) -> AirProofInput { - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = Val::::zero_vec(OVERALL_WIDTH * height); - let chunked_trace = { - let sizes: Vec<_> = self - .records - .par_iter() - .map(|record| OVERALL_WIDTH * record.get_height()) - .collect(); - variable_chunks_mut(&mut flat_trace, &sizes) - }; - let memory = self.offline_memory.lock().unwrap(); - let aux_cols_factory = memory.aux_cols_factory(); + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; - self.records - .into_par_iter() - .zip_eq(chunked_trace.into_par_iter()) - .for_each(|(record, slice)| { - record_to_rows(record, &aux_cols_factory, slice, &memory); - }); + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH); - AirProofInput::simple_no_pis(matrix) + let fn_ptr = execute_e2_impl; + Ok(fn_ptr) } } -fn variable_chunks_mut<'a, T>(mut slice: &'a mut [T], sizes: &[usize]) -> Vec<&'a mut [T]> { - let mut result = Vec::with_capacity(sizes.len()); - for &size in sizes { - // split_at_mut guarantees disjoint slices - let (left, right) = slice.split_at_mut(size); - result.push(left); - slice = right; // move forward for the next chunk +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &FriReducedOpeningPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &FriReducedOpeningPreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let alpha = vm_state.vm_read(AS::Native as u32, pre_compute.alpha_ptr); + + let [length]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.length_ptr); + let length = length.as_canonical_u32() as usize; + + let [a_ptr]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.a_ptr_ptr); + let [b_ptr]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.b_ptr_ptr); + + let [is_init_read]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.is_init_ptr); + let is_init = is_init_read.as_canonical_u32(); + + let [hint_id_f]: [F; 1] = vm_state.host_read(AS::Native as u32, pre_compute.hint_id_ptr); + let hint_id = hint_id_f.as_canonical_u32() as usize; + + let data = if is_init == 0 { + let hint_steam = &mut vm_state.streams.hint_space[hint_id]; + hint_steam.drain(0..length).collect() + } else { + vec![] + }; + + let mut as_and_bs = Vec::with_capacity(length); + #[allow(clippy::needless_range_loop)] + for i in 0..length { + let a_ptr_i = (a_ptr + F::from_canonical_usize(i)).as_canonical_u32(); + let [a]: [F; 1] = if is_init == 0 { + vm_state.vm_write(AS::Native as u32, a_ptr_i, &[data[i]]); + [data[i]] + } else { + vm_state.vm_read(AS::Native as u32, a_ptr_i) + }; + let b_ptr_i = (b_ptr + F::from_canonical_usize(EXT_DEG * i)).as_canonical_u32(); + let b = vm_state.vm_read(AS::Native as u32, b_ptr_i); + + as_and_bs.push((a, b)); + } + + let mut result = [F::ZERO; EXT_DEG]; + for (a, b) in as_and_bs.into_iter().rev() { + // result = result * alpha + (b - a) + result = FieldExtension::add( + FieldExtension::multiply(result, alpha), + FieldExtension::subtract(b, elem_to_ext(a)), + ); } - result + + vm_state.vm_write(AS::Native as u32, pre_compute.result_ptr, &result); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + length as u32 + 2 } diff --git a/extensions/native/circuit/src/fri/tests.rs b/extensions/native/circuit/src/fri/tests.rs index 97dcdbc532..5910f69e93 100644 --- a/extensions/native/circuit/src/fri/tests.rs +++ b/extensions/native/circuit/src/fri/tests.rs @@ -1,22 +1,42 @@ -use std::sync::{Arc, Mutex}; +use std::borrow::BorrowMut; use itertools::Itertools; -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder}, - Streams, -}; +use openvm_circuit::arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FriOpcode::FRI_REDUCED_OPENING; +use openvm_native_compiler::{conversion::AS, FriOpcode::FRI_REDUCED_OPENING}; use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, verifier::VerificationError, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; + +use super::{ + super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningAir, + FriReducedOpeningChip, FriReducedOpeningExecutor, EXT_DEG, +}; +use crate::{ + fri::{WorkloadCols, OVERALL_WIDTH, WL_WIDTH}, + write_native_array, FriReducedOpeningFiller, +}; + +const MAX_INS_CAPACITY: usize = 1024; +type F = BabyBear; +type Harness = + TestChipHarness>; -use super::{super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningChip, EXT_DEG}; -use crate::OVERALL_WIDTH; +fn create_test_chip(tester: &VmChipTestBuilder) -> Harness { + let air = FriReducedOpeningAir::new(tester.execution_bridge(), tester.memory_bridge()); + let step = FriReducedOpeningExecutor::new(); + let chip = FriReducedOpeningChip::new(FriReducedOpeningFiller, tester.memory_helper()); + + Harness::with_capacity(step, air, chip, MAX_INS_CAPACITY) +} fn compute_fri_mat_opening( alpha: [F; EXT_DEG], @@ -35,146 +55,111 @@ fn compute_fri_mat_opening( result } -#[test] -fn fri_mat_opening_air_test() { - let num_ops = 14; // non-power-of-2 to also test padding - let elem_range = || 1..=100; - let length_range = || 1..=49; - - let mut tester = VmChipTestBuilder::default(); - - let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = FriReducedOpeningChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - streams.clone(), +fn set_and_execute(tester: &mut VmChipTestBuilder, harness: &mut Harness, rng: &mut StdRng) { + let len = rng.gen_range(1..=28); + let a_ptr = gen_pointer(rng, len); + let b_ptr = gen_pointer(rng, len); + let a_ptr_ptr = + write_native_array::(tester, rng, Some([F::from_canonical_usize(a_ptr)])).1; + let b_ptr_ptr = + write_native_array::(tester, rng, Some([F::from_canonical_usize(b_ptr)])).1; + + let len_ptr = write_native_array::(tester, rng, Some([F::from_canonical_usize(len)])).1; + let (alpha, alpha_ptr) = write_native_array::(tester, rng, None); + let out_ptr = gen_pointer(rng, EXT_DEG); + let is_init = true; + let is_init_ptr = write_native_array::(tester, rng, Some([F::from_bool(is_init)])).1; + + let mut vec_a = Vec::with_capacity(len); + let mut vec_b = Vec::with_capacity(len); + for i in 0..len { + let a = rng.gen(); + let b: [F; EXT_DEG] = std::array::from_fn(|_| rng.gen()); + vec_a.push(a); + vec_b.push(b); + if !is_init { + tester.streams.hint_space[0].push(a); + } else { + tester.write(AS::Native as usize, a_ptr + i, [a]); + } + tester.write(AS::Native as usize, b_ptr + (EXT_DEG * i), b); + } + + tester.execute( + harness, + &Instruction::from_usize( + FRI_REDUCED_OPENING.global_opcode(), + [ + a_ptr_ptr, + b_ptr_ptr, + len_ptr, + alpha_ptr, + out_ptr, + 0, // hint id, will just use 0 for testing + is_init_ptr, + ], + ), ); - let mut rng = create_seeded_rng(); + let expected_result = compute_fri_mat_opening(alpha, &vec_a, &vec_b); + assert_eq!(expected_result, tester.read(AS::Native as usize, out_ptr)); - macro_rules! gen_ext { - () => { - std::array::from_fn::<_, EXT_DEG, _>(|_| { - BabyBear::from_canonical_u32(rng.gen_range(elem_range())) - }) - }; + for (i, ai) in vec_a.iter().enumerate() { + let [found] = tester.read(AS::Native as usize, a_ptr + i); + assert_eq!(*ai, found); } +} - streams.lock().unwrap().hint_space = vec![vec![]]; - - for _ in 0..num_ops { - let alpha = gen_ext!(); - let length = rng.gen_range(length_range()); - let a = (0..length) - .map(|_| BabyBear::from_canonical_u32(rng.gen_range(elem_range()))) - .collect_vec(); - let b = (0..length).map(|_| gen_ext!()).collect_vec(); - - let result = compute_fri_mat_opening(alpha, &a, &b); - - let alpha_pointer = gen_pointer(&mut rng, 4); - let length_pointer = gen_pointer(&mut rng, 1); - let a_pointer_pointer = gen_pointer(&mut rng, 1); - let b_pointer_pointer = gen_pointer(&mut rng, 1); - let result_pointer = gen_pointer(&mut rng, 4); - let a_pointer = gen_pointer(&mut rng, 1); - let b_pointer = gen_pointer(&mut rng, 4); - let is_init_ptr = gen_pointer(&mut rng, 1); - - let address_space = 4usize; - - /*tracing::debug!( - "{opcode:?} d = {}, e = {}, f = {}, result_addr = {}, addr1 = {}, addr2 = {}, z = {}, x = {}, y = {}", - result_as, as1, as2, result_pointer, address1, address2, result, operand1, operand2, - );*/ - - tester.write(address_space, alpha_pointer, alpha); - tester.write_cell( - address_space, - length_pointer, - BabyBear::from_canonical_usize(length), - ); - tester.write_cell( - address_space, - a_pointer_pointer, - BabyBear::from_canonical_usize(a_pointer), - ); - tester.write_cell( - address_space, - b_pointer_pointer, - BabyBear::from_canonical_usize(b_pointer), - ); - let is_init = rng.gen_range(0..2); - tester.write_cell( - address_space, - is_init_ptr, - BabyBear::from_canonical_u32(is_init), - ); +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// - if is_init == 0 { - streams.lock().unwrap().hint_space[0].extend_from_slice(&a); - } else { - for (i, ai) in a.iter().enumerate() { - tester.write_cell(address_space, a_pointer + i, *ai); - } - } - for (i, bi) in b.iter().enumerate() { - tester.write(address_space, b_pointer + (4 * i), *bi); - } +#[test] +fn fri_mat_opening_air_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); - tester.execute( - &mut chip, - &Instruction::from_usize( - FRI_REDUCED_OPENING.global_opcode(), - [ - a_pointer_pointer, - b_pointer_pointer, - length_pointer, - alpha_pointer, - result_pointer, - 0, // hint id - is_init_ptr, - ], - ), - ); - assert_eq!(result, tester.read(address_space, result_pointer)); - // Check that `a` was populated. - for (i, ai) in a.iter().enumerate() { - let found = tester.read_cell(address_space, a_pointer + i); - assert_eq!(*ai, found); - } + let num_ops = 28; // non-power-of-2 to also test padding + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut harness, &mut rng); } - let mut tester = tester.build().load(chip).finalize(); + let tester = tester.build().load(harness).finalize(); tester.simple_test().expect("Verification failed"); +} - disable_debug_builder(); - // negative test pranking each value - for height in 0..num_ops { - // TODO: better way to modify existing traces in tester - let trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let old_trace = trace.clone(); - for width in 0..OVERALL_WIDTH - /* num operands */ - { - let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); - trace.row_mut(height)[width] = prank_value; - } +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// - // Run a test after pranking each row - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" - ); +#[test] +fn run_negative_fri_mat_opening_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); - tester.air_proof_inputs[2].1.raw.common_main = Some(old_trace); - } + set_and_execute(&mut tester, &mut harness, &mut rng); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut WorkloadCols = values[..WL_WIDTH].borrow_mut(); + + cols.prefix.a_or_is_first = F::from_canonical_u32(42); + + *trace = RowMajorMatrix::new(values, OVERALL_WIDTH); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); } diff --git a/extensions/native/circuit/src/jal/mod.rs b/extensions/native/circuit/src/jal/mod.rs deleted file mode 100644 index 28322834a2..0000000000 --- a/extensions/native/circuit/src/jal/mod.rs +++ /dev/null @@ -1,342 +0,0 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - ops::Deref, - sync::{Arc, Mutex}, -}; - -use openvm_circuit::{ - arch::{ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, PcIncOrSet}, - system::memory::{ - offline_checker::{MemoryBridge, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, - }, -}; -use openvm_circuit_primitives::{ - utils::next_power_of_two_or_zero, - var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip, - }, -}; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; -use openvm_native_compiler::{conversion::AS, NativeJalOpcode, NativeRangeCheckOpcode}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - interaction::InteractionBuilder, - p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, -}; -use serde::{Deserialize, Serialize}; -use static_assertions::const_assert_eq; -use AS::Native; - -#[cfg(test)] -mod tests; - -#[repr(C)] -#[derive(AlignedBorrow)] -struct JalRangeCheckCols { - is_jal: T, - is_range_check: T, - a_pointer: T, - state: ExecutionState, - // Write when is_jal, read when is_range_check. - writes_aux: MemoryWriteAuxCols, - b: T, - // Only used by range check. - c: T, - // Only used by range check. - y: T, -} - -const OVERALL_WIDTH: usize = JalRangeCheckCols::::width(); -const_assert_eq!(OVERALL_WIDTH, 12); - -#[derive(Copy, Clone, Debug)] -pub struct JalRangeCheckAir { - execution_bridge: ExecutionBridge, - memory_bridge: MemoryBridge, - range_bus: VariableRangeCheckerBus, -} - -impl BaseAir for JalRangeCheckAir { - fn width(&self) -> usize { - OVERALL_WIDTH - } -} - -impl BaseAirWithPublicValues for JalRangeCheckAir {} -impl PartitionedBaseAir for JalRangeCheckAir {} -impl Air for JalRangeCheckAir -where - AB::F: PrimeField32, -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_slice = local.deref(); - let local: &JalRangeCheckCols = local_slice.borrow(); - builder.assert_bool(local.is_jal); - builder.assert_bool(local.is_range_check); - let is_valid = local.is_jal + local.is_range_check; - builder.assert_bool(is_valid.clone()); - - let d = AB::Expr::from_canonical_u32(Native as u32); - let a_val = local.writes_aux.prev_data()[0]; - // if is_jal, write pc + DEFAULT_PC_STEP, else if is_range_check, read a_val. - let write_val = local.is_jal - * (local.state.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_STEP)) - + local.is_range_check * a_val; - self.memory_bridge - .write( - MemoryAddress::new(d.clone(), local.a_pointer), - [write_val], - local.state.timestamp, - &local.writes_aux, - ) - .eval(builder, is_valid.clone()); - - let opcode = local.is_jal - * AB::F::from_canonical_usize(NativeJalOpcode::JAL.global_opcode().as_usize()) - + local.is_range_check - * AB::F::from_canonical_usize( - NativeRangeCheckOpcode::RANGE_CHECK - .global_opcode() - .as_usize(), - ); - // Increment pc by b if is_jal, else by DEFAULT_PC_STEP if is_range_check. - let pc_inc = local.is_jal * local.b - + local.is_range_check * AB::F::from_canonical_u32(DEFAULT_PC_STEP); - builder.when(local.is_jal).assert_zero(local.c); - self.execution_bridge - .execute_and_increment_or_set_pc( - opcode, - [local.a_pointer.into(), local.b.into(), local.c.into(), d], - local.state, - AB::F::ONE, - PcIncOrSet::Inc(pc_inc), - ) - .eval(builder, is_valid); - - // Range check specific: - // a_val = x + y * (1 << 16) - let x = a_val - local.y * AB::Expr::from_canonical_u32(1 << 16); - self.range_bus - .send(x.clone(), local.b) - .eval(builder, local.is_range_check); - // Assert y < (1 << c), where c <= 14. - self.range_bus - .send(local.y, local.c) - .eval(builder, local.is_range_check); - } -} - -impl JalRangeCheckAir { - fn new( - execution_bridge: ExecutionBridge, - memory_bridge: MemoryBridge, - range_bus: VariableRangeCheckerBus, - ) -> Self { - Self { - execution_bridge, - memory_bridge, - range_bus, - } - } -} - -#[repr(C)] -#[derive(Serialize, Deserialize)] -pub struct JalRangeCheckRecord { - pub state: ExecutionState, - pub a_rw: RecordId, - pub b: u32, - pub c: u8, - pub is_jal: bool, -} - -/// Chip for JAL and RANGE_CHECK. These opcodes are logically irrelevant. Putting these opcodes into -/// the same chip is just to save columns. -pub struct JalRangeCheckChip { - air: JalRangeCheckAir, - pub records: Vec, - offline_memory: Arc>>, - range_checker_chip: SharedVariableRangeCheckerChip, - /// If true, ignore execution errors. - debug: bool, -} - -impl JalRangeCheckChip { - pub fn new( - execution_bridge: ExecutionBridge, - offline_memory: Arc>>, - range_checker_chip: SharedVariableRangeCheckerChip, - ) -> Self { - let memory_bridge = offline_memory.lock().unwrap().memory_bridge(); - let air = JalRangeCheckAir::new(execution_bridge, memory_bridge, range_checker_chip.bus()); - Self { - air, - records: vec![], - offline_memory, - range_checker_chip, - debug: false, - } - } - pub fn with_debug(mut self) -> Self { - self.debug = true; - self - } -} - -impl InstructionExecutor for JalRangeCheckChip { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - if instruction.opcode == NativeJalOpcode::JAL.global_opcode() { - let (record_id, _) = memory.write( - F::from_canonical_u32(AS::Native as u32), - instruction.a, - [F::from_canonical_u32(from_state.pc + DEFAULT_PC_STEP)], - ); - let b = instruction.b.as_canonical_u32(); - self.records.push(JalRangeCheckRecord { - state: from_state, - a_rw: record_id, - b, - c: 0, - is_jal: true, - }); - return Ok(ExecutionState { - pc: (F::from_canonical_u32(from_state.pc) + instruction.b).as_canonical_u32(), - timestamp: memory.timestamp(), - }); - } else if instruction.opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { - let d = F::from_canonical_u32(AS::Native as u32); - // This is a read, but we make the record have prev_data - let a_val = memory.unsafe_read_cell(d, instruction.a); - let (record_id, _) = memory.write(d, instruction.a, [a_val]); - let a_val = a_val.as_canonical_u32(); - let b = instruction.b.as_canonical_u32(); - let c = instruction.c.as_canonical_u32(); - debug_assert!(!self.debug || b <= 16); - debug_assert!(!self.debug || c <= 14); - let x = a_val & ((1 << 16) - 1); - if !self.debug && x >= 1 << b { - return Err(ExecutionError::Fail { pc: from_state.pc }); - } - let y = a_val >> 16; - if !self.debug && y >= 1 << c { - return Err(ExecutionError::Fail { pc: from_state.pc }); - } - self.records.push(JalRangeCheckRecord { - state: from_state, - a_rw: record_id, - b, - c: c as u8, - is_jal: false, - }); - return Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }); - } - panic!("Unknown opcode {}", instruction.opcode); - } - - fn get_opcode_name(&self, opcode: usize) -> String { - let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize(); - let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK - .global_opcode() - .as_usize(); - if opcode == jal_opcode { - return String::from("JAL"); - } - if opcode == range_check_opcode { - return String::from("RANGE_CHECK"); - } - panic!("Unknown opcode {}", opcode); - } -} - -impl ChipUsageGetter for JalRangeCheckChip { - fn air_name(&self) -> String { - "JalRangeCheck".to_string() - } - - fn current_trace_height(&self) -> usize { - self.records.len() - } - - fn trace_width(&self) -> usize { - OVERALL_WIDTH - } -} - -impl Chip for JalRangeCheckChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air) - } - fn generate_air_proof_input(self) -> AirProofInput { - let height = next_power_of_two_or_zero(self.records.len()); - let mut flat_trace = Val::::zero_vec(OVERALL_WIDTH * height); - let memory = self.offline_memory.lock().unwrap(); - let aux_cols_factory = memory.aux_cols_factory(); - - self.records - .into_par_iter() - .zip(flat_trace.par_chunks_mut(OVERALL_WIDTH)) - .for_each(|(record, slice)| { - record_to_row( - record, - &aux_cols_factory, - self.range_checker_chip.as_ref(), - slice, - &memory, - ); - }); - - let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH); - AirProofInput::simple_no_pis(matrix) - } -} - -fn record_to_row( - record: JalRangeCheckRecord, - aux_cols_factory: &MemoryAuxColsFactory, - range_checker_chip: &VariableRangeCheckerChip, - slice: &mut [F], - memory: &OfflineMemory, -) { - let a_record = memory.record_by_id(record.a_rw); - let col: &mut JalRangeCheckCols<_> = slice.borrow_mut(); - col.is_jal = F::from_bool(record.is_jal); - col.is_range_check = F::from_bool(!record.is_jal); - col.a_pointer = a_record.pointer; - col.state = ExecutionState { - pc: F::from_canonical_u32(record.state.pc), - timestamp: F::from_canonical_u32(record.state.timestamp), - }; - aux_cols_factory.generate_write_aux(a_record, &mut col.writes_aux); - col.b = F::from_canonical_u32(record.b); - if !record.is_jal { - let a_val = a_record.data_at(0); - let a_val_u32 = a_val.as_canonical_u32(); - let y = a_val_u32 >> 16; - let x = a_val_u32 & ((1 << 16) - 1); - range_checker_chip.add_count(x, record.b as usize); - range_checker_chip.add_count(y, record.c as usize); - col.c = F::from_canonical_u32(record.c as u32); - col.y = F::from_canonical_u32(y); - } -} diff --git a/extensions/native/circuit/src/jal/tests.rs b/extensions/native/circuit/src/jal/tests.rs deleted file mode 100644 index dd56b73c8f..0000000000 --- a/extensions/native/circuit/src/jal/tests.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::borrow::BorrowMut; - -use openvm_circuit::arch::{testing::VmChipTestBuilder, ExecutionBridge}; -use openvm_instructions::{ - instruction::Instruction, - program::{DEFAULT_PC_STEP, PC_BITS}, - LocalOpcode, -}; -use openvm_native_compiler::{NativeJalOpcode::*, NativeRangeCheckOpcode::RANGE_CHECK}; -use openvm_stark_backend::{ - p3_field::{FieldAlgebra, PrimeField32}, - utils::disable_debug_builder, - verifier::VerificationError, - Chip, -}; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::{rngs::StdRng, Rng}; - -use crate::{jal::JalRangeCheckCols, JalRangeCheckChip}; -type F = BabyBear; - -fn set_and_execute( - tester: &mut VmChipTestBuilder, - chip: &mut JalRangeCheckChip, - rng: &mut StdRng, - initial_imm: Option, - initial_pc: Option, -) { - let imm = initial_imm.unwrap_or(rng.gen_range(0..20)); - let a = rng.gen_range(0..32) << 2; - let d = 4usize; - - tester.execute_with_pc( - chip, - &Instruction::from_usize(JAL.global_opcode(), [a, imm as usize, 0, d, 0, 0, 0]), - initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), - ); - let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); - let final_pc = tester.execution.last_to_pc().as_canonical_u32(); - - let next_pc = initial_pc + imm; - let rd_data = initial_pc + DEFAULT_PC_STEP; - - assert_eq!(next_pc, final_pc); - assert_eq!(rd_data, tester.read::<1>(d, a)[0].as_canonical_u32()); -} - -struct RangeCheckTestCase { - val: u32, - x_bit: u32, - y_bit: u32, -} - -fn set_and_execute_range_check( - tester: &mut VmChipTestBuilder, - chip: &mut JalRangeCheckChip, - rng: &mut StdRng, - test_cases: Vec, -) { - let a = rng.gen_range(0..32) << 2; - for RangeCheckTestCase { val, x_bit, y_bit } in test_cases { - let d = 4usize; - - tester.write_cell(d, a, F::from_canonical_u32(val)); - tester.execute_with_pc( - chip, - &Instruction::from_usize( - RANGE_CHECK.global_opcode(), - [a, x_bit as usize, y_bit as usize, d, 0, 0, 0], - ), - rng.gen_range(0..(1 << PC_BITS)), - ); - } -} - -fn setup() -> (StdRng, VmChipTestBuilder, JalRangeCheckChip) { - let rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - let execution_bridge = ExecutionBridge::new(tester.execution_bus(), tester.program_bus()); - let offline_memory = tester.offline_memory_mutex_arc(); - let range_checker = tester.range_checker(); - let chip = JalRangeCheckChip::::new(execution_bridge, offline_memory, range_checker); - (rng, tester, chip) -} - -#[test] -fn rand_jal_test() { - let (mut rng, mut tester, mut chip) = setup(); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, None, None); - } - - let tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn rand_range_check_test() { - let (mut rng, mut tester, mut chip) = setup(); - let f = |x: u32, y: u32| RangeCheckTestCase { - val: x + y * (1 << 16), - x_bit: 32 - x.leading_zeros(), - y_bit: 32 - y.leading_zeros(), - }; - let mut test_cases: Vec<_> = (0..10) - .map(|_| { - let x = 0; - let y = rng.gen_range(0..1 << 14); - f(x, y) - }) - .collect(); - test_cases.extend((0..10).map(|_| { - let x = rng.gen_range(0..1 << 16); - let y = 0; - f(x, y) - })); - test_cases.extend((0..10).map(|_| { - let x = rng.gen_range(0..1 << 16); - let y = rng.gen_range(0..1 << 14); - f(x, y) - })); - test_cases.push(f((1 << 16) - 1, (1 << 14) - 1)); - set_and_execute_range_check(&mut tester, &mut chip, &mut rng, test_cases); - let tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn negative_range_check_test() { - { - let (mut rng, mut tester, chip) = setup(); - let mut chip = chip.with_debug(); - set_and_execute_range_check( - &mut tester, - &mut chip, - &mut rng, - vec![RangeCheckTestCase { - x_bit: 1, - y_bit: 1, - val: 2, - }], - ); - let tester = tester.build().load(chip).finalize(); - disable_debug_builder(); - let result = tester.simple_test(); - assert!(result.is_err()); - } - { - let (mut rng, mut tester, chip) = setup(); - let mut chip = chip.with_debug(); - set_and_execute_range_check( - &mut tester, - &mut chip, - &mut rng, - vec![RangeCheckTestCase { - x_bit: 1, - y_bit: 0, - val: 1 << 16, - }], - ); - let tester = tester.build().load(chip).finalize(); - disable_debug_builder(); - let result = tester.simple_test(); - assert!(result.is_err()); - } -} - -#[test] -fn negative_jal_test() { - let (mut rng, mut tester, mut chip) = setup(); - set_and_execute(&mut tester, &mut chip, &mut rng, None, None); - - let tester = tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jal_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let col: &mut JalRangeCheckCols<_> = jal_trace.row_mut(0).borrow_mut(); - col.b = F::from_canonical_u32(rng.gen_range(1 << 11..1 << 12)); - } - disable_debug_builder(); - let tester = tester - .load_air_proof_input((chip_air, chip_input)) - .finalize(); - let msg = format!( - "Expected verification to fail with {:?}, but it didn't", - VerificationError::ChallengePhaseError - ); - let result = tester.simple_test(); - assert_eq!( - result.err(), - Some(VerificationError::ChallengePhaseError), - "{}", - msg - ); -} diff --git a/extensions/native/circuit/src/jal_rangecheck/mod.rs b/extensions/native/circuit/src/jal_rangecheck/mod.rs new file mode 100644 index 0000000000..f6e28c9f2d --- /dev/null +++ b/extensions/native/circuit/src/jal_rangecheck/mod.rs @@ -0,0 +1,508 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + ops::Deref, +}; + +use openvm_circuit::{ + arch::*, + system::{ + memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols, MemoryWriteAuxRecord}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, + }, + native_adapter::util::{memory_read_native, tracing_write_native}, + }, +}; +use openvm_circuit_primitives::{ + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_native_compiler::{conversion::AS, NativeJalOpcode, NativeRangeCheckOpcode}; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; +use static_assertions::const_assert_eq; +use AS::Native; + +#[cfg(test)] +mod tests; + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct JalRangeCheckCols { + is_jal: T, + is_range_check: T, + a_pointer: T, + state: ExecutionState, + // Write when is_jal, read when is_range_check. + writes_aux: MemoryWriteAuxCols, + b: T, + // Only used by range check. + c: T, + // Only used by range check. + y: T, +} + +const OVERALL_WIDTH: usize = JalRangeCheckCols::::width(); +const_assert_eq!(OVERALL_WIDTH, 12); + +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct JalRangeCheckAir { + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + range_bus: VariableRangeCheckerBus, +} + +impl BaseAir for JalRangeCheckAir { + fn width(&self) -> usize { + OVERALL_WIDTH + } +} + +impl BaseAirWithPublicValues for JalRangeCheckAir {} +impl PartitionedBaseAir for JalRangeCheckAir {} +impl Air for JalRangeCheckAir +where + AB::F: PrimeField32, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_slice = local.deref(); + let local: &JalRangeCheckCols = local_slice.borrow(); + builder.assert_bool(local.is_jal); + builder.assert_bool(local.is_range_check); + let is_valid = local.is_jal + local.is_range_check; + builder.assert_bool(is_valid.clone()); + + let d = AB::Expr::from_canonical_u32(Native as u32); + let a_val = local.writes_aux.prev_data()[0]; + // if is_jal, write pc + DEFAULT_PC_STEP, else if is_range_check, read a_val. + let write_val = local.is_jal + * (local.state.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_STEP)) + + local.is_range_check * a_val; + self.memory_bridge + .write( + MemoryAddress::new(d.clone(), local.a_pointer), + [write_val], + local.state.timestamp, + &local.writes_aux, + ) + .eval(builder, is_valid.clone()); + + let opcode = local.is_jal + * AB::F::from_canonical_usize(NativeJalOpcode::JAL.global_opcode().as_usize()) + + local.is_range_check + * AB::F::from_canonical_usize( + NativeRangeCheckOpcode::RANGE_CHECK + .global_opcode() + .as_usize(), + ); + // Increment pc by b if is_jal, else by DEFAULT_PC_STEP if is_range_check. + let pc_inc = local.is_jal * local.b + + local.is_range_check * AB::F::from_canonical_u32(DEFAULT_PC_STEP); + builder.when(local.is_jal).assert_zero(local.c); + self.execution_bridge + .execute_and_increment_or_set_pc( + opcode, + [local.a_pointer.into(), local.b.into(), local.c.into(), d], + local.state, + AB::F::ONE, + PcIncOrSet::Inc(pc_inc), + ) + .eval(builder, is_valid); + + // Range check specific: + // a_val = x + y * (1 << 16) + let x = a_val - local.y * AB::Expr::from_canonical_u32(1 << 16); + self.range_bus + .send(x.clone(), local.b) + .eval(builder, local.is_range_check); + // Assert y < (1 << c), where c <= 14. + self.range_bus + .send(local.y, local.c) + .eval(builder, local.is_range_check); + } +} + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct JalRangeCheckRecord { + pub is_jal: bool, + pub a: F, + pub from_pc: u32, + pub from_timestamp: u32, + pub write: MemoryWriteAuxRecord, + pub b: F, + pub c: F, +} + +/// Chip for JAL and RANGE_CHECK. These opcodes are logically irrelevant. Putting these opcodes into +/// the same chip is just to save columns. +#[derive(derive_new::new, Clone, Copy)] +pub struct JalRangeCheckExecutor; + +#[derive(derive_new::new)] +pub struct JalRangeCheckFiller { + range_checker_chip: SharedVariableRangeCheckerChip, +} +pub type NativeJalRangeCheckChip = VmChipWrapper; + +impl PreflightExecutor for JalRangeCheckExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, EmptyMultiRowLayout, &'buf mut JalRangeCheckRecord>, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize(); + let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK + .global_opcode() + .as_usize(); + if opcode == jal_opcode { + return String::from("JAL"); + } + if opcode == range_check_opcode { + return String::from("RANGE_CHECK"); + } + panic!("Unknown opcode {opcode}"); + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode, a, b, c, .. + } = instruction; + + debug_assert!( + opcode == NativeJalOpcode::JAL.global_opcode() + || opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() + ); + + let record = state.ctx.alloc(EmptyMultiRowLayout::default()); + + record.from_pc = *state.pc; + record.from_timestamp = state.memory.timestamp; + + record.a = a; + record.b = b; + + if opcode == NativeJalOpcode::JAL.global_opcode() { + record.is_jal = true; + record.c = F::ZERO; + + tracing_write_native( + state.memory, + a.as_canonical_u32(), + [F::from_canonical_u32( + state.pc.wrapping_add(DEFAULT_PC_STEP), + )], + &mut record.write.prev_timestamp, + &mut record.write.prev_data, + ); + *state.pc = (F::from_canonical_u32(*state.pc) + b).as_canonical_u32(); + } else if opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { + record.is_jal = false; + record.c = c; + + let a_ptr = a.as_canonical_u32(); + let [a_val]: [F; 1] = memory_read_native(state.memory.data(), a_ptr); + tracing_write_native( + state.memory, + a_ptr, + [a_val], + &mut record.write.prev_timestamp, + &mut record.write.prev_data, + ); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + } + + Ok(()) + } +} + +impl TraceFiller for JalRangeCheckFiller { + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut row_slice: &mut [F]) { + let record: &mut JalRangeCheckRecord = + unsafe { get_record_from_slice(&mut row_slice, ()) }; + let cols: &mut JalRangeCheckCols = row_slice.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + if record.is_jal { + cols.y = F::ZERO; + cols.c = F::ZERO; + cols.b = record.b; + cols.writes_aux.set_prev_data(record.write.prev_data); + mem_helper.fill( + record.write.prev_timestamp, + record.from_timestamp, + cols.writes_aux.as_mut(), + ); + cols.state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.state.pc = F::from_canonical_u32(record.from_pc); + cols.a_pointer = record.a; + cols.is_range_check = F::ZERO; + cols.is_jal = F::ONE; + } else { + let a_val = record.write.prev_data[0].as_canonical_u32(); + let b = record.b.as_canonical_u32(); + let c = record.c.as_canonical_u32(); + let x = a_val & 0xffff; + let y = a_val >> 16; + #[cfg(debug_assertions)] + { + assert!(b <= 16); + assert!(c <= 14); + assert!(x < (1 << b)); + assert!(y < (1 << c)); + } + + self.range_checker_chip.add_count(x, b as usize); + self.range_checker_chip.add_count(y, c as usize); + + cols.y = F::from_canonical_u32(y); + cols.c = record.c; + cols.b = record.b; + cols.writes_aux.set_prev_data(record.write.prev_data); + mem_helper.fill( + record.write.prev_timestamp, + record.from_timestamp, + cols.writes_aux.as_mut(), + ); + cols.state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.state.pc = F::from_canonical_u32(record.from_pc); + cols.a_pointer = record.a; + cols.is_range_check = F::ONE; + cols.is_jal = F::ZERO; + } + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct JalPreCompute { + a: u32, + b: F, + return_pc: F, +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct RangeCheckPreCompute { + a: u32, + b: u8, + c: u8, +} + +impl JalRangeCheckExecutor { + #[inline(always)] + fn pre_compute_jal_impl( + &self, + pc: u32, + inst: &Instruction, + jal_data: &mut JalPreCompute, + ) -> Result<(), StaticProgramError> { + let &Instruction { opcode, a, b, .. } = inst; + + if opcode != NativeJalOpcode::JAL.global_opcode() { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let return_pc = F::from_canonical_u32(pc.wrapping_add(DEFAULT_PC_STEP)); + + *jal_data = JalPreCompute { a, b, return_pc }; + Ok(()) + } + + #[inline(always)] + fn pre_compute_range_check_impl( + &self, + pc: u32, + inst: &Instruction, + range_check_data: &mut RangeCheckPreCompute, + ) -> Result<(), StaticProgramError> { + let &Instruction { + opcode, a, b, c, .. + } = inst; + + if opcode != NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + if b > 16 || c > 14 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *range_check_data = RangeCheckPreCompute { + a, + b: b as u8, + c: c as u8, + }; + Ok(()) + } +} + +impl Executor for JalRangeCheckExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>(), + size_of::(), + ) + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let &Instruction { opcode, .. } = inst; + + let is_jal = opcode == NativeJalOpcode::JAL.global_opcode(); + + if is_jal { + let jal_data: &mut JalPreCompute = data.borrow_mut(); + self.pre_compute_jal_impl(pc, inst, jal_data)?; + Ok(execute_jal_e1_impl) + } else { + let range_check_data: &mut RangeCheckPreCompute = data.borrow_mut(); + self.pre_compute_range_check_impl(pc, inst, range_check_data)?; + Ok(execute_range_check_e1_impl) + } + } +} + +impl MeteredExecutor for JalRangeCheckExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>>(), + size_of::>(), + ) + } + + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let &Instruction { opcode, .. } = inst; + + let is_jal = opcode == NativeJalOpcode::JAL.global_opcode(); + + if is_jal { + let pre_compute: &mut E2PreCompute> = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_jal_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_jal_e2_impl) + } else { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_range_check_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_range_check_e2_impl) + } + } +} + +unsafe fn execute_jal_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &JalPreCompute = pre_compute.borrow(); + execute_jal_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_jal_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_jal_e12_impl(&pre_compute.data, vm_state); +} + +unsafe fn execute_range_check_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &RangeCheckPreCompute = pre_compute.borrow(); + execute_range_check_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_range_check_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_range_check_e12_impl(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_jal_e12_impl( + pre_compute: &JalPreCompute, + vm_state: &mut VmExecState, +) { + vm_state.vm_write(AS::Native as u32, pre_compute.a, &[pre_compute.return_pc]); + // TODO(ayush): better way to do this + vm_state.pc = (F::from_canonical_u32(vm_state.pc) + pre_compute.b).as_canonical_u32(); + vm_state.instret += 1; +} + +#[inline(always)] +unsafe fn execute_range_check_e12_impl( + pre_compute: &RangeCheckPreCompute, + vm_state: &mut VmExecState, +) { + let [a_val]: [F; 1] = vm_state.host_read(AS::Native as u32, pre_compute.a); + + vm_state.vm_write(AS::Native as u32, pre_compute.a, &[a_val]); + { + let a_val = a_val.as_canonical_u32(); + let b = pre_compute.b; + let c = pre_compute.c; + let x = a_val & 0xffff; + let y = a_val >> 16; + + // The range of `b`,`c` had already been checked in `pre_compute_e1`. + if !(x < (1 << b) && y < (1 << c)) { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "NativeRangeCheck", + }); + return; + } + } + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} diff --git a/extensions/native/circuit/src/jal_rangecheck/tests.rs b/extensions/native/circuit/src/jal_rangecheck/tests.rs new file mode 100644 index 0000000000..6109b9373d --- /dev/null +++ b/extensions/native/circuit/src/jal_rangecheck/tests.rs @@ -0,0 +1,307 @@ +use std::borrow::BorrowMut; + +use openvm_circuit::arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + LocalOpcode, VmOpcode, +}; +use openvm_native_compiler::{ + conversion::AS, NativeJalOpcode::*, NativeRangeCheckOpcode::RANGE_CHECK, +}; +use openvm_stark_backend::{ + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; + +use super::{JalRangeCheckAir, JalRangeCheckExecutor}; +use crate::{ + jal_rangecheck::{JalRangeCheckCols, NativeJalRangeCheckChip}, + test_utils::write_native_array, + JalRangeCheckFiller, +}; + +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; +type Harness = + TestChipHarness>; + +fn create_test_chip(tester: &VmChipTestBuilder) -> Harness { + let range_checker = tester.range_checker().clone(); + let air = JalRangeCheckAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + range_checker.bus(), + ); + let executor = JalRangeCheckExecutor::new(); + let chip = NativeJalRangeCheckChip::::new( + JalRangeCheckFiller::new(range_checker), + tester.memory_helper(), + ); + + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) +} + +// `a_val` and `c` will be disregarded if opcode is JAL +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: VmOpcode, + a_val: Option, + b: Option, + c: Option, +) { + if opcode == JAL.global_opcode() { + let initial_pc = rng.gen_range(0..(1 << PC_BITS)); + let a = gen_pointer(rng, 1); + let final_pc = F::from_canonical_u32(rng.gen_range(0..(1 << PC_BITS))); + let b = b.unwrap_or((final_pc - F::from_canonical_u32(initial_pc)).as_canonical_u32()); + tester.execute_with_pc( + harness, + &Instruction::from_usize(opcode, [a, b as usize, 0, AS::Native as usize, 0, 0, 0]), + initial_pc, + ); + + let final_pc = tester.execution.last_to_pc(); + let expected_final_pc = F::from_canonical_u32(initial_pc) + F::from_canonical_u32(b); + assert_eq!(final_pc, expected_final_pc); + let result_a_val = tester.read::<1>(AS::Native as usize, a)[0].as_canonical_u32(); + let expected_a_val = initial_pc + DEFAULT_PC_STEP; + assert_eq!(result_a_val, expected_a_val); + } else { + let a_val = a_val.unwrap_or(rng.gen_range(0..(1 << 30))); + let a = write_native_array(tester, rng, Some([F::from_canonical_u32(a_val)])).1; + let x = a_val & 0xffff; + let y = a_val >> 16; + + let min_b = 32 - x.leading_zeros(); + let min_c = 32 - y.leading_zeros(); + let b = b.unwrap_or(rng.gen_range(min_b..=16)); + let c = c.unwrap_or(rng.gen_range(min_c..=14)); + tester.execute( + harness, + &Instruction::from_usize( + opcode, + [a, b as usize, c as usize, AS::Native as usize, 0, 0, 0], + ), + ); + // There is nothing to assert for range check since it doesn't write to the memory + }; +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(JAL.global_opcode(), 100)] +#[test_case(RANGE_CHECK.global_opcode(), 100)] +fn rand_jal_range_check_test(opcode: VmOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); + + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + None, + None, + None, + ); + } + let tester = tester.build().load(harness).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn range_check_edge_cases_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + RANGE_CHECK.global_opcode(), + Some(0), + None, + None, + ); + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + RANGE_CHECK.global_opcode(), + Some((1 << 30) - 1), + None, + None, + ); + + // x = 0 + let a = rng.gen_range(0..(1 << 14)) << 16; + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + RANGE_CHECK.global_opcode(), + Some(a), + None, + None, + ); + + // y = 0 + let a = rng.gen_range(0..(1 << 16)); + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + RANGE_CHECK.global_opcode(), + Some(a), + None, + None, + ); + + let tester = tester.build().load(harness).finalize(); + tester.simple_test().expect("Verification failed"); +} + +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Clone, Copy, Default)] +struct JalRangeCheckPrankValues { + pub flags: Option<[bool; 2]>, + pub a_val: Option, + pub b: Option, + pub c: Option, + pub y: Option, +} + +fn run_negative_jal_range_check_test( + opcode: VmOpcode, + a_val: Option, + b: Option, + c: Option, + prank_vals: JalRangeCheckPrankValues, + error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip(&tester); + set_and_execute(&mut tester, &mut harness, &mut rng, opcode, a_val, b, c); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut JalRangeCheckCols = values[..].borrow_mut(); + + if let Some(flags) = prank_vals.flags { + cols.is_jal = F::from_bool(flags[0]); + cols.is_range_check = F::from_bool(flags[1]); + } + if let Some(a_val) = prank_vals.a_val { + cols.writes_aux + .set_prev_data([F::from_canonical_u32(a_val)]); + } + + if let Some(b) = prank_vals.b { + cols.b = F::from_canonical_u32(b); + } + if let Some(c) = prank_vals.c { + cols.c = F::from_canonical_u32(c); + } + if let Some(y) = prank_vals.y { + cols.y = F::from_canonical_u32(y); + } + + *trace = RowMajorMatrix::new(values, trace.width()); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); +} + +#[test] +fn negative_range_check_test() { + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some(2), + Some(2), + Some(1), + JalRangeCheckPrankValues { + b: Some(1), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some(1 << 16), + None, + None, + JalRangeCheckPrankValues { + c: Some(0), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some((1 << 30) - 1), + None, + None, + JalRangeCheckPrankValues { + a_val: Some(1 << 30), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some(1 << 17), + None, + None, + JalRangeCheckPrankValues { + y: Some(1), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); +} + +#[test] +fn negative_jal_test() { + run_negative_jal_range_check_test( + JAL.global_opcode(), + None, + None, + None, + JalRangeCheckPrankValues { + b: Some(0), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); +} diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index 46c6bc890f..17839a1975 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -1,3 +1,22 @@ +use openvm_circuit::{ + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, MemoryConfig, + SystemConfig, VmBuilder, VmChipComplex, VmProverExtension, + }, + system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, +}; +use openvm_circuit_derive::VmConfig; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; +use serde::{Deserialize, Serialize}; + pub mod adapters; mod branch_eq; @@ -5,7 +24,7 @@ mod castf; mod field_arithmetic; mod field_extension; mod fri; -mod jal; +mod jal_rangecheck; mod loadstore; mod poseidon2; @@ -14,7 +33,7 @@ pub use castf::*; pub use field_arithmetic::*; pub use field_extension::*; pub use fri::*; -pub use jal::*; +pub use jal_rangecheck::*; pub use loadstore::*; pub use poseidon2::*; @@ -22,4 +41,133 @@ mod extension; pub use extension::*; mod utils; -pub use utils::*; +#[cfg(any(test, feature = "test-utils"))] +pub use utils::test_utils::*; +pub(crate) use utils::*; + +#[derive(Clone, Debug, derive_new::new, VmConfig, Serialize, Deserialize)] +pub struct NativeConfig { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension(generics = true)] + pub native: Native, +} + +impl NativeConfig { + pub fn aggregation(num_public_values: usize, max_constraint_degree: usize) -> Self { + Self { + system: SystemConfig::new( + max_constraint_degree, + MemoryConfig::aggregation(), + num_public_values, + ) + .with_max_segment_len((1 << 24) - 100), + native: Default::default(), + } + } +} + +// Default implementation uses no init file +impl InitFileGenerator for NativeConfig {} + +#[derive(Clone, Default)] +pub struct NativeCpuBuilder; + +impl VmBuilder for NativeCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = NativeConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &NativeConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover( + &NativeCpuProverExt, + &config.native, + inventory, + )?; + Ok(chip_complex) + } +} + +#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] +pub struct Rv32WithKernelsConfig { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension(generics = true)] + pub native: Native, + #[extension] + pub castf: CastFExtension, +} + +impl Default for Rv32WithKernelsConfig { + fn default() -> Self { + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + native: Native, + castf: CastFExtension, + } + } +} + +// Default implementation uses no init file +impl InitFileGenerator for Rv32WithKernelsConfig {} + +#[derive(Clone)] +pub struct Rv32WithKernelsCpuBuilder; + +impl VmBuilder for Rv32WithKernelsCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Rv32WithKernelsConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Rv32WithKernelsConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover( + &NativeCpuProverExt, + &config.native, + inventory, + )?; + VmProverExtension::::extend_prover(&NativeCpuProverExt, &config.castf, inventory)?; + Ok(chip_complex) + } +} diff --git a/extensions/native/circuit/src/loadstore/core.rs b/extensions/native/circuit/src/loadstore/core.rs index 094a57dccc..6e9b3497cd 100644 --- a/extensions/native/circuit/src/loadstore/core.rs +++ b/extensions/native/circuit/src/loadstore/core.rs @@ -1,27 +1,28 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::{Arc, Mutex, OnceLock}, }; -use openvm_circuit::arch::{ - instructions::LocalOpcode, AdapterAirContext, AdapterRuntimeContext, ExecutionError, Result, - Streams, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::instruction::Instruction; -use openvm_native_compiler::NativeLoadStoreOpcode; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_native_compiler::{conversion::AS, NativeLoadStoreOpcode}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; -use super::super::adapters::loadstore_native_adapter::NativeLoadStoreInstruction; +use crate::adapters::NativeLoadStoreInstruction; #[repr(C)] #[derive(AlignedBorrow)] @@ -34,17 +35,7 @@ pub struct NativeLoadStoreCoreCols { pub data: [T; NUM_CELLS], } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NativeLoadStoreCoreRecord { - pub opcode: NativeLoadStoreOpcode, - - pub pointer_read: F, - #[serde(with = "BigArray")] - pub data: [F; NUM_CELLS], -} - -#[derive(Clone, Debug)] +#[derive(Clone, Debug, derive_new::new)] pub struct NativeLoadStoreCoreAir { pub offset: usize, } @@ -113,89 +104,320 @@ where } } -pub struct NativeLoadStoreCoreChip { - pub air: NativeLoadStoreCoreAir, - pub streams: OnceLock>>>, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeLoadStoreCoreRecord { + pub pointer_read: F, + pub data: [F; NUM_CELLS], + pub local_opcode: u8, } -impl NativeLoadStoreCoreChip { - pub fn new(offset: usize) -> Self { - Self { - air: NativeLoadStoreCoreAir:: { offset }, - streams: OnceLock::new(), - } - } - pub fn set_streams(&mut self, streams: Arc>>) { - self.streams - .set(streams) - .map_err(|_| "streams have already been set.") - .unwrap(); - } +#[derive(derive_new::new, Debug, Clone, Copy)] +pub struct NativeLoadStoreCoreExecutor { + adapter: A, + offset: usize, } -impl Default for NativeLoadStoreCoreChip { - fn default() -> Self { - Self::new(NativeLoadStoreOpcode::CLASS_OFFSET) - } +#[derive(derive_new::new)] +pub struct NativeLoadStoreCoreFiller { + adapter: A, } -impl, const NUM_CELLS: usize> VmCoreChip - for NativeLoadStoreCoreChip +impl PreflightExecutor + for NativeLoadStoreCoreExecutor where - I::Reads: Into<(F, [F; NUM_CELLS])>, - I::Writes: From<[F; NUM_CELLS]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut NativeLoadStoreCoreRecord, + ), + >, { - type Record = NativeLoadStoreCoreRecord; - type Air = NativeLoadStoreCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + NativeLoadStoreOpcode::from_usize(opcode - self.offset) + ) + } - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = *instruction; - let local_opcode = - NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let (pointer_read, data_read) = reads.into(); - - let data = if local_opcode == NativeLoadStoreOpcode::HINT_STOREW { - let mut streams = self.streams.get().unwrap().lock().unwrap(); - if streams.hint_stream.len() < NUM_CELLS { - return Err(ExecutionError::HintOutOfBounds { pc: from_pc }); + ) -> Result<(), ExecutionError> { + let &Instruction { opcode, .. } = instruction; + + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let (pointer_read, data_read) = + self.adapter + .read(state.memory, instruction, &mut adapter_record); + + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + let opcode = NativeLoadStoreOpcode::from_usize(core_record.local_opcode as usize); + + let data = if opcode == NativeLoadStoreOpcode::HINT_STOREW { + if state.streams.hint_stream.len() < NUM_CELLS { + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); } - array::from_fn(|_| streams.hint_stream.pop_front().unwrap()) + array::from_fn(|_| state.streams.hint_stream.pop_front().unwrap()) } else { data_read }; - let output = AdapterRuntimeContext::without_pc(data); - let record = NativeLoadStoreCoreRecord { - opcode: NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)), - pointer_read, - data, + self.adapter + .write(state.memory, instruction, data, &mut adapter_record); + + core_record.pointer_read = pointer_read; + core_record.data = data; + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for NativeLoadStoreCoreFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &NativeLoadStoreCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut NativeLoadStoreCoreCols = core_row.borrow_mut(); + + let opcode = NativeLoadStoreOpcode::from_usize(record.local_opcode as usize); + + // Writing in reverse order to avoid overwriting the `record` + core_row.data = record.data; + core_row.pointer_read = record.pointer_read; + core_row.is_hint_storew = F::from_bool(opcode == NativeLoadStoreOpcode::HINT_STOREW); + core_row.is_storew = F::from_bool(opcode == NativeLoadStoreOpcode::STOREW); + core_row.is_loadw = F::from_bool(opcode == NativeLoadStoreOpcode::LOADW); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeLoadStorePreCompute { + a: u32, + b: F, + c: u32, +} + +impl NativeLoadStoreCoreExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut NativeLoadStorePreCompute, + ) -> Result { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let a = a.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + if d != AS::Native as u32 || e != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *data = NativeLoadStorePreCompute { a, b, c }; + + Ok(local_opcode) + } +} + +impl Executor for NativeLoadStoreCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut NativeLoadStorePreCompute = data.borrow_mut(); + + let local_opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match local_opcode { + NativeLoadStoreOpcode::LOADW => execute_e1_loadw::, + NativeLoadStoreOpcode::STOREW => execute_e1_storew::, + NativeLoadStoreOpcode::HINT_STOREW => execute_e1_hint_storew::, }; - Ok((output, record)) + + Ok(fn_ptr) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - NativeLoadStoreOpcode::from_usize(opcode - self.air.offset) - ) +impl MeteredExecutor for NativeLoadStoreCoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let cols: &mut NativeLoadStoreCoreCols<_, NUM_CELLS> = row_slice.borrow_mut(); - cols.is_loadw = F::from_bool(record.opcode == NativeLoadStoreOpcode::LOADW); - cols.is_storew = F::from_bool(record.opcode == NativeLoadStoreOpcode::STOREW); - cols.is_hint_storew = F::from_bool(record.opcode == NativeLoadStoreOpcode::HINT_STOREW); + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute> = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let local_opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - cols.pointer_read = record.pointer_read; - cols.data = record.data; + let fn_ptr = match local_opcode { + NativeLoadStoreOpcode::LOADW => execute_e2_loadw::, + NativeLoadStoreOpcode::STOREW => execute_e2_storew::, + NativeLoadStoreOpcode::HINT_STOREW => execute_e2_hint_storew::, + }; + + Ok(fn_ptr) } +} - fn air(&self) -> &Self::Air { - &self.air +unsafe fn execute_e1_loadw( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &NativeLoadStorePreCompute = pre_compute.borrow(); + execute_e12_loadw::<_, _, NUM_CELLS>(pre_compute, vm_state); +} + +unsafe fn execute_e1_storew( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &NativeLoadStorePreCompute = pre_compute.borrow(); + execute_e12_storew::<_, _, NUM_CELLS>(pre_compute, vm_state); +} + +unsafe fn execute_e1_hint_storew( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &NativeLoadStorePreCompute = pre_compute.borrow(); + execute_e12_hint_storew::<_, _, NUM_CELLS>(pre_compute, vm_state); +} + +unsafe fn execute_e2_loadw( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_loadw::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); +} + +unsafe fn execute_e2_storew( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_storew::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); +} + +unsafe fn execute_e2_hint_storew( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_hint_storew::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_loadw( + pre_compute: &NativeLoadStorePreCompute, + vm_state: &mut VmExecState, +) { + let [read_cell]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.c); + + let data_read_ptr = (read_cell + pre_compute.b).as_canonical_u32(); + let data_read: [F; NUM_CELLS] = vm_state.vm_read(AS::Native as u32, data_read_ptr); + + vm_state.vm_write(AS::Native as u32, pre_compute.a, &data_read); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +#[inline(always)] +unsafe fn execute_e12_storew( + pre_compute: &NativeLoadStorePreCompute, + vm_state: &mut VmExecState, +) { + let [read_cell]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.c); + let data_read: [F; NUM_CELLS] = vm_state.vm_read(AS::Native as u32, pre_compute.a); + + let data_write_ptr = (read_cell + pre_compute.b).as_canonical_u32(); + vm_state.vm_write(AS::Native as u32, data_write_ptr, &data_read); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +#[inline(always)] +unsafe fn execute_e12_hint_storew( + pre_compute: &NativeLoadStorePreCompute, + vm_state: &mut VmExecState, +) { + let [read_cell]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.c); + + if vm_state.streams.hint_stream.len() < NUM_CELLS { + vm_state.exit_code = Err(ExecutionError::HintOutOfBounds { pc: vm_state.pc }); + return; } + let data: [F; NUM_CELLS] = + array::from_fn(|_| vm_state.streams.hint_stream.pop_front().unwrap()); + + let data_write_ptr = (read_cell + pre_compute.b).as_canonical_u32(); + vm_state.vm_write(AS::Native as u32, data_write_ptr, &data); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; } diff --git a/extensions/native/circuit/src/loadstore/mod.rs b/extensions/native/circuit/src/loadstore/mod.rs index 3dd51113a9..eb4ed579ae 100644 --- a/extensions/native/circuit/src/loadstore/mod.rs +++ b/extensions/native/circuit/src/loadstore/mod.rs @@ -1,19 +1,18 @@ use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -#[cfg(test)] -mod tests; +use crate::adapters::{ + NativeLoadStoreAdapterAir, NativeLoadStoreAdapterExecutor, NativeLoadStoreAdapterFiller, +}; mod core; pub use core::*; -use super::adapters::loadstore_native_adapter::{ - NativeLoadStoreAdapterAir, NativeLoadStoreAdapterChip, -}; +#[cfg(test)] +mod tests; pub type NativeLoadStoreAir = VmAirWrapper, NativeLoadStoreCoreAir>; -pub type NativeLoadStoreChip = VmChipWrapper< - F, - NativeLoadStoreAdapterChip, - NativeLoadStoreCoreChip, ->; +pub type NativeLoadStoreExecutor = + NativeLoadStoreCoreExecutor, NUM_CELLS>; +pub type NativeLoadStoreChip = + VmChipWrapper, NUM_CELLS>>; diff --git a/extensions/native/circuit/src/loadstore/tests.rs b/extensions/native/circuit/src/loadstore/tests.rs index cd653c2fc0..9c4c5ee587 100644 --- a/extensions/native/circuit/src/loadstore/tests.rs +++ b/extensions/native/circuit/src/loadstore/tests.rs @@ -1,175 +1,258 @@ -use std::sync::{Arc, Mutex}; +use std::{array, borrow::BorrowMut}; -use openvm_circuit::arch::{testing::VmChipTestBuilder, Streams}; +use openvm_circuit::arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::NativeLoadStoreOpcode::{self, *}; -use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_native_compiler::{ + conversion::AS, + NativeLoadStoreOpcode::{self, *}, +}; +use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{ - super::adapters::loadstore_native_adapter::NativeLoadStoreAdapterChip, NativeLoadStoreChip, - NativeLoadStoreCoreChip, +use super::{NativeLoadStoreChip, NativeLoadStoreCoreAir}; +use crate::{ + adapters::{ + NativeLoadStoreAdapterAir, NativeLoadStoreAdapterCols, NativeLoadStoreAdapterExecutor, + NativeLoadStoreAdapterFiller, + }, + test_utils::write_native_array, + NativeLoadStoreAir, NativeLoadStoreCoreCols, NativeLoadStoreCoreFiller, + NativeLoadStoreExecutor, }; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +type Harness = TestChipHarness< + F, + NativeLoadStoreExecutor, + NativeLoadStoreAir, + NativeLoadStoreChip, +>; + +fn create_test_chip(tester: &VmChipTestBuilder) -> Harness { + let air = NativeLoadStoreAir::new( + NativeLoadStoreAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + NativeLoadStoreCoreAir::new(NativeLoadStoreOpcode::CLASS_OFFSET), + ); + let executor = NativeLoadStoreExecutor::new( + NativeLoadStoreAdapterExecutor::new(NativeLoadStoreOpcode::CLASS_OFFSET), + NativeLoadStoreOpcode::CLASS_OFFSET, + ); + let chip = NativeLoadStoreChip::::new( + NativeLoadStoreCoreFiller::new(NativeLoadStoreAdapterFiller), + tester.memory_helper(), + ); -#[derive(Debug)] -struct TestData { - a: F, - b: F, - c: F, - d: F, - e: F, - ad_val: F, - cd_val: F, - data_val: F, - is_load: bool, - is_hint: bool, + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) } -fn setup() -> (StdRng, VmChipTestBuilder, NativeLoadStoreChip) { - let rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: NativeLoadStoreOpcode, +) { + let a = gen_pointer(rng, NUM_CELLS); + let ([c_val], c) = write_native_array(tester, rng, None); + + let mem_ptr = gen_pointer(rng, NUM_CELLS); + let b = F::from_canonical_usize(mem_ptr) - c_val; + let data: [F; NUM_CELLS] = array::from_fn(|_| rng.gen()); - let adapter = NativeLoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - NativeLoadStoreOpcode::CLASS_OFFSET, + match opcode { + LOADW => { + tester.write(AS::Native as usize, mem_ptr, data); + } + STOREW => { + tester.write(AS::Native as usize, a, data); + } + HINT_STOREW => { + tester.streams.hint_stream.extend(data); + } + } + + tester.execute( + harness, + &Instruction::from_usize( + opcode.global_opcode(), + [ + a, + b.as_canonical_u32() as usize, + c, + AS::Native as usize, + AS::Native as usize, + ], + ), ); - let mut inner = NativeLoadStoreCoreChip::new(NativeLoadStoreOpcode::CLASS_OFFSET); - inner.set_streams(Arc::new(Mutex::new(Streams::default()))); - let chip = NativeLoadStoreChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - (rng, tester, chip) + + let result = match opcode { + STOREW | HINT_STOREW => tester.read(AS::Native as usize, mem_ptr), + LOADW => tester.read(AS::Native as usize, a), + }; + assert_eq!(result, data); } -fn gen_test_data(rng: &mut StdRng, opcode: NativeLoadStoreOpcode) -> TestData { - let is_load = matches!(opcode, NativeLoadStoreOpcode::LOADW); - - let a = rng.gen_range(0..1 << 20); - let b = rng.gen_range(0..1 << 20); - let c = rng.gen_range(0..1 << 20); - let d = F::from_canonical_u32(4u32); - let e = F::from_canonical_u32(4u32); - - TestData { - a: F::from_canonical_u32(a), - b: F::from_canonical_u32(b), - c: F::from_canonical_u32(c), - d, - e, - ad_val: F::from_canonical_u32(111), - cd_val: F::from_canonical_u32(222), - data_val: F::from_canonical_u32(444), - is_load, - is_hint: matches!(opcode, NativeLoadStoreOpcode::HINT_STOREW), +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(STOREW, 100)] +#[test_case(HINT_STOREW, 100)] +#[test_case(LOADW, 100)] +fn rand_native_loadstore_test_1(opcode: NativeLoadStoreOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip::<1>(&tester); + + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut harness, &mut rng, opcode); } + let tester = tester.build().load(harness).finalize(); + tester.simple_test().expect("Verification failed"); } -fn get_data_pointer(data: &TestData) -> F { - if data.d != F::ZERO { - data.cd_val + data.b - } else { - data.c + data.b +#[test_case(STOREW, 100)] +#[test_case(HINT_STOREW, 100)] +#[test_case(LOADW, 100)] +fn rand_native_loadstore_test_4(opcode: NativeLoadStoreOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip::<4>(&tester); + + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut harness, &mut rng, opcode); } + let tester = tester.build().load(harness).finalize(); + tester.simple_test().expect("Verification failed"); } -fn set_values( - tester: &mut VmChipTestBuilder, - chip: &mut NativeLoadStoreChip, - data: &TestData, -) { - if data.d != F::ZERO { - tester.write( - data.d.as_canonical_u32() as usize, - data.a.as_canonical_u32() as usize, - [data.ad_val], - ); - tester.write( - data.d.as_canonical_u32() as usize, - data.c.as_canonical_u32() as usize, - [data.cd_val], - ); - } - if data.is_load { - let data_pointer = get_data_pointer(data); - tester.write( - data.e.as_canonical_u32() as usize, - data_pointer.as_canonical_u32() as usize, - [data.data_val], - ); - } - if data.is_hint { - for _ in 0..data.e.as_canonical_u32() { - chip.core - .streams - .get() - .unwrap() - .lock() - .unwrap() - .hint_stream - .push_back(data.data_val); - } - } +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Clone, Copy, Default)] +struct NativeLoadStorePrankValues { + // Core cols + pub data: Option<[F; NUM_CELLS]>, + pub opcode_flags: Option<[bool; 3]>, + pub pointer_read: Option, + // Adapter cols + pub data_write_pointer: Option, } -fn check_values(tester: &mut VmChipTestBuilder, data: &TestData) { - let data_pointer = get_data_pointer(data); - - let written_data_val = if data.is_load { - tester.read::<1>( - data.d.as_canonical_u32() as usize, - data.a.as_canonical_u32() as usize, - )[0] - } else { - tester.read::<1>( - data.e.as_canonical_u32() as usize, - data_pointer.as_canonical_u32() as usize, - )[0] - }; +fn run_negative_native_loadstore_test( + opcode: NativeLoadStoreOpcode, + prank_vals: NativeLoadStorePrankValues, + error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip::(&tester); + + set_and_execute(&mut tester, &mut harness, &mut rng, opcode); - let correct_data_val = if data.is_load || data.is_hint { - data.data_val - } else if data.d != F::ZERO { - data.ad_val - } else { - data.a + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let (adapter_row, core_row) = values.split_at_mut(adapter_width); + let adapter_cols: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + let core_cols: &mut NativeLoadStoreCoreCols = core_row.borrow_mut(); + + if let Some(data) = prank_vals.data { + core_cols.data = data; + } + if let Some(pointer_read) = prank_vals.pointer_read { + core_cols.pointer_read = pointer_read; + } + if let Some(opcode_flags) = prank_vals.opcode_flags { + [ + core_cols.is_loadw, + core_cols.is_storew, + core_cols.is_hint_storew, + ] = opcode_flags.map(F::from_bool); + } + if let Some(data_write_pointer) = prank_vals.data_write_pointer { + adapter_cols.data_write_pointer = data_write_pointer; + } + + *trace = RowMajorMatrix::new(values, trace.width()); }; - assert_eq!(written_data_val, correct_data_val, "{:?}", data); + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); } -fn set_and_execute( - tester: &mut VmChipTestBuilder, - chip: &mut NativeLoadStoreChip, - rng: &mut StdRng, - opcode: NativeLoadStoreOpcode, -) { - let data = gen_test_data(rng, opcode); - set_values(tester, chip, &data); +#[test] +fn negative_native_loadstore_tests() { + run_negative_native_loadstore_test::<1>( + STOREW, + NativeLoadStorePrankValues { + data_write_pointer: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); - tester.execute_with_pc( - chip, - &Instruction::from_usize( - opcode.global_opcode(), - [data.a, data.b, data.c, data.d, data.e].map(|x| x.as_canonical_u32() as usize), - ), - 0u32, + run_negative_native_loadstore_test::<1>( + LOADW, + NativeLoadStorePrankValues { + data_write_pointer: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, ); +} - check_values(tester, &data); +#[test] +fn invalid_flags_native_loadstore_tests() { + run_negative_native_loadstore_test::<1>( + HINT_STOREW, + NativeLoadStorePrankValues { + opcode_flags: Some([false, false, false]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + + run_negative_native_loadstore_test::<1>( + LOADW, + NativeLoadStorePrankValues { + opcode_flags: Some([false, false, true]), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); } #[test] -fn rand_native_loadstore_test() { - setup_tracing(); - let (mut rng, mut tester, mut chip) = setup(); - for _ in 0..20 { - set_and_execute(&mut tester, &mut chip, &mut rng, STOREW); - set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); - set_and_execute(&mut tester, &mut chip, &mut rng, LOADW); - } - let tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); +fn invalid_data_native_loadstore_tests() { + run_negative_native_loadstore_test( + LOADW, + NativeLoadStorePrankValues { + data: Some([F::ZERO; 4]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); } diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 5ed28abd60..adf2c09a62 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -7,10 +7,13 @@ use openvm_circuit::{ use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; use openvm_native_compiler::{ + conversion::AS, Poseidon2Opcode::{COMP_POS2, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; -use openvm_poseidon2_air::{Poseidon2SubAir, BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS}; +use openvm_poseidon2_air::{ + Poseidon2Config, Poseidon2SubAir, BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS, +}; use openvm_stark_backend::{ air_builders::sub::SubAirBuilder, interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, @@ -20,15 +23,13 @@ use openvm_stark_backend::{ rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use crate::{ +use crate::poseidon2::{ chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, - poseidon2::{ - columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, - }, - CHUNK, + columns::{ + InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, }, + CHUNK, }; #[derive(Clone, Debug)] @@ -40,6 +41,23 @@ pub struct NativePoseidon2Air { pub(crate) address_space: F, } +impl NativePoseidon2Air { + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + verify_batch_bus: VerifyBatchBus, + poseidon2_config: Poseidon2Config, + ) -> Self { + NativePoseidon2Air { + execution_bridge, + memory_bridge, + internal_bus: verify_batch_bus, + subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), + address_space: F::from_canonical_u32(AS::Native as u32), + } + } +} + impl BaseAir for NativePoseidon2Air { fn width(&self) -> usize { NativePoseidon2Cols::::width() diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 426b089a9c..8ad46e5bbf 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -1,189 +1,173 @@ -use std::sync::{Arc, Mutex}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ - arch::{ - ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort, + arch::*, + system::{ + memory::{ + offline_checker::MemoryBaseAuxCols, + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, + native_adapter::util::{ + memory_read_native, tracing_read_native, tracing_write_native_inplace, + }, }, - system::memory::{MemoryController, OfflineMemory, RecordId}, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, Poseidon2Opcode::{COMP_POS2, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; -use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip}; +use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip, Poseidon2SubCols}; use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::{Field, PrimeField32}, - p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelSliceMut, *}, }; -use serde::{Deserialize, Serialize}; use crate::poseidon2::{ - air::{NativePoseidon2Air, VerifyBatchBus}, + columns::{ + InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, + }, CHUNK, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct VerifyBatchRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub dim_base_pointer: F, - pub opened_base_pointer: F, - pub opened_length: usize, - pub index_base_pointer: F, - pub commit_pointer: F, - - pub dim_base_pointer_read: RecordId, - pub opened_base_pointer_read: RecordId, - pub opened_length_read: RecordId, - pub index_base_pointer_read: RecordId, - pub commit_pointer_read: RecordId, - - pub commit_read: RecordId, - pub initial_log_height: usize, - pub top_level: Vec>, +#[derive(Clone)] +pub struct NativePoseidon2Executor { + pub(super) subchip: Poseidon2SubChip, + /// If true, `verify_batch` assumes the verification is always passed and skips poseidon2 + /// computation during execution for performance. + optimistic: bool, } -impl VerifyBatchRecord { - pub fn opened_element_size_inv(&self) -> F { - self.instruction.g - } +pub struct NativePoseidon2Filler { + // pre-computed Poseidon2 sub cols for dummy rows. + empty_poseidon2_sub_cols: Vec, + pub(super) subchip: Poseidon2SubChip, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct TopLevelRecord { - // must be present in first record - pub incorporate_row: Option>, - // must be present in all bust last record - pub incorporate_sibling: Option>, +impl NativePoseidon2Executor { + pub fn new(poseidon2_config: Poseidon2Config) -> Self { + let subchip = Poseidon2SubChip::new(poseidon2_config.constants); + Self { + subchip, + optimistic: true, + } + } + pub fn set_optimistic(&mut self, optimistic: bool) { + self.optimistic = optimistic; + } } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct IncorporateSiblingRecord { - pub read_sibling_is_on_right: RecordId, - pub sibling_is_on_right: bool, - pub p2_input: [F; 2 * CHUNK], +fn compress( + subchip: &Poseidon2SubChip, + left: [F; CHUNK], + right: [F; CHUNK], +) -> ([F; 2 * CHUNK], [F; CHUNK]) { + let concatenated = std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); + let permuted = subchip.permute(concatenated); + (concatenated, std::array::from_fn(|i| permuted[i])) } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct IncorporateRowRecord { - pub chunks: Vec>, - pub initial_opened_index: usize, - pub final_opened_index: usize, - pub initial_height_read: RecordId, - pub final_height_read: RecordId, - pub p2_input: [F; 2 * CHUNK], +impl NativePoseidon2Filler { + pub fn new(poseidon2_config: Poseidon2Config) -> Self { + let subchip = Poseidon2SubChip::new(poseidon2_config.constants); + let empty_poseidon2_sub_cols = subchip.generate_trace(vec![[F::ZERO; CHUNK * 2]]).values; + Self { + empty_poseidon2_sub_cols, + subchip, + } + } } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct InsideRowRecord { - pub cells: Vec, - pub p2_input: [F; 2 * CHUNK], -} +pub(super) const NUM_INITIAL_READS: usize = 6; +pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7; -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CellRecord { - pub read: RecordId, - pub opened_index: usize, - pub read_row_pointer_and_length: Option, - pub row_pointer: usize, - pub row_end: usize, +#[derive(Debug, Clone, Default)] +pub struct NativePoseidon2Metadata { + num_rows: usize, } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct SimplePoseidonRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub read_input_pointer_1: RecordId, - pub read_input_pointer_2: Option, - pub read_output_pointer: RecordId, - pub read_data_1: RecordId, - pub read_data_2: RecordId, - pub write_data_1: RecordId, - pub write_data_2: Option, - - pub input_pointer_1: F, - pub input_pointer_2: F, - pub output_pointer: F, - pub p2_input: [F; 2 * CHUNK], +impl MultiRowMetadata for NativePoseidon2Metadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(bound = "F: Field")] -pub struct NativePoseidon2RecordSet { - pub verify_batch_records: Vec>, - pub simple_permute_records: Vec>, -} +type NativePoseidon2RecordLayout = MultiRowLayout; -pub struct NativePoseidon2Chip { - pub(super) air: NativePoseidon2Air, - pub record_set: NativePoseidon2RecordSet, - pub height: usize, - pub(super) offline_memory: Arc>>, - pub(super) subchip: Poseidon2SubChip, - pub(super) streams: Arc>>, -} +pub struct NativePoseidon2RecordMut<'a, F, const SBOX_REGISTERS: usize>( + &'a mut [NativePoseidon2Cols], +); -impl NativePoseidon2Chip { - pub fn new( - port: SystemPort, - offline_memory: Arc>>, - poseidon2_config: Poseidon2Config, - verify_batch_bus: VerifyBatchBus, - streams: Arc>>, - ) -> Self { - let air = NativePoseidon2Air { - execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), - memory_bridge: port.memory_bridge, - internal_bus: verify_batch_bus, - subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), - address_space: F::from_canonical_u32(AS::Native as u32), +impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> + CustomBorrow<'a, NativePoseidon2RecordMut<'a, F, SBOX_REGISTERS>, NativePoseidon2RecordLayout> + for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: NativePoseidon2RecordLayout, + ) -> NativePoseidon2RecordMut<'a, F, SBOX_REGISTERS> { + let arr = unsafe { + self.align_to_mut::>() + .1 }; - Self { - record_set: Default::default(), - air, - height: 0, - offline_memory, - subchip: Poseidon2SubChip::new(poseidon2_config.constants), - streams, - } + NativePoseidon2RecordMut(&mut arr[..layout.metadata.num_rows]) } - fn compress(&self, left: [F; CHUNK], right: [F; CHUNK]) -> ([F; 2 * CHUNK], [F; CHUNK]) { - let concatenated = - std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); - let permuted = self.subchip.permute(concatenated); - (concatenated, std::array::from_fn(|i| permuted[i])) + unsafe fn extract_layout(&self) -> NativePoseidon2RecordLayout { + // Each instruction record consists solely of some number of contiguously + // stored NativePoseidon2Cols<...> structs, each of which corresponds to a + // single trace row. Trace fillers don't actually need to know how many rows + // each instruction uses, and can thus treat each NativePoseidon2Cols<...> + // as a single record. + NativePoseidon2RecordLayout { + metadata: NativePoseidon2Metadata { num_rows: 1 }, + } } } -pub(super) const NUM_INITIAL_READS: usize = 6; -pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7; +impl SizedRecord + for NativePoseidon2RecordMut<'_, F, SBOX_REGISTERS> +{ + fn size(layout: &NativePoseidon2RecordLayout) -> usize { + layout.metadata.num_rows * size_of::>() + } + + fn alignment(_layout: &NativePoseidon2RecordLayout) -> usize { + align_of::>() + } +} -impl InstructionExecutor - for NativePoseidon2Chip +impl PreflightExecutor + for NativePoseidon2Executor +where + for<'buf> RA: RecordArena< + 'buf, + MultiRowLayout, + NativePoseidon2RecordMut<'buf, F, SBOX_REGISTERS>, + >, { fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + ) -> Result<(), ExecutionError> { + let arena = state.ctx; + let init_timestamp_u32 = state.memory.timestamp; if instruction.opcode == PERM_POS2.global_opcode() || instruction.opcode == COMP_POS2.global_opcode() { + let cols = &mut arena + .alloc(MultiRowLayout::new(NativePoseidon2Metadata { num_rows: 1 })) + .0[0]; + let simple_cols: &mut SimplePoseidonSpecificCols = + cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); let &Instruction { a: output_register, b: input_register_1, @@ -192,22 +176,45 @@ impl InstructionExecutor e: data_address_space, .. } = instruction; + debug_assert_eq!( + register_address_space, + F::from_canonical_u32(AS::Native as u32) + ); + debug_assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); + let [output_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + output_register.as_canonical_u32(), + simple_cols.read_output_pointer.as_mut(), + ); + let output_pointer_u32 = output_pointer.as_canonical_u32(); + let [input_pointer_1]: [F; 1] = tracing_read_native_helper( + state.memory, + input_register_1.as_canonical_u32(), + simple_cols.read_input_pointer_1.as_mut(), + ); + let input_pointer_1_u32 = input_pointer_1.as_canonical_u32(); + let [input_pointer_2]: [F; 1] = if instruction.opcode == PERM_POS2.global_opcode() { + state.memory.increment_timestamp(); + [input_pointer_1 + F::from_canonical_usize(CHUNK)] + } else { + tracing_read_native_helper( + state.memory, + input_register_2.as_canonical_u32(), + simple_cols.read_input_pointer_2.as_mut(), + ) + }; + let input_pointer_2_u32 = input_pointer_2.as_canonical_u32(); + let data_1: [F; CHUNK] = tracing_read_native_helper( + state.memory, + input_pointer_1_u32, + simple_cols.read_data_1.as_mut(), + ); + let data_2: [F; CHUNK] = tracing_read_native_helper( + state.memory, + input_pointer_2_u32, + simple_cols.read_data_2.as_mut(), + ); - let (read_output_pointer, output_pointer) = - memory.read_cell(register_address_space, output_register); - let (read_input_pointer_1, input_pointer_1) = - memory.read_cell(register_address_space, input_register_1); - let (read_input_pointer_2, input_pointer_2) = - if instruction.opcode == PERM_POS2.global_opcode() { - memory.increment_timestamp(); - (None, input_pointer_1 + F::from_canonical_usize(CHUNK)) - } else { - let (read_input_pointer_2, input_pointer_2) = - memory.read_cell(register_address_space, input_register_2); - (Some(read_input_pointer_2), input_pointer_2) - }; - let (read_data_1, data_1) = memory.read::(data_address_space, input_pointer_1); - let (read_data_2, data_2) = memory.read::(data_address_space, input_pointer_2); let p2_input = std::array::from_fn(|i| { if i < CHUNK { data_1[i] @@ -216,50 +223,51 @@ impl InstructionExecutor } }); let output = self.subchip.permute(p2_input); - let (write_data_1, _) = memory.write::( - data_address_space, - output_pointer, + tracing_write_native_inplace( + state.memory, + output_pointer_u32, std::array::from_fn(|i| output[i]), + &mut simple_cols.write_data_1, ); - let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() { - Some( - memory - .write::( - data_address_space, - output_pointer + F::from_canonical_usize(CHUNK), - std::array::from_fn(|i| output[CHUNK + i]), - ) - .0, - ) + if instruction.opcode == PERM_POS2.global_opcode() { + tracing_write_native_inplace( + state.memory, + output_pointer_u32 + CHUNK as u32, + std::array::from_fn(|i| output[i + CHUNK]), + &mut simple_cols.write_data_2, + ); } else { - memory.increment_timestamp(); - None - }; - - assert_eq!( - memory.timestamp(), - from_state.timestamp + NUM_SIMPLE_ACCESSES + state.memory.increment_timestamp(); + } + debug_assert_eq!( + state.memory.timestamp, + init_timestamp_u32 + NUM_SIMPLE_ACCESSES ); + cols.incorporate_row = F::ZERO; + cols.incorporate_sibling = F::ZERO; + cols.inside_row = F::ZERO; + cols.simple = F::ONE; + cols.end_inside_row = F::ZERO; + cols.end_top_level = F::ZERO; + cols.is_exhausted = [F::ZERO; CHUNK - 1]; + cols.start_timestamp = F::from_canonical_u32(init_timestamp_u32); - self.record_set - .simple_permute_records - .push(SimplePoseidonRecord { - from_state, - instruction: instruction.clone(), - read_input_pointer_1, - read_input_pointer_2, - read_output_pointer, - read_data_1, - read_data_2, - write_data_1, - write_data_2, - input_pointer_1, - input_pointer_2, - output_pointer, - p2_input, - }); - self.height += 1; + cols.inner.inputs = p2_input; + simple_cols.pc = F::from_canonical_u32(*state.pc); + simple_cols.is_compress = F::from_bool(instruction.opcode == COMP_POS2.global_opcode()); + simple_cols.output_register = output_register; + simple_cols.input_register_1 = input_register_1; + simple_cols.input_register_2 = input_register_2; + simple_cols.output_pointer = output_pointer; + simple_cols.input_pointer_1 = input_pointer_1; + simple_cols.input_pointer_2 = input_pointer_2; } else if instruction.opcode == VERIFY_BATCH.global_opcode() { + let init_timestamp = F::from_canonical_u32(init_timestamp_u32); + let mut col_buffer = vec![F::ZERO; NativePoseidon2Cols::::width()]; + let last_top_level_cols: &mut NativePoseidon2Cols = + col_buffer.as_mut_slice().borrow_mut(); + let ltl_specific_cols: &mut TopLevelSpecificCols = + last_top_level_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); let &Instruction { a: dim_register, b: opened_register, @@ -270,228 +278,379 @@ impl InstructionExecutor g: opened_element_size_inv, .. } = instruction; - let address_space = self.air.address_space; // calc inverse fast assuming opened_element_size in {1, 4} let mut opened_element_size = F::ONE; while opened_element_size * opened_element_size_inv != F::ONE { opened_element_size += F::ONE; } - let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr); - let (dim_base_pointer_read, dim_base_pointer) = - memory.read_cell(address_space, dim_register); - let (opened_base_pointer_read, opened_base_pointer) = - memory.read_cell(address_space, opened_register); - let (opened_length_read, opened_length) = - memory.read_cell(address_space, opened_length_register); - let (index_base_pointer_read, index_base_pointer) = - memory.read_cell(address_space, index_register); - let (commit_pointer_read, commit_pointer) = - memory.read_cell(address_space, commit_register); - let (commit_read, commit) = memory.read(address_space, commit_pointer); + let [proof_id]: [F; 1] = + memory_read_native(state.memory.data(), proof_id_ptr.as_canonical_u32()); + let [dim_base_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + dim_register.as_canonical_u32(), + ltl_specific_cols.dim_base_pointer_read.as_mut(), + ); + let dim_base_pointer_u32 = dim_base_pointer.as_canonical_u32(); + let [opened_base_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + opened_register.as_canonical_u32(), + ltl_specific_cols.opened_base_pointer_read.as_mut(), + ); + let opened_base_pointer_u32 = opened_base_pointer.as_canonical_u32(); + let [opened_length]: [F; 1] = tracing_read_native_helper( + state.memory, + opened_length_register.as_canonical_u32(), + ltl_specific_cols.opened_length_read.as_mut(), + ); + let [index_base_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + index_register.as_canonical_u32(), + ltl_specific_cols.index_base_pointer_read.as_mut(), + ); + let index_base_pointer_u32 = index_base_pointer.as_canonical_u32(); + let [commit_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + commit_register.as_canonical_u32(), + ltl_specific_cols.commit_pointer_read.as_mut(), + ); + // In E3, the proof is assumed to be valid. The verification during execution is + // skipped. + let commit: [F; CHUNK] = tracing_read_native_helper( + state.memory, + commit_pointer.as_canonical_u32(), + ltl_specific_cols.commit_read.as_mut(), + ); let opened_length = opened_length.as_canonical_u32() as usize; + let [initial_log_height]: [F; 1] = + memory_read_native(state.memory.data(), dim_base_pointer_u32); + let initial_log_height_u32 = initial_log_height.as_canonical_u32(); + let mut log_height = initial_log_height_u32 as i32; - let initial_log_height = memory - .unsafe_read_cell(address_space, dim_base_pointer) - .as_canonical_u32(); - let mut log_height = initial_log_height as i32; - let mut sibling_index = 0; + // Number of non-inside rows, this is used to compute the offset of the inside row + // section. + let (num_inside_rows, num_non_inside_rows) = { + let opened_element_size_u32 = opened_element_size.as_canonical_u32(); + let mut num_non_inside_rows = initial_log_height_u32 as usize; + let mut num_inside_rows = 0; + let mut log_height = initial_log_height_u32; + let mut opened_index = 0; + loop { + let mut total_len = 0; + while opened_index < opened_length { + let [height]: [F; 1] = memory_read_native( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + ); + if height.as_canonical_u32() != log_height { + break; + } + let [row_len]: [F; 1] = memory_read_native( + state.memory.data(), + opened_base_pointer_u32 + 2 * opened_index as u32 + 1, + ); + total_len += row_len.as_canonical_u32() * opened_element_size_u32; + opened_index += 1; + } + if total_len != 0 { + num_non_inside_rows += 1; + num_inside_rows += (total_len as usize).div_ceil(CHUNK); + } + if log_height == 0 { + break; + } + log_height -= 1; + } + (num_inside_rows, num_non_inside_rows) + }; + let mut proof_index = 0; let mut opened_index = 0; - let mut top_level = vec![]; let mut root = [F::ZERO; CHUNK]; let sibling_proof: Vec<[F; CHUNK]> = { - let streams = self.streams.lock().unwrap(); let proof_idx = proof_id.as_canonical_u32() as usize; - streams.hint_space[proof_idx] + state.streams.hint_space[proof_idx] .par_chunks(CHUNK) .map(|c| c.try_into().unwrap()) .collect() }; + let total_num_row = num_inside_rows + num_non_inside_rows; + let allocated_rows = arena + .alloc(MultiRowLayout::new(NativePoseidon2Metadata { + num_rows: total_num_row, + })) + .0; + allocated_rows[0].inner.export = F::from_canonical_u32(num_non_inside_rows as u32); + let mut inside_row_idx = num_non_inside_rows; + let mut non_inside_row_idx = 0; + while log_height >= 0 { - let incorporate_row = if opened_index < opened_length - && memory.unsafe_read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(opened_index), - ) == F::from_canonical_u32(log_height as u32) + if opened_index < opened_length + && memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] == F::from_canonical_u32(log_height as u32) { + state + .memory + .increment_timestamp_by(NUM_INITIAL_READS as u32); + let incorporate_start_timestamp = state.memory.timestamp; let initial_opened_index = opened_index; - for _ in 0..NUM_INITIAL_READS { - memory.increment_timestamp(); - } - let mut chunks = vec![]; - let mut row_pointer = 0; let mut row_end = 0; - - let mut prev_rolling_hash: Option<[F; 2 * CHUNK]> = None; let mut rolling_hash = [F::ZERO; 2 * CHUNK]; - let mut is_first_in_segment = true; loop { - let mut cells = vec![]; + if inside_row_idx == total_num_row { + opened_index += 1; + break; + } + let inside_cols = &mut allocated_rows[inside_row_idx]; + let inside_specific_cols: &mut InsideRowSpecificCols = inside_cols + .specific[..InsideRowSpecificCols::::width()] + .borrow_mut(); + let start_timestamp_u32 = state.memory.timestamp; + + let mut cells_idx = 0; for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { - let read_row_pointer_and_length = if is_first_in_segment - || row_pointer == row_end - { + let cell_cols = &mut inside_specific_cols.cells[cells_idx]; + if is_first_in_segment || row_pointer == row_end { if is_first_in_segment { is_first_in_segment = false; } else { opened_index += 1; if opened_index == opened_length - || memory.unsafe_read_cell( - address_space, - dim_base_pointer - + F::from_canonical_usize(opened_index), - ) != F::from_canonical_u32(log_height as u32) + || memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] != F::from_canonical_u32(log_height as u32) { break; } } - let (result, [new_row_pointer, row_len]) = memory.read( - address_space, - opened_base_pointer + F::from_canonical_usize(2 * opened_index), + let [new_row_pointer, row_len]: [F; 2] = tracing_read_native_helper( + state.memory, + opened_base_pointer_u32 + 2 * opened_index as u32, + cell_cols.read_row_pointer_and_length.as_mut(), ); row_pointer = new_row_pointer.as_canonical_u32() as usize; row_end = row_pointer + (opened_element_size * row_len).as_canonical_u32() as usize; - Some(result) + cell_cols.is_first_in_row = F::ONE; } else { - memory.increment_timestamp(); - None - }; - let (read, value) = memory - .read_cell(address_space, F::from_canonical_usize(row_pointer)); - cells.push(CellRecord { - read, - opened_index, - read_row_pointer_and_length, - row_pointer, - row_end, - }); + state.memory.increment_timestamp(); + } + let [value]: [F; 1] = tracing_read_native_helper( + state.memory, + row_pointer as u32, + cell_cols.read.as_mut(), + ); + + cell_cols.opened_index = F::from_canonical_usize(opened_index); + cell_cols.row_pointer = F::from_canonical_usize(row_pointer); + cell_cols.row_end = F::from_canonical_usize(row_end); + *chunk_elem = value; row_pointer += 1; + cells_idx += 1; } - if cells.is_empty() { + if cells_idx == 0 { break; } - let cells_len = cells.len(); - chunks.push(InsideRowRecord { - cells, - p2_input: rolling_hash, - }); - self.height += 1; - prev_rolling_hash = Some(rolling_hash); - self.subchip.permute_mut(&mut rolling_hash); - if cells_len < CHUNK { - for _ in 0..CHUNK - cells_len { - memory.increment_timestamp(); - memory.increment_timestamp(); + inside_cols.inner.inputs[..CHUNK].copy_from_slice(&rolling_hash[..CHUNK]); + if !self.optimistic { + self.subchip.permute_mut(&mut rolling_hash); + } + if cells_idx < CHUNK { + state + .memory + .increment_timestamp_by(2 * (CHUNK - cells_idx) as u32); + } + + inside_row_idx += 1; + // left + inside_cols.incorporate_row = F::ZERO; + inside_cols.incorporate_sibling = F::ZERO; + inside_cols.inside_row = F::ONE; + inside_cols.simple = F::ZERO; + // `end_inside_row` of the last row will be set to 1 after this loop. + inside_cols.end_inside_row = F::ZERO; + inside_cols.end_top_level = F::ZERO; + inside_cols.opened_element_size_inv = opened_element_size_inv; + inside_cols.very_first_timestamp = + F::from_canonical_u32(incorporate_start_timestamp); + inside_cols.start_timestamp = F::from_canonical_u32(start_timestamp_u32); + + inside_cols.initial_opened_index = + F::from_canonical_usize(initial_opened_index); + inside_cols.opened_base_pointer = opened_base_pointer; + if cells_idx < CHUNK { + let exhausted_opened_idx = F::from_canonical_usize(opened_index - 1); + for exhausted_idx in cells_idx..CHUNK { + inside_cols.is_exhausted[exhausted_idx - 1] = F::ONE; + inside_specific_cols.cells[exhausted_idx].opened_index = + exhausted_opened_idx; } break; } } + { + let inside_cols = &mut allocated_rows[inside_row_idx - 1]; + inside_cols.end_inside_row = F::ONE; + } + + let incorporate_cols = &mut allocated_rows[non_inside_row_idx]; + let top_level_specific_cols: &mut TopLevelSpecificCols = incorporate_cols + .specific[..TopLevelSpecificCols::::width()] + .borrow_mut(); + let final_opened_index = opened_index - 1; - let (initial_height_read, height_check) = memory.read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(initial_opened_index), + let [height_check]: [F; 1] = tracing_read_native_helper( + state.memory, + dim_base_pointer_u32 + initial_opened_index as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), ); assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); - let (final_height_read, height_check) = memory.read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(final_opened_index), + let final_height_read_timestamp = state.memory.timestamp; + let [height_check]: [F; 1] = tracing_read_native_helper( + state.memory, + dim_base_pointer_u32 + final_opened_index as u32, + top_level_specific_cols.read_final_height.as_mut(), ); assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); - let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); + if !self.optimistic { + let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); + root = if log_height as u32 == initial_log_height_u32 { + hash + } else { + compress(&self.subchip, root, hash).1 + }; + } + non_inside_row_idx += 1; - let (p2_input, new_root) = if log_height as u32 == initial_log_height { - (prev_rolling_hash.unwrap(), hash) - } else { - self.compress(root, hash) - }; - root = new_root; - - self.height += 1; - Some(IncorporateRowRecord { - chunks, - initial_opened_index, - final_opened_index, - initial_height_read, - final_height_read, - p2_input, - }) - } else { - None - }; + incorporate_cols.incorporate_row = F::ONE; + incorporate_cols.incorporate_sibling = F::ZERO; + incorporate_cols.inside_row = F::ZERO; + incorporate_cols.simple = F::ZERO; + incorporate_cols.end_inside_row = F::ZERO; + incorporate_cols.end_top_level = F::ZERO; + incorporate_cols.start_top_level = F::from_bool(proof_index == 0); + incorporate_cols.opened_element_size_inv = opened_element_size_inv; + incorporate_cols.very_first_timestamp = init_timestamp; + incorporate_cols.start_timestamp = F::from_canonical_u32( + incorporate_start_timestamp - NUM_INITIAL_READS as u32, + ); + top_level_specific_cols.end_timestamp = + F::from_canonical_u32(final_height_read_timestamp + 1); - let incorporate_sibling = if log_height == 0 { - None - } else { - for _ in 0..NUM_INITIAL_READS { - memory.increment_timestamp(); - } + incorporate_cols.initial_opened_index = + F::from_canonical_usize(initial_opened_index); + top_level_specific_cols.final_opened_index = + F::from_canonical_usize(final_opened_index); + top_level_specific_cols.log_height = F::from_canonical_u32(log_height as u32); + top_level_specific_cols.opened_length = F::from_canonical_usize(opened_length); + top_level_specific_cols.dim_base_pointer = dim_base_pointer; + incorporate_cols.opened_base_pointer = opened_base_pointer; + top_level_specific_cols.index_base_pointer = index_base_pointer; + top_level_specific_cols.proof_index = F::from_canonical_usize(proof_index); + } + + if log_height != 0 { + let row_start_timestamp = state.memory.timestamp; + state + .memory + .increment_timestamp_by(NUM_INITIAL_READS as u32); + + let sibling_cols = &mut allocated_rows[non_inside_row_idx]; + let top_level_specific_cols: &mut TopLevelSpecificCols = + sibling_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell( - address_space, - index_base_pointer + F::from_canonical_usize(sibling_index), + let read_sibling_is_on_right_timestamp = state.memory.timestamp; + let [sibling_is_on_right]: [F; 1] = tracing_read_native_helper( + state.memory, + index_base_pointer_u32 + proof_index as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), ); - let sibling_is_on_right = sibling_is_on_right == F::ONE; - let sibling = sibling_proof[sibling_index]; - let (p2_input, new_root) = if sibling_is_on_right { - self.compress(sibling, root) - } else { - self.compress(root, sibling) - }; - root = new_root; - - self.height += 1; - Some(IncorporateSiblingRecord { - read_sibling_is_on_right, - sibling_is_on_right, - p2_input, - }) - }; + let sibling = sibling_proof[proof_index]; + if !self.optimistic { + root = if sibling_is_on_right == F::ONE { + compress(&self.subchip, sibling, root).1 + } else { + compress(&self.subchip, root, sibling).1 + }; + } + + non_inside_row_idx += 1; - top_level.push(TopLevelRecord { - incorporate_row, - incorporate_sibling, - }); + sibling_cols.inner.inputs[..CHUNK].copy_from_slice(&sibling); + + sibling_cols.incorporate_row = F::ZERO; + sibling_cols.incorporate_sibling = F::ONE; + sibling_cols.inside_row = F::ZERO; + sibling_cols.simple = F::ZERO; + sibling_cols.end_inside_row = F::ZERO; + sibling_cols.end_top_level = F::ZERO; + sibling_cols.start_top_level = F::ZERO; + sibling_cols.opened_element_size_inv = opened_element_size_inv; + sibling_cols.very_first_timestamp = init_timestamp; + sibling_cols.start_timestamp = F::from_canonical_u32(row_start_timestamp); + + top_level_specific_cols.end_timestamp = + F::from_canonical_u32(read_sibling_is_on_right_timestamp + 1); + sibling_cols.initial_opened_index = F::from_canonical_usize(opened_index); + top_level_specific_cols.final_opened_index = + F::from_canonical_usize(opened_index - 1); + top_level_specific_cols.log_height = F::from_canonical_u32(log_height as u32); + top_level_specific_cols.opened_length = F::from_canonical_usize(opened_length); + top_level_specific_cols.dim_base_pointer = dim_base_pointer; + sibling_cols.opened_base_pointer = opened_base_pointer; + top_level_specific_cols.index_base_pointer = index_base_pointer; + + top_level_specific_cols.proof_index = F::from_canonical_usize(proof_index); + top_level_specific_cols.sibling_is_on_right = sibling_is_on_right; + }; log_height -= 1; - sibling_index += 1; + proof_index += 1; + } + let ltl_trace_cols = &mut allocated_rows[non_inside_row_idx - 1]; + let ltl_trace_specific_cols: &mut TopLevelSpecificCols = + ltl_trace_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); + ltl_trace_cols.inner.export = F::from_canonical_u32(total_num_row as u32); + ltl_trace_cols.end_top_level = F::ONE; + ltl_trace_specific_cols.pc = F::from_canonical_u32(*state.pc); + ltl_trace_specific_cols.dim_register = dim_register; + ltl_trace_specific_cols.opened_register = opened_register; + ltl_trace_specific_cols.opened_length_register = opened_length_register; + ltl_trace_specific_cols.proof_id = proof_id_ptr; + ltl_trace_specific_cols.index_register = index_register; + ltl_trace_specific_cols.commit_register = commit_register; + ltl_trace_specific_cols.commit_pointer = commit_pointer; + ltl_trace_specific_cols.dim_base_pointer_read = ltl_specific_cols.dim_base_pointer_read; + ltl_trace_specific_cols.opened_base_pointer_read = + ltl_specific_cols.opened_base_pointer_read; + ltl_trace_specific_cols.opened_length_read = ltl_specific_cols.opened_length_read; + ltl_trace_specific_cols.index_base_pointer_read = + ltl_specific_cols.index_base_pointer_read; + ltl_trace_specific_cols.commit_pointer_read = ltl_specific_cols.commit_pointer_read; + ltl_trace_specific_cols.commit_read = ltl_specific_cols.commit_read; + if !self.optimistic { + assert_eq!(commit, root); } - - assert_eq!(commit, root); - self.record_set - .verify_batch_records - .push(VerifyBatchRecord { - from_state, - instruction: instruction.clone(), - dim_base_pointer, - opened_base_pointer, - opened_length, - index_base_pointer, - commit_pointer, - dim_base_pointer_read, - opened_base_pointer_read, - opened_length_read, - index_base_pointer_read, - commit_pointer_read, - commit_read, - initial_log_height: initial_log_height as usize, - top_level, - }); } else { unreachable!() } - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) + + *state.pc += DEFAULT_PC_STEP; + Ok(()) } fn get_opcode_name(&self, opcode: usize) -> String { @@ -506,3 +665,806 @@ impl InstructionExecutor } } } + +impl TraceFiller + for NativePoseidon2Filler +{ + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) where + F: Send + Sync + Clone, + { + // Split the trace rows by instruction + let width = trace.width(); + let mut row_idx = 0; + let mut row_slice = trace.values.as_mut_slice(); + let mut chunk_start = Vec::new(); + while row_idx < rows_used { + let cols: &NativePoseidon2Cols = row_slice[..width].borrow(); + let (curr, rest) = if cols.simple.is_one() { + row_idx += 1; + row_slice.split_at_mut(width) + } else { + let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; + let start = (num_non_inside_row - 1) * width; + let cols: &NativePoseidon2Cols = + row_slice[start..(start + width)].borrow(); + let total_num_row = cols.inner.export.as_canonical_u32() as usize; + row_idx += total_num_row; + row_slice.split_at_mut(total_num_row * width) + }; + chunk_start.push(curr); + row_slice = rest; + } + chunk_start.into_par_iter().for_each(|chunk_slice| { + let cols: &NativePoseidon2Cols = chunk_slice[..width].borrow(); + if cols.simple.is_one() { + self.fill_simple_chunk(mem_helper, chunk_slice); + } else { + self.fill_verify_batch_chunk(mem_helper, chunk_slice); + } + }); + // Remaining rows are dummy rows. + let inner_width = self.subchip.air.width(); + row_slice.par_chunks_exact_mut(width).for_each(|row_slice| { + row_slice[..inner_width].copy_from_slice(&self.empty_poseidon2_sub_cols); + }); + } +} + +impl NativePoseidon2Filler { + fn fill_simple_chunk(&self, mem_helper: &MemoryAuxColsFactory, chunk_slice: &mut [F]) { + { + let inner_width = self.subchip.air.width(); + let cols: &NativePoseidon2Cols = chunk_slice.as_ref().borrow(); + let inner_cols = &self.subchip.generate_trace(vec![cols.inner.inputs]).values; + chunk_slice[..inner_width].copy_from_slice(inner_cols); + } + + let cols: &mut NativePoseidon2Cols = chunk_slice.borrow_mut(); + // Simple poseidon2 row + let simple_cols: &mut SimplePoseidonSpecificCols = + cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + mem_fill_helper( + mem_helper, + start_timestamp_u32, + simple_cols.read_output_pointer.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + simple_cols.read_input_pointer_1.as_mut(), + ); + if simple_cols.is_compress.is_one() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2, + simple_cols.read_input_pointer_2.as_mut(), + ); + } + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 3, + simple_cols.read_data_1.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 4, + simple_cols.read_data_2.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 5, + simple_cols.write_data_1.as_mut(), + ); + if simple_cols.is_compress.is_zero() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 6, + simple_cols.write_data_2.as_mut(), + ); + } + } + + fn fill_verify_batch_chunk(&self, mem_helper: &MemoryAuxColsFactory, chunk_slice: &mut [F]) { + let inner_width = self.subchip.air.width(); + let width = NativePoseidon2Cols::::width(); + let num_non_inside_rows = { + let cols: &NativePoseidon2Cols = chunk_slice[..width].borrow(); + cols.inner.export.as_canonical_u32() as usize + }; + let total_num_rows = { + let start = (num_non_inside_rows - 1) * width; + let last_cols: &NativePoseidon2Cols = + chunk_slice[start..(start + width)].borrow(); + // During execution, this field hasn't been filled with meaningful data. So we use this + // field to store the number of inside rows. + last_cols.inner.export.as_canonical_u32() as usize + }; + let mut first_round = true; + let mut root = [F::ZERO; CHUNK]; + let mut inside_idx = num_non_inside_rows; + let mut non_inside_idx = 0; + while inside_idx < total_num_rows || non_inside_idx < num_non_inside_rows { + debug_assert!(non_inside_idx < num_non_inside_rows); + let incorporate_sibling = { + let start = non_inside_idx * width; + let row_slice = &mut chunk_slice[start..(start + width)]; + let cols: &NativePoseidon2Cols = row_slice.as_ref().borrow(); + cols.incorporate_sibling.is_one() + }; + if !incorporate_sibling { + let mut prev_rolling_hash: [F; 2 * CHUNK]; + let mut rolling_hash = [F::ZERO; 2 * CHUNK]; + loop { + let start = inside_idx * width; + let row_slice = &mut chunk_slice[start..(start + width)]; + let mut input_len = 0; + { + let cols: &mut NativePoseidon2Cols = + row_slice.borrow_mut(); + let inside_row_specific_cols: &mut InsideRowSpecificCols = + cols.specific[..InsideRowSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + for (i, cell) in inside_row_specific_cols.cells.iter_mut().enumerate() { + if i > 0 && cols.is_exhausted[i - 1].is_one() { + break; + } + input_len += 1; + if cell.is_first_in_row.is_one() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2 * i as u32, + cell.read_row_pointer_and_length.as_mut(), + ); + } + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2 * i as u32 + 1, + cell.read.as_mut(), + ); + } + } + { + let cols: &NativePoseidon2Cols = + row_slice.as_ref().borrow(); + rolling_hash[..input_len].copy_from_slice(&cols.inner.inputs[..input_len]); + } + prev_rolling_hash = rolling_hash; + + let inner_cols = &self.subchip.generate_trace(vec![rolling_hash]).values; + row_slice[..inner_width].copy_from_slice(inner_cols); + let cols: &NativePoseidon2Cols = row_slice.as_ref().borrow(); + rolling_hash = *Self::poseidon2_output_from_trace(&cols.inner); + inside_idx += 1; + if cols.end_inside_row.is_one() { + break; + } + } + + let start = non_inside_idx * width; + let row_slice = &mut chunk_slice[start..(start + width)]; + let mut p2_input = [F::ZERO; 2 * CHUNK]; + if first_round { + p2_input.copy_from_slice(&prev_rolling_hash); + } else { + p2_input[..CHUNK].copy_from_slice(&root); + p2_input[CHUNK..].copy_from_slice(&rolling_hash[..CHUNK]); + } + + first_round = false; + let inner_cols = &self.subchip.generate_trace(vec![p2_input]).values; + row_slice[..inner_width].copy_from_slice(inner_cols); + let cols: &mut NativePoseidon2Cols = row_slice.borrow_mut(); + Self::fill_timestamp_for_top_level(mem_helper, cols); + root.copy_from_slice(&Self::poseidon2_output_from_trace(&cols.inner)[..CHUNK]); + non_inside_idx += 1; + } + + if non_inside_idx < num_non_inside_rows { + let start = non_inside_idx * width; + let row_slice = &mut chunk_slice[start..(start + width)]; + let p2_input = { + let cols: &mut NativePoseidon2Cols = row_slice.borrow_mut(); + Self::fill_timestamp_for_top_level(mem_helper, cols); + let sibling = &cols.inner.inputs[..CHUNK]; + let top_level_specific_cols: &TopLevelSpecificCols = + cols.specific[..TopLevelSpecificCols::::width()].borrow(); + let sibling_is_on_right = top_level_specific_cols.sibling_is_on_right.is_one(); + let mut p2_input = [F::ZERO; 2 * CHUNK]; + if sibling_is_on_right { + p2_input[..CHUNK].copy_from_slice(sibling); + p2_input[CHUNK..].copy_from_slice(&root); + } else { + p2_input[..CHUNK].copy_from_slice(&root); + p2_input[CHUNK..].copy_from_slice(sibling); + }; + p2_input + }; + let inner_cols = &self.subchip.generate_trace(vec![p2_input]).values; + row_slice[..inner_width].copy_from_slice(inner_cols); + let cols: &NativePoseidon2Cols = row_slice.as_ref().borrow(); + root.copy_from_slice(&Self::poseidon2_output_from_trace(&cols.inner)[..CHUNK]); + non_inside_idx += 1; + } + } + } + fn fill_timestamp_for_top_level( + mem_helper: &MemoryAuxColsFactory, + cols: &mut NativePoseidon2Cols, + ) { + let top_level_specific_cols: &mut TopLevelSpecificCols = + cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + if cols.end_top_level.is_one() { + let very_start_timestamp_u32 = cols.very_first_timestamp.as_canonical_u32(); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32, + top_level_specific_cols.dim_base_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 1, + top_level_specific_cols.opened_base_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 2, + top_level_specific_cols.opened_length_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 3, + top_level_specific_cols.index_base_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 4, + top_level_specific_cols.commit_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 5, + top_level_specific_cols.commit_read.as_mut(), + ); + } + if cols.incorporate_row.is_one() { + let end_timestamp = top_level_specific_cols.end_timestamp.as_canonical_u32(); + mem_fill_helper( + mem_helper, + end_timestamp - 2, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), + ); + mem_fill_helper( + mem_helper, + end_timestamp - 1, + top_level_specific_cols.read_final_height.as_mut(), + ); + } else if cols.incorporate_sibling.is_one() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + NUM_INITIAL_READS as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), + ); + } else { + unreachable!() + } + } + + #[inline(always)] + fn poseidon2_output_from_trace(inner: &Poseidon2SubCols) -> &[F; 2 * CHUNK] { + &inner.ending_full_rounds.last().unwrap().post + } +} + +fn tracing_read_native_helper( + memory: &mut TracingMemory, + ptr: u32, + base_aux: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] { + let mut prev_ts = 0; + let ret = tracing_read_native(memory, ptr, &mut prev_ts); + base_aux.set_prev(F::from_canonical_u32(prev_ts)); + ret +} + +/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. +fn mem_fill_helper( + mem_helper: &MemoryAuxColsFactory, + timestamp: u32, + base_aux: &mut MemoryBaseAuxCols, +) { + let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); + mem_helper.fill(prev_ts, timestamp, base_aux); +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { + subchip: &'a Poseidon2SubChip, + output_register: u32, + input_register_1: u32, + input_register_2: u32, +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct VerifyBatchPreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { + subchip: &'a Poseidon2SubChip, + dim_register: u32, + opened_register: u32, + opened_length_register: u32, + proof_id_ptr: u32, + index_register: u32, + commit_register: u32, + opened_element_size: F, +} + +impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor { + #[inline(always)] + fn pre_compute_pos2_impl( + &'a self, + pc: u32, + inst: &Instruction, + pos2_data: &mut Pos2PreCompute<'a, F, SBOX_REGISTERS>, + ) -> Result<(), StaticProgramError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + if opcode != PERM_POS2.global_opcode() && opcode != COMP_POS2.global_opcode() { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + if d != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *pos2_data = Pos2PreCompute { + subchip: &self.subchip, + output_register: a, + input_register_1: b, + input_register_2: c, + }; + + Ok(()) + } + + #[inline(always)] + fn pre_compute_verify_batch_impl( + &'a self, + pc: u32, + inst: &Instruction, + verify_batch_data: &mut VerifyBatchPreCompute<'a, F, SBOX_REGISTERS>, + ) -> Result<(), StaticProgramError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + if opcode != VERIFY_BATCH.global_opcode() { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); + + let opened_element_size_inv = g; + // calc inverse fast assuming opened_element_size in {1, 4} + let mut opened_element_size = F::ONE; + while opened_element_size * opened_element_size_inv != F::ONE { + opened_element_size += F::ONE; + } + + *verify_batch_data = VerifyBatchPreCompute { + subchip: &self.subchip, + dim_register: a, + opened_register: b, + opened_length_register: c, + proof_id_ptr: d, + index_register: e, + commit_register: f, + opened_element_size, + }; + + Ok(()) + } +} + +impl Executor + for NativePoseidon2Executor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>(), + size_of::>(), + ) + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let &Instruction { opcode, .. } = inst; + + let is_pos2 = opcode == PERM_POS2.global_opcode() || opcode == COMP_POS2.global_opcode(); + + if is_pos2 { + let pos2_data: &mut Pos2PreCompute = data.borrow_mut(); + self.pre_compute_pos2_impl(pc, inst, pos2_data)?; + if opcode == PERM_POS2.global_opcode() { + Ok(execute_pos2_e1_impl::<_, _, SBOX_REGISTERS, true>) + } else { + Ok(execute_pos2_e1_impl::<_, _, SBOX_REGISTERS, false>) + } + } else { + let verify_batch_data: &mut VerifyBatchPreCompute = + data.borrow_mut(); + self.pre_compute_verify_batch_impl(pc, inst, verify_batch_data)?; + Ok(execute_verify_batch_e1_impl::<_, _, SBOX_REGISTERS>) + } + } +} + +impl MeteredExecutor + for NativePoseidon2Executor +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>>(), + size_of::>>(), + ) + } + + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let &Instruction { opcode, .. } = inst; + + let is_pos2 = opcode == PERM_POS2.global_opcode() || opcode == COMP_POS2.global_opcode(); + + if is_pos2 { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_pos2_impl(pc, inst, &mut pre_compute.data)?; + if opcode == PERM_POS2.global_opcode() { + Ok(execute_pos2_e2_impl::<_, _, SBOX_REGISTERS, true>) + } else { + Ok(execute_pos2_e2_impl::<_, _, SBOX_REGISTERS, false>) + } + } else { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_verify_batch_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_verify_batch_e2_impl::<_, _, SBOX_REGISTERS>) + } + } +} + +unsafe fn execute_pos2_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, + const IS_PERM: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &Pos2PreCompute = pre_compute.borrow(); + execute_pos2_e12_impl::<_, _, SBOX_REGISTERS, IS_PERM>(pre_compute, vm_state); +} + +unsafe fn execute_pos2_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const SBOX_REGISTERS: usize, + const IS_PERM: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + let height = + execute_pos2_e12_impl::<_, _, SBOX_REGISTERS, IS_PERM>(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +unsafe fn execute_verify_batch_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &VerifyBatchPreCompute = pre_compute.borrow(); + // NOTE: using optimistic execution + execute_verify_batch_e12_impl::<_, _, SBOX_REGISTERS, true>(pre_compute, vm_state); +} + +unsafe fn execute_verify_batch_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + // NOTE: using optimistic execution + let height = + execute_verify_batch_e12_impl::<_, _, SBOX_REGISTERS, true>(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_pos2_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, + const IS_PERM: bool, +>( + pre_compute: &Pos2PreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let subchip = pre_compute.subchip; + + let [output_pointer]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.output_register); + let [input_pointer_1]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.input_register_1); + let [input_pointer_2] = if IS_PERM { + [input_pointer_1 + F::from_canonical_usize(CHUNK)] + } else { + vm_state.vm_read(AS::Native as u32, pre_compute.input_register_2) + }; + + let data_1: [F; CHUNK] = + vm_state.vm_read(AS::Native as u32, input_pointer_1.as_canonical_u32()); + let data_2: [F; CHUNK] = + vm_state.vm_read(AS::Native as u32, input_pointer_2.as_canonical_u32()); + + let p2_input = std::array::from_fn(|i| { + if i < CHUNK { + data_1[i] + } else { + data_2[i - CHUNK] + } + }); + let output = subchip.permute(p2_input); + let output_pointer_u32 = output_pointer.as_canonical_u32(); + + vm_state.vm_write::( + AS::Native as u32, + output_pointer_u32, + &std::array::from_fn(|i| output[i]), + ); + if IS_PERM { + vm_state.vm_write::( + AS::Native as u32, + output_pointer_u32 + CHUNK as u32, + &std::array::from_fn(|i| output[i + CHUNK]), + ); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + 1 +} + +#[inline(always)] +unsafe fn execute_verify_batch_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, + const OPTIMISTIC: bool, +>( + pre_compute: &VerifyBatchPreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let subchip = pre_compute.subchip; + let opened_element_size = pre_compute.opened_element_size; + + let [proof_id]: [F; 1] = vm_state.host_read(AS::Native as u32, pre_compute.proof_id_ptr); + let [dim_base_pointer]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.dim_register); + let dim_base_pointer_u32 = dim_base_pointer.as_canonical_u32(); + let [opened_base_pointer]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.opened_register); + let opened_base_pointer_u32 = opened_base_pointer.as_canonical_u32(); + let [opened_length]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.opened_length_register); + let [index_base_pointer]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.index_register); + let index_base_pointer_u32 = index_base_pointer.as_canonical_u32(); + let [commit_pointer]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.commit_register); + let commit: [F; CHUNK] = vm_state.vm_read(AS::Native as u32, commit_pointer.as_canonical_u32()); + + let opened_length = opened_length.as_canonical_u32() as usize; + + let initial_log_height = { + let [height]: [F; 1] = vm_state.host_read(AS::Native as u32, dim_base_pointer_u32); + height.as_canonical_u32() + }; + + let mut log_height = initial_log_height as i32; + let mut sibling_index = 0; + let mut opened_index = 0; + let mut height = 0; + + let mut root = [F::ZERO; CHUNK]; + let sibling_proof: Vec<[F; CHUNK]> = { + let proof_idx = proof_id.as_canonical_u32() as usize; + vm_state.streams.hint_space[proof_idx] + .par_chunks(CHUNK) + .map(|c| c.try_into().unwrap()) + .collect() + }; + + while log_height >= 0 { + if opened_index < opened_length + && vm_state.host_read::( + AS::Native as u32, + dim_base_pointer_u32 + opened_index as u32, + )[0] == F::from_canonical_u32(log_height as u32) + { + let initial_opened_index = opened_index; + + let mut row_pointer = 0; + let mut row_end = 0; + + let mut rolling_hash = [F::ZERO; 2 * CHUNK]; + + let mut is_first_in_segment = true; + + loop { + let mut cells_len = 0; + for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { + if is_first_in_segment || row_pointer == row_end { + if is_first_in_segment { + is_first_in_segment = false; + } else { + opened_index += 1; + if opened_index == opened_length + || vm_state.host_read::( + AS::Native as u32, + dim_base_pointer_u32 + opened_index as u32, + )[0] != F::from_canonical_u32(log_height as u32) + { + break; + } + } + let [new_row_pointer, row_len]: [F; 2] = vm_state.vm_read( + AS::Native as u32, + opened_base_pointer_u32 + 2 * opened_index as u32, + ); + row_pointer = new_row_pointer.as_canonical_u32() as usize; + row_end = row_pointer + + (opened_element_size * row_len).as_canonical_u32() as usize; + } + let [value]: [F; 1] = vm_state.vm_read(AS::Native as u32, row_pointer as u32); + cells_len += 1; + *chunk_elem = value; + row_pointer += 1; + } + if cells_len == 0 { + break; + } + height += 1; + if !OPTIMISTIC { + subchip.permute_mut(&mut rolling_hash); + } + if cells_len < CHUNK { + break; + } + } + + let final_opened_index = opened_index - 1; + let [height_check]: [F; 1] = vm_state.host_read( + AS::Native as u32, + dim_base_pointer_u32 + initial_opened_index as u32, + ); + assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + let [height_check]: [F; 1] = vm_state.host_read( + AS::Native as u32, + dim_base_pointer_u32 + final_opened_index as u32, + ); + assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + + if !OPTIMISTIC { + let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); + + let new_root = if log_height as u32 == initial_log_height { + hash + } else { + let (_, new_root) = compress(subchip, root, hash); + new_root + }; + root = new_root; + } + height += 1; + } + + if log_height != 0 { + let [sibling_is_on_right]: [F; 1] = vm_state.vm_read( + AS::Native as u32, + index_base_pointer_u32 + sibling_index as u32, + ); + let sibling_is_on_right = sibling_is_on_right == F::ONE; + let sibling = sibling_proof[sibling_index]; + if !OPTIMISTIC { + let (_, new_root) = if sibling_is_on_right { + compress(subchip, sibling, root) + } else { + compress(subchip, root, sibling) + }; + root = new_root; + } + height += 1; + } + + log_height -= 1; + sibling_index += 1; + } + + if !OPTIMISTIC { + assert_eq!(commit, root); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + height +} diff --git a/extensions/native/circuit/src/poseidon2/mod.rs b/extensions/native/circuit/src/poseidon2/mod.rs index af503e20f4..96907f6587 100644 --- a/extensions/native/circuit/src/poseidon2/mod.rs +++ b/extensions/native/circuit/src/poseidon2/mod.rs @@ -1,8 +1,13 @@ +use openvm_circuit::arch::VmChipWrapper; + +use crate::chip::NativePoseidon2Filler; + pub mod air; pub mod chip; -mod columns; +pub mod columns; #[cfg(test)] mod tests; -mod trace; const CHUNK: usize = 8; +pub type NativePoseidon2Chip = + VmChipWrapper>; diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index 32a0e483a3..64966585c0 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -1,11 +1,8 @@ -use std::{ - cmp::min, - sync::{Arc, Mutex}, -}; +use std::cmp::min; -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder, VmChipTester}, - verify_single, Streams, VirtualMachine, +use openvm_circuit::{ + arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder, VmChipTester}, + utils::air_test, }; use openvm_instructions::{instruction::Instruction, program::Program, LocalOpcode, SystemOpcode}; use openvm_native_compiler::{ @@ -16,14 +13,16 @@ use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32, PrimeField64}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, verifier::VerificationError, }; use openvm_stark_sdk::{ config::{ baby_bear_blake3::{BabyBearBlake3Config, BabyBearBlake3Engine}, - baby_bear_poseidon2::BabyBearPoseidon2Engine, - fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, }, engine::StarkFriEngine, @@ -34,12 +33,38 @@ use rand::{rngs::StdRng, Rng}; use super::air::VerifyBatchBus; use crate::{ - poseidon2::{chip::NativePoseidon2Chip, CHUNK}, - NativeConfig, + air::NativePoseidon2Air, + chip::NativePoseidon2Executor, + poseidon2::{chip::NativePoseidon2Filler, CHUNK}, + NativeConfig, NativeCpuBuilder, NativePoseidon2Chip, }; const VERIFY_BATCH_BUS: VerifyBatchBus = VerifyBatchBus::new(7); +const MAX_INS_CAPACITY: usize = 1 << 15; +type Harness = TestChipHarness< + F, + NativePoseidon2Executor, + NativePoseidon2Air, + NativePoseidon2Chip, +>; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> Harness { + let air = NativePoseidon2Air::new( + tester.execution_bridge(), + tester.memory_bridge(), + VERIFY_BATCH_BUS, + Poseidon2Config::default(), + ); + let step = NativePoseidon2Executor::new(Poseidon2Config::default()); + let chip = NativePoseidon2Chip::new( + NativePoseidon2Filler::new(Poseidon2Config::default()), + tester.memory_helper(), + ); + Harness::with_capacity(step, air, chip, MAX_INS_CAPACITY) +} fn compute_commit( dim: &[usize], opened: &[Vec], @@ -140,140 +165,144 @@ fn random_instance( const SBOX_REGISTERS: usize = 1; +#[derive(Clone)] struct Case { row_lengths: Vec>, opened_element_size: usize, } +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + case: Case, +) { + let instance = random_instance( + rng, + case.row_lengths, + case.opened_element_size, + |left, right| { + let concatenated = + std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); + let permuted = harness.executor.subchip.permute(concatenated); + ( + std::array::from_fn(|i| permuted[i]), + std::array::from_fn(|i| permuted[i + CHUNK]), + ) + }, + ); + let VerifyBatchInstance { + dim, + opened, + proof, + sibling_is_on_right, + commit, + } = instance; + + let dim_register = gen_pointer(rng, 1); + let opened_register = gen_pointer(rng, 1); + let opened_length_register = gen_pointer(rng, 1); + let proof_id = gen_pointer(rng, 1); + let index_register = gen_pointer(rng, 1); + let commit_register = gen_pointer(rng, 1); + + let dim_base_pointer = gen_pointer(rng, 1); + let opened_base_pointer = gen_pointer(rng, 2); + let index_base_pointer = gen_pointer(rng, 1); + let commit_pointer = gen_pointer(rng, 1); + + let address_space = AS::Native as usize; + tester.write_usize(address_space, dim_register, [dim_base_pointer]); + tester.write_usize(address_space, opened_register, [opened_base_pointer]); + tester.write_usize(address_space, opened_length_register, [opened.len()]); + tester.write_usize(address_space, proof_id, [tester.streams.hint_space.len()]); + tester.write_usize(address_space, index_register, [index_base_pointer]); + tester.write_usize(address_space, commit_register, [commit_pointer]); + + for (i, &dim_value) in dim.iter().enumerate() { + tester.write_usize(address_space, dim_base_pointer + i, [dim_value]); + } + for (i, opened_row) in opened.iter().enumerate() { + let row_pointer = gen_pointer(rng, 1); + tester.write_usize( + address_space, + opened_base_pointer + (2 * i), + [row_pointer, opened_row.len() / case.opened_element_size], + ); + for (j, &opened_value) in opened_row.iter().enumerate() { + tester.write(address_space, row_pointer + j, [opened_value]); + } + } + tester + .streams + .hint_space + .push(proof.iter().flatten().copied().collect()); + for (i, &bit) in sibling_is_on_right.iter().enumerate() { + tester.write(address_space, index_base_pointer + i, [F::from_bool(bit)]); + } + tester.write(address_space, commit_pointer, commit); + + let opened_element_size_inv = F::from_canonical_usize(case.opened_element_size) + .inverse() + .as_canonical_u32() as usize; + tester.execute( + harness, + &Instruction::from_usize( + VERIFY_BATCH.global_opcode(), + [ + dim_register, + opened_register, + opened_length_register, + proof_id, + index_register, + commit_register, + opened_element_size_inv, + ], + ), + ); +} + fn test(cases: [Case; N]) { unsafe { std::env::set_var("RUST_BACKTRACE", "1"); } - - // single op - let address_space = AS::Native as usize; - - let mut tester = VmChipTestBuilder::default(); - let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = NativePoseidon2Chip::::new( - tester.system_port(), - tester.offline_memory_mutex_arc(), - Poseidon2Config::default(), - VERIFY_BATCH_BUS, - streams.clone(), - ); + let mut valid_tester = VmChipTestBuilder::default_native(); + let mut valid_harness = create_test_chip::(&valid_tester); + let mut prank_tester = VmChipTestBuilder::default_native(); + let mut prank_harness = create_test_chip::(&prank_tester); let mut rng = create_seeded_rng(); - for Case { - row_lengths, - opened_element_size, - } in cases - { - let mut streams = streams.lock().unwrap(); - let instance = - random_instance(&mut rng, row_lengths, opened_element_size, |left, right| { - let concatenated = - std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); - let permuted = chip.subchip.permute(concatenated); - ( - std::array::from_fn(|i| permuted[i]), - std::array::from_fn(|i| permuted[i + CHUNK]), - ) - }); - let VerifyBatchInstance { - dim, - opened, - proof, - sibling_is_on_right, - commit, - } = instance; - - let dim_register = gen_pointer(&mut rng, 1); - let opened_register = gen_pointer(&mut rng, 1); - let opened_length_register = gen_pointer(&mut rng, 1); - let proof_id = gen_pointer(&mut rng, 1); - let index_register = gen_pointer(&mut rng, 1); - let commit_register = gen_pointer(&mut rng, 1); - - let dim_base_pointer = gen_pointer(&mut rng, 1); - let opened_base_pointer = gen_pointer(&mut rng, 2); - let index_base_pointer = gen_pointer(&mut rng, 1); - let commit_pointer = gen_pointer(&mut rng, 1); - - tester.write_usize(address_space, dim_register, [dim_base_pointer]); - tester.write_usize(address_space, opened_register, [opened_base_pointer]); - tester.write_usize(address_space, opened_length_register, [opened.len()]); - tester.write_usize(address_space, proof_id, [streams.hint_space.len()]); - tester.write_usize(address_space, index_register, [index_base_pointer]); - tester.write_usize(address_space, commit_register, [commit_pointer]); - - for (i, &dim_value) in dim.iter().enumerate() { - tester.write_usize(address_space, dim_base_pointer + i, [dim_value]); - } - for (i, opened_row) in opened.iter().enumerate() { - let row_pointer = gen_pointer(&mut rng, 1); - tester.write_usize( - address_space, - opened_base_pointer + (2 * i), - [row_pointer, opened_row.len() / opened_element_size], - ); - for (j, &opened_value) in opened_row.iter().enumerate() { - tester.write_cell(address_space, row_pointer + j, opened_value); - } - } - streams - .hint_space - .push(proof.iter().flatten().copied().collect()); - drop(streams); - for (i, &bit) in sibling_is_on_right.iter().enumerate() { - tester.write_cell(address_space, index_base_pointer + i, F::from_bool(bit)); - } - tester.write(address_space, commit_pointer, commit); - - let opened_element_size_inv = F::from_canonical_usize(opened_element_size) - .inverse() - .as_canonical_u32() as usize; - tester.execute( - &mut chip, - &Instruction::from_usize( - VERIFY_BATCH.global_opcode(), - [ - dim_register, - opened_register, - opened_length_register, - proof_id, - index_register, - commit_register, - opened_element_size_inv, - ], - ), + for case in cases { + set_and_execute( + &mut valid_tester, + &mut valid_harness, + &mut rng, + case.clone(), ); + set_and_execute(&mut prank_tester, &mut prank_harness, &mut rng, case); } - let mut tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); + let valid_tester = valid_tester.build().load(valid_harness).finalize(); + valid_tester.simple_test().expect("Verification failed"); disable_debug_builder(); - let trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let row_index = 0; - trace.row_mut(row_index); - let p2_chip = Poseidon2SubChip::::new(Poseidon2Config::default().constants); let inner_trace = p2_chip.generate_trace(vec![[F::ZERO; 2 * CHUNK]]); let inner_width = p2_chip.air.width(); - trace.row_mut(row_index)[..inner_width].copy_from_slice(&inner_trace.values); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); + trace_row[..inner_width].copy_from_slice(&inner_trace.values); + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; + + let prank_tester = prank_tester + .build() + .load_and_prank_trace(prank_harness, modify_trace) + .finalize(); + // Run a test after pranking the poseidon2 stuff - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" - ); + prank_tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); } #[test] @@ -383,15 +412,8 @@ fn random_instructions(num_ops: usize) -> Vec> { fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { let elem_range = || 1..=100; - let mut tester = VmChipTestBuilder::default(); - let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = NativePoseidon2Chip::::new( - tester.system_port(), - tester.offline_memory_mutex_arc(), - Poseidon2Config::default(), - VERIFY_BATCH_BUS, - streams.clone(), - ); + let mut tester = VmChipTestBuilder::default_native(); + let mut harness = create_test_chip::(&tester); let mut rng = create_seeded_rng(); @@ -417,27 +439,28 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { - let data_left: [_; CHUNK] = std::array::from_fn(|i| data[i]); - let data_right: [_; CHUNK] = std::array::from_fn(|i| data[CHUNK + i]); tester.write(e, lhs, data_left); tester.write(e, rhs, data_right); } PERM_POS2 => { - tester.write(e, lhs, data); + tester.write(e, lhs, data_left); + tester.write(e, lhs + CHUNK, data_right); } } - tester.execute(&mut chip, &instruction); + tester.execute(&mut harness, &instruction); match opcode { COMP_POS2 => { @@ -446,12 +469,14 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { - let actual = tester.read::<{ 2 * CHUNK }>(e, dst); - assert_eq!(hash, actual); + let actual_0 = tester.read::<{ CHUNK }>(e, dst); + let actual_1 = tester.read::<{ CHUNK }>(e, dst + CHUNK); + let actual = [actual_0, actual_1].concat(); + assert_eq!(&hash, &actual[..]); } } } - tester.build().load(chip).finalize() + tester.build().load(harness).finalize() } fn get_engine() -> BabyBearBlake3Engine { @@ -476,34 +501,6 @@ fn verify_batch_chip_simple_50() { tester.test(get_engine).expect("Verification failed"); } -// log_blowup = 3 for poseidon2 chip -fn air_test_with_compress_poseidon2( - poseidon2_max_constraint_degree: usize, - program: Program, -) { - let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { - FriParameters { - log_blowup: 3, - log_final_poly_len: 0, - num_queries: 2, - proof_of_work_bits: 0, - } - } else { - standard_fri_params_with_100_bits_conjectured_security(3) - }; - let engine = BabyBearPoseidon2Engine::new(fri_params); - - let config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); - let vm = VirtualMachine::new(engine, config); - - let pk = vm.keygen(); - let result = vm.execute_and_generate(program, vec![]).unwrap(); - let proofs = vm.prove(&pk, result); - for proof in proofs { - verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); - } -} - #[test] fn test_vm_compress_poseidon2_as4() { let mut rng = create_seeded_rng(); @@ -594,6 +591,14 @@ fn test_vm_compress_poseidon2_as4() { let program = Program::from_instructions(&instructions); - air_test_with_compress_poseidon2(3, program.clone()); - air_test_with_compress_poseidon2(7, program.clone()); + air_test( + NativeCpuBuilder, + NativeConfig::aggregation(0, 3), + program.clone(), + ); + air_test( + NativeCpuBuilder, + NativeConfig::aggregation(0, 7), + program.clone(), + ); } diff --git a/extensions/native/circuit/src/poseidon2/trace.rs b/extensions/native/circuit/src/poseidon2/trace.rs deleted file mode 100644 index df8547767f..0000000000 --- a/extensions/native/circuit/src/poseidon2/trace.rs +++ /dev/null @@ -1,485 +0,0 @@ -use std::{borrow::BorrowMut, sync::Arc}; - -use openvm_circuit::system::memory::{MemoryAuxColsFactory, OfflineMemory}; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::Poseidon2Opcode::COMP_POS2; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{Field, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, -}; - -use crate::{ - chip::{SimplePoseidonRecord, NUM_INITIAL_READS}, - poseidon2::{ - chip::{ - CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, - NativePoseidon2Chip, VerifyBatchRecord, - }, - columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, - }, - CHUNK, - }, -}; -impl ChipUsageGetter - for NativePoseidon2Chip -{ - fn air_name(&self) -> String { - "VerifyBatchAir".to_string() - } - - fn current_trace_height(&self) -> usize { - self.height - } - - fn trace_width(&self) -> usize { - NativePoseidon2Cols::::width() - } -} - -impl NativePoseidon2Chip { - fn generate_subair_cols(&self, input: [F; 2 * CHUNK], cols: &mut [F]) { - let inner_trace = self.subchip.generate_trace(vec![input]); - let inner_width = self.air.subair.width(); - cols[..inner_width].copy_from_slice(inner_trace.values.as_slice()); - } - #[allow(clippy::too_many_arguments)] - fn incorporate_sibling_record_to_row( - &self, - record: &IncorporateSiblingRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &VerifyBatchRecord, - proof_index: usize, - opened_index: usize, - log_height: usize, - ) { - let &IncorporateSiblingRecord { - read_sibling_is_on_right, - sibling_is_on_right, - p2_input, - } = record; - - let read_sibling_is_on_right = memory.record_by_id(read_sibling_is_on_right); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ONE; - cols.inside_row = F::ZERO; - cols.simple = F::ZERO; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.start_top_level = F::ZERO; - cols.opened_element_size_inv = parent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp); - cols.start_timestamp = - F::from_canonical_u32(read_sibling_is_on_right.timestamp - NUM_INITIAL_READS as u32); - - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.end_timestamp = - F::from_canonical_usize(read_sibling_is_on_right.timestamp as usize + 1); - cols.initial_opened_index = F::from_canonical_usize(opened_index); - specific.final_opened_index = F::from_canonical_usize(opened_index - 1); - specific.log_height = F::from_canonical_usize(log_height); - specific.opened_length = F::from_canonical_usize(parent.opened_length); - specific.dim_base_pointer = parent.dim_base_pointer; - cols.opened_base_pointer = parent.opened_base_pointer; - specific.index_base_pointer = parent.index_base_pointer; - - specific.proof_index = F::from_canonical_usize(proof_index); - aux_cols_factory.generate_read_aux( - read_sibling_is_on_right, - &mut specific.read_initial_height_or_sibling_is_on_right, - ); - specific.sibling_is_on_right = F::from_bool(sibling_is_on_right); - } - fn correct_last_top_level_row( - &self, - record: &VerifyBatchRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) { - let &VerifyBatchRecord { - from_state, - commit_pointer, - dim_base_pointer_read, - opened_base_pointer_read, - opened_length_read, - index_base_pointer_read, - commit_pointer_read, - commit_read, - .. - } = record; - let instruction = &record.instruction; - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.end_top_level = F::ONE; - - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.pc = F::from_canonical_u32(from_state.pc); - specific.dim_register = instruction.a; - specific.opened_register = instruction.b; - specific.opened_length_register = instruction.c; - specific.proof_id = instruction.d; - specific.index_register = instruction.e; - specific.commit_register = instruction.f; - specific.commit_pointer = commit_pointer; - aux_cols_factory.generate_read_aux( - memory.record_by_id(dim_base_pointer_read), - &mut specific.dim_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(opened_base_pointer_read), - &mut specific.opened_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(opened_length_read), - &mut specific.opened_length_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(index_base_pointer_read), - &mut specific.index_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(commit_pointer_read), - &mut specific.commit_pointer_read, - ); - aux_cols_factory - .generate_read_aux(memory.record_by_id(commit_read), &mut specific.commit_read); - } - #[allow(clippy::too_many_arguments)] - fn incorporate_row_record_to_row( - &self, - record: &IncorporateRowRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &VerifyBatchRecord, - proof_index: usize, - log_height: usize, - ) { - let &IncorporateRowRecord { - initial_opened_index, - final_opened_index, - initial_height_read, - final_height_read, - p2_input, - .. - } = record; - - let initial_height_read = memory.record_by_id(initial_height_read); - let final_height_read = memory.record_by_id(final_height_read); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ONE; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ZERO; - cols.simple = F::ZERO; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.start_top_level = F::from_bool(proof_index == 0); - cols.opened_element_size_inv = parent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp); - cols.start_timestamp = F::from_canonical_u32( - memory - .record_by_id( - record.chunks[0].cells[0] - .read_row_pointer_and_length - .unwrap(), - ) - .timestamp - - NUM_INITIAL_READS as u32, - ); - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.end_timestamp = F::from_canonical_u32(final_height_read.timestamp + 1); - - cols.initial_opened_index = F::from_canonical_usize(initial_opened_index); - specific.final_opened_index = F::from_canonical_usize(final_opened_index); - specific.log_height = F::from_canonical_usize(log_height); - specific.opened_length = F::from_canonical_usize(parent.opened_length); - specific.dim_base_pointer = parent.dim_base_pointer; - cols.opened_base_pointer = parent.opened_base_pointer; - specific.index_base_pointer = parent.index_base_pointer; - - specific.proof_index = F::from_canonical_usize(proof_index); - aux_cols_factory.generate_read_aux( - initial_height_read, - &mut specific.read_initial_height_or_sibling_is_on_right, - ); - aux_cols_factory.generate_read_aux(final_height_read, &mut specific.read_final_height); - } - #[allow(clippy::too_many_arguments)] - fn inside_row_record_to_row( - &self, - record: &InsideRowRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &IncorporateRowRecord, - grandparent: &VerifyBatchRecord, - is_last: bool, - ) { - let InsideRowRecord { cells, p2_input } = record; - - self.generate_subair_cols(*p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ONE; - cols.simple = F::ZERO; - cols.end_inside_row = F::from_bool(is_last); - cols.end_top_level = F::ZERO; - cols.opened_element_size_inv = grandparent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32( - memory - .record_by_id( - parent.chunks[0].cells[0] - .read_row_pointer_and_length - .unwrap(), - ) - .timestamp, - ); - cols.start_timestamp = - F::from_canonical_u32(memory.record_by_id(cells[0].read).timestamp - 1); - let specific: &mut InsideRowSpecificCols = - cols.specific[..InsideRowSpecificCols::::width()].borrow_mut(); - - for (record, cell) in cells.iter().zip(specific.cells.iter_mut()) { - let &CellRecord { - read, - opened_index, - read_row_pointer_and_length, - row_pointer, - row_end, - } = record; - aux_cols_factory.generate_read_aux(memory.record_by_id(read), &mut cell.read); - cell.opened_index = F::from_canonical_usize(opened_index); - if let Some(read_row_pointer_and_length) = read_row_pointer_and_length { - aux_cols_factory.generate_read_aux( - memory.record_by_id(read_row_pointer_and_length), - &mut cell.read_row_pointer_and_length, - ); - } - cell.row_pointer = F::from_canonical_usize(row_pointer); - cell.row_end = F::from_canonical_usize(row_end); - cell.is_first_in_row = F::from_bool(read_row_pointer_and_length.is_some()); - } - - for cell in specific.cells.iter_mut().skip(cells.len()) { - cell.opened_index = F::from_canonical_usize(parent.final_opened_index); - } - - cols.is_exhausted = std::array::from_fn(|i| F::from_bool(i + 1 >= cells.len())); - - cols.initial_opened_index = F::from_canonical_usize(parent.initial_opened_index); - cols.opened_base_pointer = grandparent.opened_base_pointer; - } - // returns number of used cells - fn verify_batch_record_to_rows( - &self, - record: &VerifyBatchRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) -> usize { - let width = NativePoseidon2Cols::::width(); - let mut used_cells = 0; - - let mut opened_index = 0; - for (proof_index, top_level) in record.top_level.iter().enumerate() { - let log_height = record.initial_log_height - proof_index; - if let Some(incorporate_row) = &top_level.incorporate_row { - self.incorporate_row_record_to_row( - incorporate_row, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - record, - proof_index, - log_height, - ); - opened_index = incorporate_row.final_opened_index + 1; - used_cells += width; - } - if let Some(incorporate_sibling) = &top_level.incorporate_sibling { - self.incorporate_sibling_record_to_row( - incorporate_sibling, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - record, - proof_index, - opened_index, - log_height, - ); - used_cells += width; - } - } - self.correct_last_top_level_row( - record, - aux_cols_factory, - &mut slice[used_cells - width..used_cells], - memory, - ); - - for top_level in record.top_level.iter() { - if let Some(incorporate_row) = &top_level.incorporate_row { - for (i, chunk) in incorporate_row.chunks.iter().enumerate() { - self.inside_row_record_to_row( - chunk, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - incorporate_row, - record, - i == incorporate_row.chunks.len() - 1, - ); - used_cells += width; - } - } - } - - used_cells - } - fn simple_record_to_row( - &self, - record: &SimplePoseidonRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) { - let &SimplePoseidonRecord { - from_state, - instruction: - Instruction { - opcode, - a: output_register, - b: input_register_1, - c: input_register_2, - .. - }, - read_input_pointer_1, - read_input_pointer_2, - read_output_pointer, - read_data_1, - read_data_2, - write_data_1, - write_data_2, - input_pointer_1, - input_pointer_2, - output_pointer, - p2_input, - } = record; - - let read_input_pointer_1 = memory.record_by_id(read_input_pointer_1); - let read_output_pointer = memory.record_by_id(read_output_pointer); - let read_data_1 = memory.record_by_id(read_data_1); - let read_data_2 = memory.record_by_id(read_data_2); - let write_data_1 = memory.record_by_id(write_data_1); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ZERO; - cols.simple = F::ONE; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.is_exhausted = [F::ZERO; CHUNK - 1]; - - cols.start_timestamp = F::from_canonical_u32(from_state.timestamp); - let specific: &mut SimplePoseidonSpecificCols = - cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); - - specific.pc = F::from_canonical_u32(from_state.pc); - specific.is_compress = F::from_bool(opcode == COMP_POS2.global_opcode()); - specific.output_register = output_register; - specific.input_register_1 = input_register_1; - specific.input_register_2 = input_register_2; - specific.output_pointer = output_pointer; - specific.input_pointer_1 = input_pointer_1; - specific.input_pointer_2 = input_pointer_2; - aux_cols_factory.generate_read_aux(read_output_pointer, &mut specific.read_output_pointer); - aux_cols_factory - .generate_read_aux(read_input_pointer_1, &mut specific.read_input_pointer_1); - aux_cols_factory.generate_read_aux(read_data_1, &mut specific.read_data_1); - aux_cols_factory.generate_read_aux(read_data_2, &mut specific.read_data_2); - aux_cols_factory.generate_write_aux(write_data_1, &mut specific.write_data_1); - - if opcode == COMP_POS2.global_opcode() { - let read_input_pointer_2 = memory.record_by_id(read_input_pointer_2.unwrap()); - aux_cols_factory - .generate_read_aux(read_input_pointer_2, &mut specific.read_input_pointer_2); - } else { - let write_data_2 = memory.record_by_id(write_data_2.unwrap()); - aux_cols_factory.generate_write_aux(write_data_2, &mut specific.write_data_2); - } - } - - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); - - let memory = self.offline_memory.lock().unwrap(); - - let aux_cols_factory = memory.aux_cols_factory(); - - let mut used_cells = 0; - for record in self.record_set.verify_batch_records.iter() { - used_cells += self.verify_batch_record_to_rows( - record, - &aux_cols_factory, - &mut flat_trace[used_cells..], - &memory, - ); - } - for record in self.record_set.simple_permute_records.iter() { - self.simple_record_to_row( - record, - &aux_cols_factory, - &mut flat_trace[used_cells..used_cells + width], - &memory, - ); - used_cells += width; - } - // poseidon2 constraints are always checked - // following can be optimized to only hash [0; _] once - flat_trace[used_cells..] - .par_chunks_mut(width) - .for_each(|row| { - self.generate_subair_cols([F::ZERO; 2 * CHUNK], row); - }); - - RowMajorMatrix::new(flat_trace, width) - } -} - -impl Chip - for NativePoseidon2Chip, SBOX_REGISTERS> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - fn generate_air_proof_input(self) -> AirProofInput { - AirProofInput::simple_no_pis(self.generate_trace()) - } -} diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index 2815427336..25b04b356e 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -1,19 +1,162 @@ -use openvm_circuit::arch::{Streams, SystemConfig, VmExecutor}; -use openvm_instructions::program::Program; -use openvm_stark_sdk::p3_baby_bear::BabyBear; +pub(crate) const CASTF_MAX_BITS: usize = 30; -use crate::{Native, NativeConfig}; +pub(crate) const fn const_max(a: usize, b: usize) -> usize { + [a, b][(a < b) as usize] +} -pub fn execute_program(program: Program, input_stream: impl Into>) { - let system_config = SystemConfig::default() - .with_public_values(4) - .with_max_segment_len((1 << 25) - 100); - let config = NativeConfig::new(system_config, Native); - let executor = VmExecutor::::new(config); +/// Testing framework +#[cfg(any(test, feature = "test-utils"))] +pub mod test_utils { + use std::array; - executor.execute(program, input_stream).unwrap(); -} + use openvm_circuit::{ + arch::{ + execution_mode::metered::Segment, + testing::{memory::gen_pointer, VmChipTestBuilder}, + MatrixRecordArena, PreflightExecutionOutput, Streams, VirtualMachine, + VirtualMachineError, VmBuilder, VmState, + }, + utils::test_system_config, + }; + use openvm_instructions::{ + exe::VmExe, + program::Program, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + }; + use openvm_native_compiler::conversion::AS; + use openvm_stark_backend::{ + config::Domain, p3_commit::PolynomialSpace, p3_field::PrimeField32, + }; + use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, setup_tracing, FriParameters}, + engine::StarkFriEngine, + p3_baby_bear::BabyBear, + }; + use rand::{distributions::Standard, prelude::Distribution, rngs::StdRng, Rng}; -pub(crate) const fn const_max(a: usize, b: usize) -> usize { - [a, b][(a < b) as usize] + use crate::{NativeConfig, NativeCpuBuilder, Rv32WithKernelsConfig}; + + // If immediate, returns (value, AS::Immediate). Otherwise, writes to native memory and returns + // (ptr, AS::Native). If is_imm is None, randomizes it. + pub fn write_native_or_imm( + tester: &mut VmChipTestBuilder, + rng: &mut StdRng, + value: F, + is_imm: Option, + ) -> (F, usize) { + let is_imm = is_imm.unwrap_or(rng.gen_bool(0.5)); + if is_imm { + (value, AS::Immediate as usize) + } else { + let ptr = gen_pointer(rng, 1); + tester.write::<1>(AS::Native as usize, ptr, [value]); + (F::from_canonical_usize(ptr), AS::Native as usize) + } + } + + // Writes value to native memory and returns a pointer to the first element together with the + // value If `value` is None, randomizes it. + pub fn write_native_array( + tester: &mut VmChipTestBuilder, + rng: &mut StdRng, + value: Option<[F; N]>, + ) -> ([F; N], usize) + where + Standard: Distribution, // Needed for `rng.gen` + { + let value = value.unwrap_or(array::from_fn(|_| rng.gen())); + let ptr = gen_pointer(rng, N); + tester.write::(AS::Native as usize, ptr, value); + (value, ptr) + } + + // Besides taking in system_config, this also returns Result and the full + // (PreflightExecutionOutput, VirtualMachine) for more advanced testing needs. + #[allow(clippy::type_complexity)] + pub fn execute_program_with_config( + program: Program, + input_stream: impl Into>, + builder: VB, + config: VB::VmConfig, + ) -> Result< + ( + PreflightExecutionOutput>, + VirtualMachine, + ), + VirtualMachineError, + > + where + E: StarkFriEngine, + Domain: PolynomialSpace, + VB: VmBuilder>, + { + setup_tracing(); + assert!(!config.as_ref().continuation_enabled); + let input = input_stream.into(); + + let engine = E::new(FriParameters::new_for_testing(1)); + let (vm, _) = VirtualMachine::new_with_keygen(engine, builder, config)?; + let ctx = vm.build_metered_ctx(); + let executor_idx_to_air_idx = vm.executor_idx_to_air_idx(); + let exe = VmExe::new(program); + let interpreter = vm + .executor() + .metered_instance(&exe, &executor_idx_to_air_idx)?; + let (mut segments, _) = interpreter.execute_metered(input.clone(), ctx)?; + assert_eq!(segments.len(), 1, "test only supports one segment"); + let Segment { + instret_start, + num_insns, + trace_heights, + } = segments.pop().unwrap(); + assert_eq!(instret_start, 0); + let state = vm.create_initial_state(&exe, input); + let output = vm.execute_preflight(&exe, state, None, &trace_heights)?; + assert_eq!( + output.to_state.instret, num_insns, + "metered execution insn count doesn't match preflight execution" + ); + Ok((output, vm)) + } + + pub fn execute_program( + program: Program, + input_stream: impl Into>, + ) -> VmState { + let mut config = test_native_config(); + config.system.num_public_values = 4; + // we set max segment len large so it doesn't segment + let (output, _) = execute_program_with_config::( + program, + input_stream, + NativeCpuBuilder, + config, + ) + .unwrap(); + output.to_state + } + + pub fn test_native_config() -> NativeConfig { + let mut system = test_system_config(); + system.memory_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 0; + system.memory_config.addr_spaces[RV32_MEMORY_AS as usize].num_cells = 0; + NativeConfig { + system, + native: Default::default(), + } + } + + pub fn test_native_continuations_config() -> NativeConfig { + NativeConfig { + system: test_system_config().with_continuations(), + native: Default::default(), + } + } + + pub fn test_rv32_with_kernels_config() -> Rv32WithKernelsConfig { + Rv32WithKernelsConfig { + system: test_system_config().with_continuations(), + ..Default::default() + } + } } diff --git a/extensions/native/compiler/Cargo.toml b/extensions/native/compiler/Cargo.toml index cb41c17f63..83938ed4ab 100644 --- a/extensions/native/compiler/Cargo.toml +++ b/extensions/native/compiler/Cargo.toml @@ -34,7 +34,7 @@ strum = { workspace = true } [dev-dependencies] p3-symmetric = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-native-circuit = { workspace = true } +openvm-native-circuit = { workspace = true, features = ["test-utils"]} openvm-stark-sdk = { workspace = true } rand.workspace = true @@ -42,4 +42,4 @@ rand.workspace = true default = ["parallel", "halo2-compiler"] halo2-compiler = ["dep:snark-verifier-sdk"] parallel = ["openvm-circuit/parallel"] -bench-metrics = ["dep:metrics", "openvm-circuit/bench-metrics"] +metrics = ["dep:metrics", "openvm-circuit/metrics"] diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index ce108addaa..fd75d526d2 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -7,7 +7,7 @@ use std::{ }; use itertools::Itertools; -#[cfg(feature = "bench-metrics")] +#[cfg(feature = "metrics")] use openvm_circuit::metrics::cycle_tracker::CycleTracker; use openvm_stark_backend::p3_field::{ExtensionField, Field, FieldAlgebra, PrimeField}; use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr}; @@ -135,7 +135,7 @@ impl Halo2ConstraintCompiler { where C: Config, { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] let mut cell_tracker = CycleTracker::new(); let range = Arc::new(halo2_state.builder.range_chip()); let f_chip = Arc::new(BabyBearChip::new(range.clone())); @@ -149,10 +149,10 @@ impl Halo2ConstraintCompiler { let mut felts = HashMap::::new(); let mut exts = HashMap::::new(); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] let mut old_stats = stats_snapshot(ctx, range.clone()); for (instruction, backtrace) in operations { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] if self.profiling { old_stats = stats_snapshot(ctx, range.clone()); } @@ -492,11 +492,11 @@ impl Halo2ConstraintCompiler { range.check_less_than(ctx, vars[&a.0], vars[&b.0], C::F::bits()); } DslIr::CycleTrackerStart(_name) => { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] cell_tracker.start(_name); } DslIr::CycleTrackerEnd(_name) => { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] cell_tracker.end(_name); } DslIr::CircuitPublish(val, index) => { @@ -512,7 +512,7 @@ impl Halo2ConstraintCompiler { } res.unwrap(); } - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] if self.profiling { let mut new_stats = stats_snapshot(ctx, range.clone()); new_stats.diff(&old_stats); @@ -538,7 +538,7 @@ pub fn convert_efr>(a: &EF) -> Vec { } // Unfortunately `builder.statistics()` cannot be called when `ctx` exists. -#[allow(dead_code)] // used only in bench-metrics +#[allow(dead_code)] // used only in metrics fn stats_snapshot(ctx: &Context, range_chip: Arc>) -> Halo2Stats { Halo2Stats { total_gate_cell: ctx.advice.len(), diff --git a/extensions/native/compiler/src/constraints/halo2/stats.rs b/extensions/native/compiler/src/constraints/halo2/stats.rs index 0d5192ec82..c18f64d5cb 100644 --- a/extensions/native/compiler/src/constraints/halo2/stats.rs +++ b/extensions/native/compiler/src/constraints/halo2/stats.rs @@ -14,7 +14,7 @@ impl Halo2Stats { } } -#[cfg(feature = "bench-metrics")] +#[cfg(feature = "metrics")] mod emit { use metrics::counter; diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 9c3fc8d752..af4e5080fb 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -565,7 +565,7 @@ pub fn convert_program>( } } - let mut result = Program::new_empty(DEFAULT_PC_STEP, 0); + let mut result = Program::new_empty(0); result.push_instruction_and_debug_info(init_register_0, init_debug_info); for block in program.blocks.iter() { for (instruction, debug_info) in block.0.iter().zip(block.1.iter()) { diff --git a/extensions/native/compiler/tests/arithmetic.rs b/extensions/native/compiler/tests/arithmetic.rs index cd68fab563..6b50566a5d 100644 --- a/extensions/native/compiler/tests/arithmetic.rs +++ b/extensions/native/compiler/tests/arithmetic.rs @@ -1,4 +1,5 @@ use openvm_circuit::arch::{ExecutionError, VmExecutor}; +use openvm_instructions::exe::VmExe; use openvm_native_circuit::{execute_program, NativeConfig}; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler, AsmConfig}, @@ -391,8 +392,15 @@ fn assert_failed_assertion( builder: Builder>>, ) { let program = builder.compile_isa(); - - let executor = VmExecutor::::new(NativeConfig::aggregation(4, 3)); - let result = executor.execute(program, vec![]); - assert!(matches!(result, Err(ExecutionError::Fail { .. }))); + let exe = VmExe::new(program); + + let config = NativeConfig::aggregation(4, 3); + let executor = VmExecutor::new(config).unwrap(); + let instance = executor.instance(&exe).unwrap(); + let result = instance.execute(vec![], None); + assert!( + matches!(result, Err(ExecutionError::Fail { .. })), + "Unexpected result: {:?}", + result.err() + ); } diff --git a/extensions/native/compiler/tests/cycle_tracker.rs b/extensions/native/compiler/tests/cycle_tracker.rs index 3561dfd2ec..d8bcb59bce 100644 --- a/extensions/native/compiler/tests/cycle_tracker.rs +++ b/extensions/native/compiler/tests/cycle_tracker.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + use openvm_native_circuit::execute_program; use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions, ir::Var}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; @@ -43,7 +45,7 @@ fn test_cycle_tracker() { ..Default::default() }); - for (i, debug_info) in program.debug_infos().iter().enumerate() { + for (i, debug_info) in program.debug_infos().deref().iter().enumerate() { println!("debug_info {}: {:?}", i, debug_info); } diff --git a/extensions/native/circuit/examples/fibonacci.rs b/extensions/native/compiler/tests/fibonacci.rs similarity index 97% rename from extensions/native/circuit/examples/fibonacci.rs rename to extensions/native/compiler/tests/fibonacci.rs index aca5e2d6c5..8dfb29a835 100644 --- a/extensions/native/circuit/examples/fibonacci.rs +++ b/extensions/native/compiler/tests/fibonacci.rs @@ -47,6 +47,6 @@ fn main() { builder.halt(); let program = builder.compile_isa(); - println!("{}", program); + println!("{program}"); execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/public_values.rs b/extensions/native/compiler/tests/public_values.rs index 7c7abe3bc6..f3ad166d77 100644 --- a/extensions/native/compiler/tests/public_values.rs +++ b/extensions/native/compiler/tests/public_values.rs @@ -1,8 +1,11 @@ -use openvm_circuit::arch::{SingleSegmentVmExecutor, SystemConfig}; -use openvm_native_circuit::{execute_program, Native, NativeConfig}; +use openvm_circuit::{arch::PUBLIC_VALUES_AIR_ID, utils::air_test_impl}; +use openvm_native_circuit::{execute_program_with_config, test_native_config, NativeCpuBuilder}; use openvm_native_compiler::{asm::AsmBuilder, prelude::*}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; -use openvm_stark_sdk::p3_baby_bear::BabyBear; +use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, + p3_baby_bear::BabyBear, +}; type F = BabyBear; type EF = BinomialExtensionField; @@ -28,21 +31,26 @@ fn test_compiler_public_values() { } let program = builder.compile_isa(); - let executor = SingleSegmentVmExecutor::new(NativeConfig::new( - SystemConfig::default().with_public_values(2), - Native, - )); - - let exe_result = executor - .execute_and_compute_heights(program, vec![]) - .unwrap(); + let mut config = test_native_config(); + config.system.num_public_values = 2; + // This is to justify using log_blowup=1 + assert!(config.as_ref().max_constraint_degree <= 3); + let fri_params = FriParameters::new_for_testing(1); + let (_, mut vdata) = air_test_impl::( + fri_params, + NativeCpuBuilder, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); + assert_eq!(vdata.len(), 1); + let proof = vdata.pop().unwrap().data.proof; assert_eq!( - exe_result - .public_values - .into_iter() - .flatten() - .collect::>(), - vec![public_value_0, public_value_1] + &proof.get_public_values()[PUBLIC_VALUES_AIR_ID], + &[public_value_0, public_value_1] ); } @@ -66,5 +74,13 @@ fn test_compiler_public_values_no_initial() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + let (output, _) = execute_program_with_config::( + program, + vec![], + NativeCpuBuilder, + test_native_config(), + ) + .unwrap(); + assert_eq!(output.system_records.public_values[0], public_value_0); + assert_eq!(output.system_records.public_values[1], public_value_1); } diff --git a/extensions/native/recursion/Cargo.toml b/extensions/native/recursion/Cargo.toml index c799671a55..4a2600915f 100644 --- a/extensions/native/recursion/Cargo.toml +++ b/extensions/native/recursion/Cargo.toml @@ -8,7 +8,7 @@ repository.workspace = true [dependencies] openvm-stark-backend = { workspace = true } -openvm-native-circuit = { workspace = true } +openvm-native-circuit = { workspace = true, features = ["test-utils"] } openvm-native-compiler = { workspace = true } openvm-native-compiler-derive = { workspace = true } openvm-stark-sdk = { workspace = true } @@ -49,10 +49,10 @@ evm-verify = [ "snark-verifier-sdk/revm", ] # evm-verify needs REVM to simulate EVM contract verification test-utils = ["openvm-circuit/test-utils"] -bench-metrics = [ +metrics = [ "dep:metrics", - "openvm-circuit/bench-metrics", - "openvm-native-compiler/bench-metrics", + "openvm-circuit/metrics", + "openvm-native-compiler/metrics", ] parallel = ["openvm-stark-backend/parallel"] mimalloc = ["openvm-stark-backend/mimalloc"] diff --git a/extensions/native/recursion/src/fri/two_adic_pcs.rs b/extensions/native/recursion/src/fri/two_adic_pcs.rs index 676da7493f..3e66e05e61 100644 --- a/extensions/native/recursion/src/fri/two_adic_pcs.rs +++ b/extensions/native/recursion/src/fri/two_adic_pcs.rs @@ -627,6 +627,7 @@ pub mod tests { }; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, + engine::StarkEngine, p3_challenger::{CanObserve, FieldChallenger}, p3_commit::{Pcs, TwoAdicMultiplicativeCoset}, p3_matrix::dense::RowMajorMatrix, @@ -662,8 +663,8 @@ pub mod tests { let mut rng = &mut OsRng; let log_degrees = &[nb_log2_rows]; let engine = default_engine(); - let pcs = engine.config.pcs(); - let perm = engine.perm; + let pcs = engine.config().pcs(); + let perm = engine.perm.clone(); // Generate proof. let domains_and_polys = log_degrees diff --git a/extensions/native/recursion/src/halo2/mod.rs b/extensions/native/recursion/src/halo2/mod.rs index b53a298eb4..2046af0993 100644 --- a/extensions/native/recursion/src/halo2/mod.rs +++ b/extensions/native/recursion/src/halo2/mod.rs @@ -116,7 +116,7 @@ impl Halo2Prover { state.load_witness(witness); let backend = Halo2ConstraintCompiler::::new(dsl_operations.num_public_values); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] let backend = if profiling { backend.with_profiling() } else { @@ -174,10 +174,10 @@ impl Halo2Prover { // // pk // }; - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] let start = std::time::Instant::now(); let pk = keygen_pk2(params, &builder, false).unwrap(); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64); let break_points = builder.break_points(); @@ -212,13 +212,13 @@ impl Halo2Prover { profiling: bool, ) -> Snark { let k = config_params.k; - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] let start = std::time::Instant::now(); let builder = Self::builder(CircuitBuilderStage::Prover, k) .use_params(config_params) .use_break_points(break_points); let builder = Self::populate(builder, dsl_operations, witness, profiling); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] { let stats = builder.statistics(); let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum(); @@ -228,7 +228,7 @@ impl Halo2Prover { } let snark = gen_snark_shplonk(params, pk, builder, None::<&str>); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64); snark diff --git a/extensions/native/recursion/src/halo2/wrapper.rs b/extensions/native/recursion/src/halo2/wrapper.rs index 958c502a86..77d9978c38 100644 --- a/extensions/native/recursion/src/halo2/wrapper.rs +++ b/extensions/native/recursion/src/halo2/wrapper.rs @@ -57,7 +57,7 @@ impl Halo2WrapperProvingKey { } pub fn keygen(params: &Halo2Params, dummy_snark: Snark) -> Self { let k = params.k(); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] let start = std::time::Instant::now(); let mut circuit = generate_wrapper_circuit_object(CircuitBuilderStage::Keygen, k as usize, dummy_snark); @@ -67,11 +67,11 @@ impl Halo2WrapperProvingKey { "Wrapper circuit num advice: {:?}", config_params.num_advice_per_phase ); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] emit_wrapper_circuit_metrics(&circuit); let pk = keygen_pk2(params, &circuit, false).unwrap(); let num_pvs = circuit.instances().iter().map(|x| x.len()).collect_vec(); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] metrics::gauge!("halo2_keygen_time_ms").set(start.elapsed().as_millis() as f64); Self { pinning: Halo2ProvingPinning { @@ -112,7 +112,7 @@ impl Halo2WrapperProvingKey { } #[cfg(feature = "evm-prove")] pub fn prove_for_evm(&self, params: &Halo2Params, snark_to_verify: Snark) -> RawEvmProof { - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] let start = std::time::Instant::now(); let k = self.pinning.metadata.config_params.k; let prover_circuit = self.generate_circuit_object_for_proving(k, snark_to_verify); @@ -124,7 +124,7 @@ impl Halo2WrapperProvingKey { prover_circuit, pvs.clone(), ); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] metrics::gauge!("total_proof_time_ms").set(start.elapsed().as_millis() as f64); RawEvmProof { @@ -212,7 +212,7 @@ fn generate_wrapper_circuit_object( circuit } -#[cfg(feature = "bench-metrics")] +#[cfg(feature = "metrics")] fn emit_wrapper_circuit_metrics(agg_circuit: &AggregationCircuit) { let stats = agg_circuit.builder.statistics(); let total_advices: usize = stats.gate.total_advice_per_phase.into_iter().sum(); diff --git a/extensions/native/recursion/src/testing_utils.rs b/extensions/native/recursion/src/testing_utils.rs index 380b2aa9a3..62a8a25b27 100644 --- a/extensions/native/recursion/src/testing_utils.rs +++ b/extensions/native/recursion/src/testing_utils.rs @@ -1,17 +1,7 @@ -use inner::build_verification_program; -use openvm_circuit::{arch::instructions::program::Program, utils::execute_and_prove_program}; -use openvm_native_circuit::NativeConfig; -use openvm_native_compiler::conversion::CompilerOptions; -use openvm_stark_backend::{ - config::{Com, Domain, PcsProof, PcsProverData, StarkGenericConfig}, - engine::VerificationData, - p3_commit::PolynomialSpace, - verifier::VerificationError, -}; +use openvm_circuit::{arch::instructions::program::Program, utils::air_test_impl}; +use openvm_stark_backend::engine::VerificationData; use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, - engine::{StarkFriEngine, VerificationDataWithFriParams}, - p3_baby_bear::BabyBear, + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, utils::ProofInputForTest, }; @@ -20,7 +10,7 @@ use crate::hints::InnerVal; type InnerSC = BabyBearPoseidon2Config; pub mod inner { - use openvm_native_circuit::NativeConfig; + use openvm_native_circuit::{test_native_config, NativeCpuBuilder}; use openvm_native_compiler::conversion::CompilerOptions; use openvm_stark_sdk::{ config::{ @@ -42,12 +32,12 @@ pub mod inner { let advice = new_from_inner_multi_vk(&vk); cfg_if::cfg_if! { - if #[cfg(feature = "bench-metrics")] { + if #[cfg(feature = "metrics")] { let start = std::time::Instant::now(); } } let program = VerifierProgram::build_with_options(advice, &fri_params, compiler_options); - #[cfg(feature = "bench-metrics")] + #[cfg(feature = "metrics")] metrics::gauge!("verify_program_compile_ms").set(start.elapsed().as_millis() as f64); let mut input_stream = Vec::new(); @@ -72,36 +62,17 @@ pub mod inner { )) .unwrap(); - recursive_stark_test( - vparams, - CompilerOptions::default(), - NativeConfig::aggregation(4, 7), - &BabyBearPoseidon2Engine::new(fri_params), + let compiler_options = CompilerOptions::default(); + let (program, witness_stream) = build_verification_program(vparams, compiler_options); + air_test_impl::( + fri_params, + NativeCpuBuilder, + test_native_config(), + program, + witness_stream, + 1, + true, ) .unwrap(); } } - -/// 1. Builds the recursive verification program to verify `vparams` -/// 2. Execute and proves the program in VM with `AggSC` config using `engine`. -/// -/// The `vparams` must be from the BabyBearPoseidon2 stark config for the recursion -/// program to work at the moment. -#[allow(clippy::type_complexity)] -pub fn recursive_stark_test>( - vparams: VerificationDataWithFriParams, - compiler_options: CompilerOptions, - vm_config: NativeConfig, - engine: &E, -) -> Result, VerificationError> -where - Domain: PolynomialSpace, - Domain: Send + Sync, - PcsProverData: Send + Sync, - Com: Send + Sync, - PcsProof: Send + Sync, -{ - let (program, witness_stream) = build_verification_program(vparams, compiler_options); - - execute_and_prove_program(program, witness_stream, vm_config, engine) -} diff --git a/extensions/native/recursion/src/tests.rs b/extensions/native/recursion/src/tests.rs index 4077ee6f1d..627304e866 100644 --- a/extensions/native/recursion/src/tests.rs +++ b/extensions/native/recursion/src/tests.rs @@ -1,18 +1,19 @@ -use std::{panic::catch_unwind, sync::Arc}; +use std::sync::Arc; -use openvm_circuit::utils::gen_vm_program_test_proof_input; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{execute_program_with_config, test_native_config, NativeCpuBuilder}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, interaction::BusIndex, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, + prover::{ + hal::DeviceDataTransporter, + types::{AirProvingContext, ProvingContext}, + }, utils::disable_debug_builder, Chip, }; use openvm_stark_sdk::{ - collect_airs_and_inputs, config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, FriParameters, @@ -37,8 +38,12 @@ where Val: PrimeField32, { let fib_chip = FibonacciChip::new(0, 1, n); - let (airs, per_air) = collect_airs_and_inputs!(fib_chip); - ProofInputForTest { airs, per_air } + let airs = vec![fib_chip.air()]; + let air_ctx = fib_chip.generate_proving_ctx(()); + ProofInputForTest { + airs, + per_air: vec![air_ctx], + } } pub fn interaction_test_proof_input() -> ProofInputForTest @@ -62,7 +67,12 @@ where fields: vec![vec![1, 1], vec![1, 2], vec![3, 4], vec![9999, 0]], }); - let (airs, per_air) = collect_airs_and_inputs!(send_chip1, send_chip2, recv_chip); + let airs = vec![send_chip1.air(), send_chip2.air(), recv_chip.air()]; + let per_air = vec![ + send_chip1.generate_proving_ctx(()), + send_chip2.generate_proving_ctx(()), + recv_chip.generate_proving_ctx(()), + ]; ProofInputForTest { airs, per_air } } @@ -84,12 +94,12 @@ where receiver_air.field_width() + 1, ); - let sender_air_proof_input = AirProofInput::simple_no_pis(sender_trace); - let receiver_air_proof_input = AirProofInput::simple_no_pis(receiver_trace); + let sender_ctx = AirProvingContext::simple_no_pis(Arc::new(sender_trace)); + let receiver_ctx = AirProvingContext::simple_no_pis(Arc::new(receiver_trace)); ProofInputForTest { airs: vec![Arc::new(sender_air), Arc::new(receiver_air)], - per_air: vec![sender_air_proof_input, receiver_air_proof_input], + per_air: vec![sender_ctx, receiver_ctx], } } @@ -133,12 +143,12 @@ fn test_unordered() { #[test] fn test_optional_air() { - use openvm_stark_backend::{engine::StarkEngine, prover::types::ProofInput, Chip}; + use openvm_stark_backend::engine::StarkEngine; let fri_params = FriParameters::new_for_testing(3); let engine = BabyBearPoseidon2Engine::new(fri_params); let fib_chip = FibonacciChip::new(0, 1, 8); let send_chip1 = DummyInteractionChip::new_without_partition(1, true, 0); - let send_chip2 = DummyInteractionChip::new_with_partition(engine.config(), 1, true, 0); + let send_chip2 = DummyInteractionChip::new_with_partition(engine.device().clone(), 1, true, 0); let recv_chip1 = DummyInteractionChip::new_without_partition(1, false, 0); let mut keygen_builder = engine.keygen_builder(); let fib_chip_id = keygen_builder.add_air(fib_chip.air()); @@ -148,7 +158,7 @@ fn test_optional_air() { let pk = keygen_builder.generate_pk(); let m_advice = new_from_inner_multi_vk(&pk.get_vk()); - let vm_config = NativeConfig::aggregation(4, 7); + let config = test_native_config(); let program = VerifierProgram::build(m_advice, &fri_params); // Case 1: All AIRs are present. @@ -169,26 +179,27 @@ fn test_optional_air() { count: vec![2, 4, 12], fields: vec![vec![1], vec![2], vec![3]], }); - let proof = engine.prove( - &pk, - ProofInput { - per_air: vec![ - fib_chip.generate_air_proof_input_with_id(fib_chip_id), - send_chip1.generate_air_proof_input_with_id(send_chip1_id), - send_chip2.generate_air_proof_input_with_id(send_chip2_id), - recv_chip1.generate_air_proof_input_with_id(recv_chip1_id), - ], - }, - ); - engine - .verify(&pk.get_vk(), &proof) - .expect("Verification failed"); + let proof = engine + .prove_then_verify( + &pk, + ProvingContext { + per_air: vec![ + (fib_chip_id, fib_chip.generate_proving_ctx(())), + (send_chip1_id, send_chip1.generate_proving_ctx(())), + (send_chip2_id, send_chip2.generate_proving_ctx(())), + (recv_chip1_id, recv_chip1.generate_proving_ctx(())), + ], + }, + ) + .unwrap(); // The VM program will panic when the program cannot verify the proof. - gen_vm_program_test_proof_input::( + assert!(execute_program_with_config::( program.clone(), proof.write(), - vm_config.clone(), - ); + NativeCpuBuilder, + config.clone() + ) + .is_ok()); } // Case 2: The second AIR is not presented. { @@ -202,24 +213,25 @@ fn test_optional_air() { count: vec![1, 2, 4], fields: vec![vec![1], vec![2], vec![3]], }); - let proof = engine.prove( - &pk, - ProofInput { - per_air: vec![ - send_chip1.generate_air_proof_input_with_id(send_chip1_id), - recv_chip1.generate_air_proof_input_with_id(recv_chip1_id), - ], - }, - ); - engine - .verify(&pk.get_vk(), &proof) - .expect("Verification failed"); + let proof = engine + .prove_then_verify( + &pk, + ProvingContext { + per_air: vec![ + (send_chip1_id, send_chip1.generate_proving_ctx(())), + (recv_chip1_id, recv_chip1.generate_proving_ctx(())), + ], + }, + ) + .unwrap(); // The VM program will panic when the program cannot verify the proof. - gen_vm_program_test_proof_input::( + assert!(execute_program_with_config::( program.clone(), proof.write(), - vm_config.clone(), - ); + NativeCpuBuilder, + config.clone() + ) + .is_ok()); } // Case 3: Negative - unbalanced interactions. { @@ -229,21 +241,21 @@ fn test_optional_air() { count: vec![1, 2, 4], fields: vec![vec![1], vec![2], vec![3]], }); + let d_pk = engine.device().transport_pk_to_device(&pk); let proof = engine.prove( - &pk, - ProofInput { - per_air: vec![recv_chip1.generate_air_proof_input_with_id(recv_chip1_id)], + &d_pk, + ProvingContext { + per_air: vec![(recv_chip1_id, recv_chip1.generate_proving_ctx(()))], }, ); assert!(engine.verify(&pk.get_vk(), &proof).is_err()); // The VM program should panic when the proof cannot be verified. - let unwind_res = catch_unwind(|| { - gen_vm_program_test_proof_input::( - program.clone(), - proof.write(), - vm_config, - ) - }); - assert!(unwind_res.is_err()); + assert!(execute_program_with_config::( + program.clone(), + proof.write(), + NativeCpuBuilder, + config.clone() + ) + .is_err()); } } diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 8f354f3316..bec363e367 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -1,13 +1,28 @@ -use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmConfig, VmExecutor}; -use openvm_native_circuit::{Native, NativeConfig}; +use itertools::Itertools; +use openvm_circuit::arch::{ + instructions::program::Program, MatrixRecordArena, PreflightExecutionOutput, VmBuilder, + VmCircuitConfig, +}; +use openvm_native_circuit::{ + execute_program_with_config, test_native_config, NativeConfig, NativeCpuBuilder, +}; use openvm_native_compiler::{asm::AsmBuilder, ir::Felt}; use openvm_native_recursion::testing_utils::inner::run_recursive_test; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, p3_commit::PolynomialSpace, p3_field::{extension::BinomialExtensionField, FieldAlgebra}, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::{ + config::{ + baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, + FriParameters, + }, + engine::StarkFriEngine, + p3_baby_bear::BabyBear, + utils::ProofInputForTest, }; -use openvm_stark_sdk::{config::FriParameters, p3_baby_bear::BabyBear, utils::ProofInputForTest}; fn fibonacci_program(a: u32, b: u32, n: u32) -> Program { type F = BabyBear; @@ -35,28 +50,47 @@ fn fibonacci_program(a: u32, b: u32, n: u32) -> Program { builder.compile_isa() } -pub(crate) fn fibonacci_program_test_proof_input( +// We need this for both BabyBearPoseidon2Config and BabyBearPoseidon2RootConfig +pub(crate) fn fibonacci_program_test_proof_input( a: u32, b: u32, n: u32, -) -> ProofInputForTest +) -> ProofInputForTest where + SC: StarkGenericConfig, + E: StarkFriEngine, PD = CpuDevice>, Domain: PolynomialSpace, + NativeCpuBuilder: + VmBuilder>, { let fib_program = fibonacci_program(a, b, n); - let vm_config = NativeConfig::new(SystemConfig::default().with_public_values(3), Native); - let airs = vm_config.create_chip_complex().unwrap().airs(); + let mut config = test_native_config(); + config.as_mut().num_public_values = 3; - let executor = VmExecutor::::new(vm_config); + let (output, mut vm) = execute_program_with_config::( + fib_program.clone(), + vec![], + NativeCpuBuilder, + config.clone(), + ) + .unwrap(); + let PreflightExecutionOutput { + system_records, + record_arenas, + .. + } = output; + let committed_exe = vm.commit_exe(fib_program); + let cached_program_trace = vm.transport_committed_exe_to_device(&committed_exe); + vm.load_program(cached_program_trace); + let ctx = vm + .generate_proving_ctx(system_records, record_arenas) + .unwrap(); - let mut result = executor.execute_and_generate(fib_program, vec![]).unwrap(); - assert_eq!(result.per_segment.len(), 1, "unexpected continuation"); - let proof_input = result.per_segment.remove(0); - // Filter out unused AIRS (where trace is empty) - let (used_airs, per_air) = proof_input + let airs = config.create_airs().unwrap().into_airs().collect_vec(); + let (used_airs, per_air): (Vec<_>, Vec<_>) = ctx .per_air .into_iter() - .map(|(air_id, x)| (airs[air_id].clone(), x)) + .map(|(air_id, air_ctx)| (airs[air_id].clone(), air_ctx)) .unzip(); ProofInputForTest { airs: used_airs, @@ -66,7 +100,10 @@ where #[test] fn test_fibonacci_program_verify() { - let fib_program_stark = fibonacci_program_test_proof_input(0, 1, 32); + let fib_program_stark = fibonacci_program_test_proof_input::< + BabyBearPoseidon2Config, + BabyBearPoseidon2Engine, + >(0, 1, 32); run_recursive_test(fib_program_stark, FriParameters::new_for_testing(3)); } @@ -75,7 +112,13 @@ fn test_fibonacci_program_verify() { #[ignore = "slow"] fn test_fibonacci_program_halo2_verify() { use openvm_native_recursion::halo2::testing_utils::run_static_verifier_test; + use openvm_stark_sdk::config::baby_bear_poseidon2_root::{ + BabyBearPoseidon2RootConfig, BabyBearPoseidon2RootEngine, + }; - let fib_program_stark = fibonacci_program_test_proof_input(0, 1, 32); + let fib_program_stark = fibonacci_program_test_proof_input::< + BabyBearPoseidon2RootConfig, + BabyBearPoseidon2RootEngine, + >(0, 1, 32); run_static_verifier_test(fib_program_stark, FriParameters::new_for_testing(3)); } diff --git a/extensions/pairing/circuit/Cargo.toml b/extensions/pairing/circuit/Cargo.toml index af16f7eeab..a44afff0f8 100644 --- a/extensions/pairing/circuit/Cargo.toml +++ b/extensions/pairing/circuit/Cargo.toml @@ -8,7 +8,6 @@ homepage.workspace = true repository.workspace = true [dependencies] -openvm-circuit-primitives-derive = { workspace = true } openvm-circuit-primitives = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-circuit = { workspace = true } @@ -23,7 +22,6 @@ openvm-mod-circuit-builder = { workspace = true } openvm-stark-backend = { workspace = true } openvm-rv32im-circuit = { workspace = true } openvm-algebra-circuit = { workspace = true } -openvm-rv32-adapters = { workspace = true } openvm-ecc-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } @@ -33,7 +31,6 @@ strum = { workspace = true } derive_more = { workspace = true } derive-new = { workspace = true } rand = { workspace = true } -itertools = { workspace = true } eyre = { workspace = true } serde = { workspace = true, features = ["derive", "std"] } halo2curves-axiom = { workspace = true } @@ -45,7 +42,6 @@ openvm-pairing-guest = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-mod-circuit-builder = { workspace = true, features = ["test-utils"] } openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } halo2curves-axiom = { workspace = true } openvm-ecc-guest = { workspace = true } openvm-pairing-guest = { workspace = true, features = [ diff --git a/extensions/pairing/circuit/src/config.rs b/extensions/pairing/circuit/src/config.rs index d63bac664e..3fe21a0437 100644 --- a/extensions/pairing/circuit/src/config.rs +++ b/extensions/pairing/circuit/src/config.rs @@ -1,30 +1,37 @@ -use openvm_algebra_circuit::*; -use openvm_circuit::arch::{InitFileGenerator, SystemConfig}; +use std::result::Result; + +use openvm_algebra_circuit::{ + AlgebraCpuProverExt, Fp2Extension, Fp2ExtensionExecutor, Rv32ModularConfig, + Rv32ModularConfigExecutor, Rv32ModularCpuBuilder, +}; +use openvm_circuit::{ + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, + VmBuilder, VmChipComplex, VmProverExtension, + }, + system::SystemChipInventory, +}; use openvm_circuit_derive::VmConfig; -use openvm_ecc_circuit::*; -use openvm_rv32im_circuit::*; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_ecc_circuit::{EccCpuProverExt, EccExtension, EccExtensionExecutor}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; use serde::{Deserialize, Serialize}; use super::*; #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] pub struct Rv32PairingConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub base: Rv32I, - #[extension] - pub mul: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub modular: ModularExtension, + #[config(generics = true)] + pub modular: Rv32ModularConfig, #[extension] pub fp2: Fp2Extension, #[extension] - pub weierstrass: WeierstrassExtension, - #[extension] + pub ecc: EccExtension, + #[extension(generics = true)] pub pairing: PairingExtension, } @@ -37,20 +44,14 @@ impl Rv32PairingConfig { let mut modulus_and_scalar_primes = modulus_primes.clone(); modulus_and_scalar_primes.extend(curves.iter().map(|c| c.curve_config().scalar.clone())); Self { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(modulus_and_scalar_primes), + modular: Rv32ModularConfig::new(modulus_and_scalar_primes), fp2: Fp2Extension::new( complex_struct_names .into_iter() .zip(modulus_primes) .collect(), ), - weierstrass: WeierstrassExtension::new( - curves.iter().map(|c| c.curve_config()).collect(), - ), + ecc: EccExtension::new(curves.iter().map(|c| c.curve_config()).collect(), vec![]), pairing: PairingExtension::new(curves), } } @@ -60,9 +61,40 @@ impl InitFileGenerator for Rv32PairingConfig { fn generate_init_file_contents(&self) -> Option { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n{}\n", - self.modular.generate_moduli_init(), - self.fp2.generate_complex_init(&self.modular), - self.weierstrass.generate_sw_init() + self.modular.modular.generate_moduli_init(), + self.fp2.generate_complex_init(&self.modular.modular), + self.ecc.generate_ecc_init() )) } } + +#[derive(Clone)] +pub struct Rv32PairingCpuBuilder; + +impl VmBuilder for Rv32PairingCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Rv32PairingConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Rv32PairingConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&Rv32ModularCpuBuilder, &config.modular, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&AlgebraCpuProverExt, &config.fp2, inventory)?; + VmProverExtension::::extend_prover(&EccCpuProverExt, &config.ecc, inventory)?; + VmProverExtension::::extend_prover(&PairingProverExt, &config.pairing, inventory)?; + Ok(chip_complex) + } +} diff --git a/extensions/pairing/circuit/src/fp12_chip/add.rs b/extensions/pairing/circuit/src/fp12_chip/add.rs deleted file mode 100644 index 643c68ef27..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/add.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; - -use crate::Fp12; - -pub fn fp12_add_expr(config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.add(&mut y); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/fp12_chip/mod.rs b/extensions/pairing/circuit/src/fp12_chip/mod.rs deleted file mode 100644 index c6894d0d27..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod add; -mod mul; -mod sub; - -pub use add::*; -pub use mul::*; -pub use sub::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/fp12_chip/mul.rs b/extensions/pairing/circuit/src/fp12_chip/mul.rs deleted file mode 100644 index 0736981de7..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/mul.rs +++ /dev/null @@ -1,175 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::Fp12Opcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; -// Input: Fp12 * 2 -// Output: Fp12 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp12MulChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp12MulChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = fp12_mul_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![Fp12Opcode::MUL as usize], - vec![], - range_checker, - "Fp12Mul", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn fp12_mul_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.mul(&mut y, xi); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use halo2curves_axiom::{bn256::Fq12, ff::Field}; - use itertools::Itertools; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::algebra::field::FieldExtension; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::{ - test_utils::{biguint_to_limbs, bn254_fq12_to_biguint_vec, bn254_fq2_to_biguint_vec}, - ExprBuilderConfig, - }; - use openvm_pairing_guest::bn254::{BN254_MODULUS, BN254_XI_ISIZE}; - use openvm_rv32_adapters::rv32_write_heap_default_with_increment; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - const LIMB_BITS: usize = 8; - type F = BabyBear; - - #[test] - fn test_fp12_mul_bn254() { - const NUM_LIMBS: usize = 32; - const BLOCK_SIZE: usize = 32; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - - let mut chip = Fp12MulChip::new( - adapter, - config, - BN254_XI_ISIZE, - Fp12Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(64); - let x = Fq12::random(&mut rng); - let y = Fq12::random(&mut rng); - let inputs = [x.to_coeffs(), y.to_coeffs()] - .concat() - .iter() - .flat_map(|&x| bn254_fq2_to_biguint_vec(x)) - .collect::>(); - - let cmp = bn254_fq12_to_biguint_vec(x * y); - let res = chip - .0 - .core - .expr() - .execute_with_output(inputs.clone(), vec![true]); - assert_eq!(res.len(), cmp.len()); - for i in 0..res.len() { - assert_eq!(res[i], cmp[i]); - } - - let x_limbs = inputs[..12] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect_vec(); - let y_limbs = inputs[12..] - .iter() - .map(|y| { - biguint_to_limbs::(y.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect_vec(); - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - x_limbs, - y_limbs, - 512, - chip.0.core.air.offset + Fp12Opcode::MUL as usize, - ); - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/fp12_chip/sub.rs b/extensions/pairing/circuit/src/fp12_chip/sub.rs deleted file mode 100644 index 470e700910..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/sub.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; - -use crate::Fp12; - -pub fn fp12_sub_expr(config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.sub(&mut y); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/fp12_chip/tests.rs b/extensions/pairing/circuit/src/fp12_chip/tests.rs deleted file mode 100644 index a9f6b235d5..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/tests.rs +++ /dev/null @@ -1,271 +0,0 @@ -use num_bigint::BigUint; -use openvm_circuit::arch::{ - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmChipWrapper, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{ - test_utils::{ - biguint_to_limbs, bls12381_fq12_random, bn254_fq12_random, bn254_fq12_to_biguint_vec, - }, - ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_guest::{ - bls12_381::{ - BLS12_381_BLOCK_SIZE, BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS, - BLS12_381_XI_ISIZE, - }, - bn254::{BN254_BLOCK_SIZE, BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS, BN254_XI_ISIZE}, -}; -use openvm_pairing_transpiler::{Bls12381Fp12Opcode, Bn254Fp12Opcode, Fp12Opcode}; -use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; - -use super::{fp12_add_expr, fp12_mul_expr, fp12_sub_expr}; - -type F = BabyBear; - -#[allow(clippy::too_many_arguments)] -fn test_fp12_fn< - const INPUT_SIZE: usize, - const NUM_LIMBS: usize, - const LIMB_BITS: usize, - const BLOCK_SIZE: usize, ->( - mut tester: VmChipTestBuilder, - expr: FieldExpr, - offset: usize, - local_opcode_idx: usize, - name: &str, - x: Vec, - y: Vec, - var_len: usize, -) { - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![local_opcode_idx], - vec![], - tester.memory_controller().borrow().range_checker.clone(), - name, - false, - ); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let adapter = - Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - - let x_limbs = x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let y_limbs = y - .iter() - .map(|y| { - biguint_to_limbs::(y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let mut chip = VmChipWrapper::new(adapter, core, tester.offline_memory_mutex_arc()); - - let res = chip.core.air.expr.execute([x, y].concat(), vec![]); - assert_eq!(res.len(), var_len); - - let instruction = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.core.air.offset + local_opcode_idx, - ); - tester.execute(&mut chip, &instruction); - - let run_tester = tester.build().load(chip).load(bitwise_chip).finalize(); - run_tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_fp12_add_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let expr = fp12_add_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(1)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(2)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::ADD as usize, - "Bn254Fp12Add", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_sub_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let expr = fp12_sub_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(59)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(3)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::SUB as usize, - "Bn254Fp12Sub", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_mul_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let xi = BN254_XI_ISIZE; - let expr = fp12_mul_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - xi, - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(5)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(25)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::MUL as usize, - "Bn254Fp12Mul", - x, - y, - 33, - ); -} - -#[test] -fn test_fp12_add_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let expr = fp12_add_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bls12381_fq12_random(3); - let y = bls12381_fq12_random(99); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::ADD as usize, - "Bls12381Fp12Add", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_sub_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let expr = fp12_sub_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bls12381_fq12_random(8); - let y = bls12381_fq12_random(9); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::SUB as usize, - "Bls12381Fp12Sub", - x, - y, - 12, - ); -} - -// NOTE[yj]: This test requires RUST_MIN_STACK=8388608 to run without overflowing the stack, so it -// is ignored by the test runner for now -#[test] -#[ignore] -fn test_fp12_mul_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let xi = BLS12_381_XI_ISIZE; - let expr = fp12_mul_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - xi, - ); - - let x = bls12381_fq12_random(5); - let y = bls12381_fq12_random(25); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::MUL as usize, - "Bls12381Fp12Mul", - x, - y, - 46, - ); -} diff --git a/extensions/pairing/circuit/src/lib.rs b/extensions/pairing/circuit/src/lib.rs index b2b962b7f7..f96d126555 100644 --- a/extensions/pairing/circuit/src/lib.rs +++ b/extensions/pairing/circuit/src/lib.rs @@ -1,11 +1,7 @@ mod config; mod fp12; -mod fp12_chip; -mod pairing_chip; mod pairing_extension; pub use config::*; pub use fp12::*; -pub use fp12_chip::*; -pub use pairing_chip::*; pub use pairing_extension::*; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs deleted file mode 100644 index 08857995f3..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod mul_013_by_013; -mod mul_by_01234; - -pub use mul_013_by_013::*; -pub use mul_by_01234::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs deleted file mode 100644 index 36d1012e9b..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements -// Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMul013By013Chip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMul013By013Chip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_013_by_013_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_013_BY_013 as usize], - vec![], - range_checker, - "Mul013By013", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_013_by_013_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut b0 = Fp2::new(builder.clone()); - let mut c0 = Fp2::new(builder.clone()); - let mut b1 = Fp2::new(builder.clone()); - let mut c1 = Fp2::new(builder.clone()); - - // where w⁶ = xi - // l0 * l1 = 1 + (b0 + b1)w + (b0b1)w² + (c0 + c1)w³ + (b0c1 + b1c0)w⁴ + (c0c1)w⁶ - // = (1 + c0c1 * xi) + (b0 + b1)w + (b0b1)w² + (c0 + c1)w³ + (b0c1 + b1c0)w⁴ - let l0 = c0.mul(&mut c1).int_mul(xi).int_add([1, 0]); - let l1 = b0.add(&mut b1); - let l2 = b0.mul(&mut b1); - let l3 = c0.add(&mut c1); - let l4 = b0.mul(&mut c1).add(&mut b1.mul(&mut c0)); - - [l0, l1, l2, l3, l4].map(|mut l| l.save_output()); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs deleted file mode 100644 index 996372e994..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; - -// Input: Fp12 (12 field elements), [Fp2; 5] (5 x 2 field elements) -// Output: Fp12 (12 field elements) -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMulBy01234Chip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMulBy01234Chip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_by_01234_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_BY_01234 as usize], - vec![], - range_checker.clone(), - "MulBy01234", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_by_01234_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut f = Fp12::new(builder.clone()); - let mut x0 = Fp2::new(builder.clone()); - let mut x1 = Fp2::new(builder.clone()); - let mut x2 = Fp2::new(builder.clone()); - let mut x3 = Fp2::new(builder.clone()); - let mut x4 = Fp2::new(builder.clone()); - - let mut r = f.mul_by_01234(&mut x0, &mut x1, &mut x2, &mut x3, &mut x4, xi); - r.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs deleted file mode 100644 index 81da3169fa..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs +++ /dev/null @@ -1,287 +0,0 @@ -use halo2curves_axiom::{ - bn256::{Fq, Fq12, Fq2, G1Affine}, - ff::Field, -}; -use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_ecc_guest::AffinePoint; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{ - test_utils::{ - biguint_to_limbs, bn254_fq12_to_biguint_vec, bn254_fq2_to_biguint_vec, bn254_fq_to_biguint, - }, - ExprBuilderConfig, -}; -use openvm_pairing_guest::{ - bn254::{BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS, BN254_XI_ISIZE}, - halo2curves_shims::bn254::{tangent_line_013, Bn254}, - pairing::{Evaluatable, LineMulDType, UnevaluatedLine}, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::{ - rv32_write_heap_default, rv32_write_heap_default_with_increment, Rv32VecHeapAdapterChip, - Rv32VecHeapTwoReadsAdapterChip, -}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use rand::{rngs::StdRng, SeedableRng}; - -use super::{super::EvaluateLineChip, *}; - -type F = BabyBear; -const NUM_LIMBS: usize = 32; -const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 32; - -#[test] -fn test_mul_013_by_013() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMul013By013Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }, - BN254_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(8); - let mut rng1 = StdRng::seed_from_u64(95); - let rnd_pt_0 = G1Affine::random(&mut rng0); - let rnd_pt_1 = G1Affine::random(&mut rng1); - let ec_pt_0 = AffinePoint:: { - x: rnd_pt_0.x, - y: rnd_pt_0.y, - }; - let ec_pt_1 = AffinePoint:: { - x: rnd_pt_1.x, - y: rnd_pt_1.y, - }; - let line0 = tangent_line_013::(ec_pt_0); - let line1 = tangent_line_013::(ec_pt_1); - let input_line0 = [ - bn254_fq2_to_biguint_vec(line0.b), - bn254_fq2_to_biguint_vec(line0.c), - ] - .concat(); - let input_line1 = [ - bn254_fq2_to_biguint_vec(line1.b), - bn254_fq2_to_biguint_vec(line1.c), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_line0.clone(), input_line1.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 10); - - let r_cmp = Bn254::mul_013_by_013(&line0, &line1); - let r_cmp_bigint = r_cmp - .map(|x| [bn254_fq_to_biguint(x.c0), bn254_fq_to_biguint(x.c1)]) - .concat(); - - for i in 0..10 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_line0_limbs = input_line0 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_line1_limbs = input_line1 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default( - &mut tester, - input_line0_limbs, - input_line1_limbs, - chip.0.core.air.offset + PairingOpcode::MUL_013_BY_013 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_mul_by_01234() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMulBy01234Chip::new( - adapter, - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }, - BN254_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(8); - let f = Fq12::random(&mut rng); - let x0 = Fq2::random(&mut rng); - let x1 = Fq2::random(&mut rng); - let x2 = Fq2::random(&mut rng); - let x3 = Fq2::random(&mut rng); - let x4 = Fq2::random(&mut rng); - - let input_f = bn254_fq12_to_biguint_vec(f); - let input_x = [ - bn254_fq2_to_biguint_vec(x0), - bn254_fq2_to_biguint_vec(x1), - bn254_fq2_to_biguint_vec(x2), - bn254_fq2_to_biguint_vec(x3), - bn254_fq2_to_biguint_vec(x4), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_f.clone(), input_x.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 12); - - let r_cmp = Bn254::mul_by_01234(&f, &[x0, x1, x2, x3, x4]); - let r_cmp_bigint = bn254_fq12_to_biguint_vec(r_cmp); - - for i in 0..12 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_f_limbs = input_f - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_x_limbs = input_x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_f_limbs, - input_x_limbs, - 512, - chip.0.core.air.offset + PairingOpcode::MUL_BY_01234 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_evaluate_line() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: BN254_LIMB_BITS, - num_limbs: BN254_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EvaluateLineChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(42); - let uneval_b = Fq2::random(&mut rng); - let uneval_c = Fq2::random(&mut rng); - let x_over_y = Fq::random(&mut rng); - let y_inv = Fq::random(&mut rng); - let mut inputs = vec![]; - inputs.extend(bn254_fq2_to_biguint_vec(uneval_b)); - inputs.extend(bn254_fq2_to_biguint_vec(uneval_c)); - inputs.push(bn254_fq_to_biguint(x_over_y)); - inputs.push(bn254_fq_to_biguint(y_inv)); - let input_limbs = inputs - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect(); - - let uneval: UnevaluatedLine = UnevaluatedLine { - b: uneval_b, - c: uneval_c, - }; - let evaluated = uneval.evaluate(&(x_over_y, y_inv)); - - let result = chip.0.core.expr().execute_with_output(inputs, vec![]); - assert_eq!(result.len(), 4); - assert_eq!(result[0], bn254_fq_to_biguint(evaluated.b.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(evaluated.b.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(evaluated.c.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(evaluated.c.c1)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs, - vec![], - chip.0.core.air.offset + PairingOpcode::EVALUATE_LINE as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs b/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs deleted file mode 100644 index dc0a8cdfe1..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs +++ /dev/null @@ -1,102 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: UnevaluatedLine, (Fp, Fp) -// Output: EvaluatedLine -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EvaluateLineChip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EvaluateLineChip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = evaluate_line_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::EVALUATE_LINE as usize], - vec![], - range_checker, - "EvaluateLine", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn evaluate_line_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut uneval_b = Fp2::new(builder.clone()); - let mut uneval_c = Fp2::new(builder.clone()); - - let mut x_over_y = ExprBuilder::new_input(builder.clone()); - let mut y_inv = ExprBuilder::new_input(builder.clone()); - - let mut b = uneval_b.scalar_mul(&mut x_over_y); - let mut c = uneval_c.scalar_mul(&mut y_inv); - b.save_output(); - c.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs deleted file mode 100644 index b454d260ce..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod mul_023_by_023; -mod mul_by_02345; - -pub use mul_023_by_023::*; -pub use mul_by_02345::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs deleted file mode 100644 index 0d760b886e..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements -// Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMul023By023Chip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMul023By023Chip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_023_by_023_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_023_BY_023 as usize], - vec![], - range_checker, - "Mul023By023", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_023_by_023_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut b0 = Fp2::new(builder.clone()); // x2 - let mut c0 = Fp2::new(builder.clone()); // x3 - let mut b1 = Fp2::new(builder.clone()); // y2 - let mut c1 = Fp2::new(builder.clone()); // y3 - - // where w⁶ = xi - // l0 * l1 = c0c1 + (c0b1 + c1b0)w² + (c0 + c1)w³ + (b0b1)w⁴ + (b0 +b1)w⁵ + w⁶ - // = (c0c1 + xi) + (c0b1 + c1b0)w² + (c0 + c1)w³ + (b0b1)w⁴ + (b0 + b1)w⁵ - let l0 = c0.mul(&mut c1).int_add(xi); - let l2 = c0.mul(&mut b1).add(&mut c1.mul(&mut b0)); - let l3 = c0.add(&mut c1); - let l4 = b0.mul(&mut b1); - let l5 = b0.add(&mut b1); - - [l0, l2, l3, l4, l5].map(|mut l| l.save_output()); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs deleted file mode 100644 index ad0e91e7bd..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; - -// Input: 2 Fp12: 2 x 12 field elements -// Output: Fp12 -> 12 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMulBy02345Chip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMulBy02345Chip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_by_02345_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_BY_02345 as usize], - vec![], - range_checker, - "MulBy02345", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_by_02345_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut f = Fp12::new(builder.clone()); - let mut x0 = Fp2::new(builder.clone()); - let mut x2 = Fp2::new(builder.clone()); - let mut x3 = Fp2::new(builder.clone()); - let mut x4 = Fp2::new(builder.clone()); - let mut x5 = Fp2::new(builder.clone()); - - let mut r = f.mul_by_02345(&mut x0, &mut x2, &mut x3, &mut x4, &mut x5, xi); - r.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs deleted file mode 100644 index 4331d2278e..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs +++ /dev/null @@ -1,217 +0,0 @@ -use halo2curves_axiom::{ - bls12_381::{Fq, Fq12, Fq2, G1Affine}, - ff::Field, -}; -use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_ecc_guest::AffinePoint; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{test_utils::*, ExprBuilderConfig}; -use openvm_pairing_guest::{ - bls12_381::{BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS, BLS12_381_XI_ISIZE}, - halo2curves_shims::bls12_381::{tangent_line_023, Bls12_381}, - pairing::LineMulMType, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::{ - rv32_write_heap_default_with_increment, Rv32VecHeapAdapterChip, Rv32VecHeapTwoReadsAdapterChip, -}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use rand::{rngs::StdRng, SeedableRng}; - -use super::*; - -type F = BabyBear; -const NUM_LIMBS: usize = 48; -const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 16; - -#[test] -fn test_mul_023_by_023() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMul023By023Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }, - BLS12_381_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(15); - let mut rng1 = StdRng::seed_from_u64(95); - let rnd_pt_0 = G1Affine::random(&mut rng0); - let rnd_pt_1 = G1Affine::random(&mut rng1); - let ec_pt_0 = AffinePoint:: { - x: rnd_pt_0.x, - y: rnd_pt_0.y, - }; - let ec_pt_1 = AffinePoint:: { - x: rnd_pt_1.x, - y: rnd_pt_1.y, - }; - let line0 = tangent_line_023::(ec_pt_0); - let line1 = tangent_line_023::(ec_pt_1); - let input_line0 = [ - bls12381_fq2_to_biguint_vec(line0.b), - bls12381_fq2_to_biguint_vec(line0.c), - ] - .concat(); - let input_line1 = [ - bls12381_fq2_to_biguint_vec(line1.b), - bls12381_fq2_to_biguint_vec(line1.c), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_line0.clone(), input_line1.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 10); - - let r_cmp = Bls12_381::mul_023_by_023(&line0, &line1); - let r_cmp_bigint = r_cmp - .map(|x| [bls12381_fq_to_biguint(x.c0), bls12381_fq_to_biguint(x.c1)]) - .concat(); - - for i in 0..10 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_line0_limbs = input_line0 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_line1_limbs = input_line1 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_line0_limbs, - input_line1_limbs, - 512, - chip.0.core.air.offset + PairingOpcode::MUL_023_BY_023 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -// NOTE[yj]: this test requires `RUST_MIN_STACK=8388608` to run otherwise it will overflow the stack -#[test] -#[ignore] -fn test_mul_by_02345() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMulBy02345Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }, - BLS12_381_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(19); - let f = Fq12::random(&mut rng); - let x0 = Fq2::random(&mut rng); - let x2 = Fq2::random(&mut rng); - let x3 = Fq2::random(&mut rng); - let x4 = Fq2::random(&mut rng); - let x5 = Fq2::random(&mut rng); - - let input_f = bls12381_fq12_to_biguint_vec(f); - let input_x = [ - bls12381_fq2_to_biguint_vec(x0), - bls12381_fq2_to_biguint_vec(x2), - bls12381_fq2_to_biguint_vec(x3), - bls12381_fq2_to_biguint_vec(x4), - bls12381_fq2_to_biguint_vec(x5), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_f.clone(), input_x.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 12); - - let r_cmp = Bls12_381::mul_by_02345(&f, &[x0, x2, x3, x4, x5]); - let r_cmp_bigint = bls12381_fq12_to_biguint_vec(r_cmp); - - for i in 0..12 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_f_limbs = input_f - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_x_limbs = input_x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_f_limbs, - input_x_limbs, - 1024, - chip.0.core.air.offset + PairingOpcode::MUL_BY_02345 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/mod.rs deleted file mode 100644 index acf02c72be..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod d_type; -mod evaluate_line; -mod m_type; - -pub use d_type::*; -pub use evaluate_line::*; -pub use m_type::*; diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs deleted file mode 100644 index 77084428c9..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs +++ /dev/null @@ -1,215 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: two AffinePoint: 4 field elements each -// Output: (AffinePoint, UnevaluatedLine, UnevaluatedLine) -> 2*2 + 2*2 + 2*2 = 12 -// field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct MillerDoubleAndAddStepChip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > MillerDoubleAndAddStepChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = miller_double_and_add_step_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MILLER_DOUBLE_AND_ADD_STEP as usize], - vec![], - range_checker, - "MillerDoubleAndAddStep", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -// Ref: openvm_pairing_guest::miller_step -pub fn miller_double_and_add_step_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x_s = Fp2::new(builder.clone()); - let mut y_s = Fp2::new(builder.clone()); - let mut x_q = Fp2::new(builder.clone()); - let mut y_q = Fp2::new(builder.clone()); - - // λ1 = (y_s - y_q) / (x_s - x_q) - let mut lambda1 = y_s.sub(&mut y_q).div(&mut x_s.sub(&mut x_q)); - let mut x_sq = lambda1.square().sub(&mut x_s).sub(&mut x_q); - // λ2 = -λ1 - 2y_s / (x_{s+q} - x_s) - let mut lambda2 = lambda1 - .neg() - .sub(&mut y_s.int_mul([2, 0]).div(&mut x_sq.sub(&mut x_s))); - let mut x_sqs = lambda2.square().sub(&mut x_s).sub(&mut x_sq); - let mut y_sqs = lambda2.mul(&mut (x_s.sub(&mut x_sqs))).sub(&mut y_s); - - x_sqs.save_output(); - y_sqs.save_output(); - - let mut b0 = lambda1.neg(); - let mut c0 = lambda1.mul(&mut x_s).sub(&mut y_s); - b0.save_output(); - c0.save_output(); - - let mut b1 = lambda2.neg(); - let mut c1 = lambda2.mul(&mut x_s).sub(&mut y_s); - b1.save_output(); - c1.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use halo2curves_axiom::bn256::G2Affine; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::AffinePoint; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::test_utils::{biguint_to_limbs, bn254_fq_to_biguint}; - use openvm_pairing_guest::{ - bn254::BN254_MODULUS, halo2curves_shims::bn254::Bn254, pairing::MillerStep, - }; - use openvm_pairing_transpiler::PairingOpcode; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - type F = BabyBear; - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 32; - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_and_add() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleAndAddStepChip::new( - adapter, - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: LIMB_BITS, - num_limbs: NUM_LIMBS, - }, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(2); - let Q = G2Affine::random(&mut rng0); - let Q2 = G2Affine::random(&mut rng0); - let inputs = [ - Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1, Q2.x.c0, Q2.x.c1, Q2.y.c0, Q2.y.c1, - ] - .map(bn254_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let Q_ecpoint2 = AffinePoint { x: Q2.x, y: Q2.y }; - let (Q_daa, l_qa, l_sqs) = Bn254::miller_double_and_add_step(&Q_ecpoint, &Q_ecpoint2); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 12); // AffinePoint and 4 Fp2 coefficients - assert_eq!(result[0], bn254_fq_to_biguint(Q_daa.x.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(Q_daa.x.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(Q_daa.y.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(Q_daa.y.c1)); - assert_eq!(result[4], bn254_fq_to_biguint(l_qa.b.c0)); - assert_eq!(result[5], bn254_fq_to_biguint(l_qa.b.c1)); - assert_eq!(result[6], bn254_fq_to_biguint(l_qa.c.c0)); - assert_eq!(result[7], bn254_fq_to_biguint(l_qa.c.c1)); - assert_eq!(result[8], bn254_fq_to_biguint(l_sqs.b.c0)); - assert_eq!(result[9], bn254_fq_to_biguint(l_sqs.b.c1)); - assert_eq!(result[10], bn254_fq_to_biguint(l_sqs.c.c0)); - assert_eq!(result[11], bn254_fq_to_biguint(l_sqs.c.c1)); - - let input1_limbs = inputs[0..4] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let input2_limbs = inputs[4..8] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default( - &mut tester, - input1_limbs, - input2_limbs, - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_AND_ADD_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs deleted file mode 100644 index 519eb473a5..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs +++ /dev/null @@ -1,253 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: AffinePoint: 4 field elements -// Output: (AffinePoint, Fp2, Fp2) -> 8 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct MillerDoubleStepChip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > MillerDoubleStepChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = miller_double_step_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MILLER_DOUBLE_STEP as usize], - vec![], - range_checker, - "MillerDoubleStep", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -// Ref: https://github.com/openvm-org/openvm/blob/f7d6fa7b8ef247e579740eb652fcdf5a04259c28/lib/ecc-execution/src/common/miller_step.rs#L7 -pub fn miller_double_step_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x_s = Fp2::new(builder.clone()); - let mut y_s = Fp2::new(builder.clone()); - - let mut three_x_square = x_s.square().int_mul([3, 0]); - let mut lambda = three_x_square.div(&mut y_s.int_mul([2, 0])); - let mut x_2s = lambda.square().sub(&mut x_s.int_mul([2, 0])); - let mut y_2s = lambda.mul(&mut (x_s.sub(&mut x_2s))).sub(&mut y_s); - x_2s.save_output(); - y_2s.save_output(); - - let mut b = lambda.neg(); - let mut c = lambda.mul(&mut x_s).sub(&mut y_s); - b.save_output(); - c.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::AffinePoint; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::test_utils::{ - biguint_to_limbs, bls12381_fq_to_biguint, bn254_fq_to_biguint, - }; - use openvm_pairing_guest::{ - bls12_381::{BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS}, - bn254::{BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS}, - halo2curves_shims::{bls12_381::Bls12_381, bn254::Bn254}, - pairing::MillerStep, - }; - use openvm_pairing_transpiler::PairingOpcode; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - type F = BabyBear; - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_bn254() { - use halo2curves_axiom::bn256::G2Affine; - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 32; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: BN254_LIMB_BITS, - num_limbs: BN254_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleStepChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(2); - let Q = G2Affine::random(&mut rng0); - let inputs = [Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1].map(bn254_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let (Q_acc_init, l_init) = Bn254::miller_double_step(&Q_ecpoint); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 8); // AffinePoint and two Fp2 coefficients - assert_eq!(result[0], bn254_fq_to_biguint(Q_acc_init.x.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(Q_acc_init.x.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(Q_acc_init.y.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(Q_acc_init.y.c1)); - assert_eq!(result[4], bn254_fq_to_biguint(l_init.b.c0)); - assert_eq!(result[5], bn254_fq_to_biguint(l_init.b.c1)); - assert_eq!(result[6], bn254_fq_to_biguint(l_init.c.c0)); - assert_eq!(result[7], bn254_fq_to_biguint(l_init.c.c1)); - - let input_limbs = inputs - .map(|x| biguint_to_limbs::(x, LIMB_BITS).map(BabyBear::from_canonical_u32)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs.to_vec(), - vec![], - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_bls12_381() { - use halo2curves_axiom::bls12_381::G2Affine; - const NUM_LIMBS: usize = 48; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 16; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - limb_bits: BLS12_381_LIMB_BITS, - num_limbs: BLS12_381_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleStepChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(12); - let Q = G2Affine::random(&mut rng0); - let inputs = [Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1].map(bls12381_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let (Q_acc_init, l_init) = Bls12_381::miller_double_step(&Q_ecpoint); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 8); // AffinePoint and two Fp2 coefficients - assert_eq!(result[0], bls12381_fq_to_biguint(Q_acc_init.x.c0)); - assert_eq!(result[1], bls12381_fq_to_biguint(Q_acc_init.x.c1)); - assert_eq!(result[2], bls12381_fq_to_biguint(Q_acc_init.y.c0)); - assert_eq!(result[3], bls12381_fq_to_biguint(Q_acc_init.y.c1)); - assert_eq!(result[4], bls12381_fq_to_biguint(l_init.b.c0)); - assert_eq!(result[5], bls12381_fq_to_biguint(l_init.b.c1)); - assert_eq!(result[6], bls12381_fq_to_biguint(l_init.c.c0)); - assert_eq!(result[7], bls12381_fq_to_biguint(l_init.c.c1)); - - let input_limbs = inputs - .map(|x| biguint_to_limbs::(x, LIMB_BITS).map(BabyBear::from_canonical_u32)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs.to_vec(), - vec![], - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/pairing_chip/mod.rs b/extensions/pairing/circuit/src/pairing_chip/mod.rs deleted file mode 100644 index df00df16ce..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod line; -mod miller_double_step; - -pub use line::*; -pub use miller_double_step::*; - -mod miller_double_and_add_step; -pub use miller_double_and_add_step::*; diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index c75687f404..eacea93859 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -2,13 +2,15 @@ use derive_more::derive::From; use num_bigint::BigUint; use num_traits::{FromPrimitive, Zero}; use openvm_circuit::{ - arch::{VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, - system::phantom::PhantomChip, + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, + ExecutorInventoryBuilder, ExecutorInventoryError, VmCircuitExtension, VmExecutionExtension, + VmProverExtension, + }, + system::phantom::PhantomExecutor, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_ecc_circuit::CurveConfig; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; +use openvm_ecc_circuit::{CurveConfig, SwCurveCoeffs}; use openvm_instructions::PhantomDiscriminant; use openvm_pairing_guest::{ bls12_381::{ @@ -17,12 +19,10 @@ use openvm_pairing_guest::{ bn254::{BN254_ECC_STRUCT_NAME, BN254_MODULUS, BN254_ORDER, BN254_XI_ISIZE}, }; use openvm_pairing_transpiler::PairingPhantom; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{config::StarkGenericConfig, engine::StarkEngine, p3_field::Field}; use serde::{Deserialize, Serialize}; use strum::FromRepr; -use super::*; - // All the supported pairing curves. #[derive(Clone, Copy, Debug, FromRepr, Serialize, Deserialize)] #[repr(usize)] @@ -32,21 +32,25 @@ pub enum PairingCurve { } impl PairingCurve { - pub fn curve_config(&self) -> CurveConfig { + pub fn curve_config(&self) -> CurveConfig { match self { PairingCurve::Bn254 => CurveConfig::new( BN254_ECC_STRUCT_NAME.to_string(), BN254_MODULUS.clone(), BN254_ORDER.clone(), - BigUint::zero(), - BigUint::from_u8(3).unwrap(), + SwCurveCoeffs { + a: BigUint::zero(), + b: BigUint::from_u8(3).unwrap(), + }, ), PairingCurve::Bls12_381 => CurveConfig::new( BLS12_381_ECC_STRUCT_NAME.to_string(), BLS12_381_MODULUS.clone(), BLS12_381_ORDER.clone(), - BigUint::zero(), - BigUint::from_u8(4).unwrap(), + SwCurveCoeffs { + a: BigUint::zero(), + b: BigUint::from_u8(4).unwrap(), + }, ), } } @@ -59,43 +63,48 @@ impl PairingCurve { } } -#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +#[derive(Clone, Debug, From, derive_new::new, Serialize, Deserialize)] pub struct PairingExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] -pub enum PairingExtensionExecutor { - // bn254 (32 limbs) - MillerDoubleAndAddStepRv32_32(MillerDoubleAndAddStepChip), - EvaluateLineRv32_32(EvaluateLineChip), - // bls12-381 (48 limbs) - MillerDoubleAndAddStepRv32_48(MillerDoubleAndAddStepChip), - EvaluateLineRv32_48(EvaluateLineChip), +#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum PairingExtensionExecutor { + Phantom(PhantomExecutor), } -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] -pub enum PairingExtensionPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - Phantom(PhantomChip), -} - -impl VmExtension for PairingExtension { +impl VmExecutionExtension for PairingExtension { type Executor = PairingExtensionExecutor; - type Periphery = PairingExtensionPeriphery; - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let inventory = VmInventory::new(); - - builder.add_phantom_sub_executor( + inventory: &mut ExecutorInventoryBuilder>, + ) -> Result<(), ExecutorInventoryError> { + inventory.add_phantom_sub_executor( phantom::PairingHintSubEx, PhantomDiscriminant(PairingPhantom::HintFinalExp as u16), )?; + Ok(()) + } +} + +impl VmCircuitExtension for PairingExtension { + fn extend_circuit(&self, _inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + Ok(()) + } +} - Ok(inventory) +pub struct PairingProverExt; +impl VmProverExtension for PairingProverExt +where + E: StarkEngine, +{ + fn extend_prover( + &self, + _: &PairingExtension, + _inventory: &mut ChipInventory, + ) -> Result<(), ChipInventoryError> { + Ok(()) } } @@ -106,7 +115,7 @@ pub(crate) mod phantom { use halo2curves_axiom::ff; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_ecc_guest::{algebra::field::FieldExtension, AffinePoint}; use openvm_instructions::{ @@ -118,53 +127,52 @@ pub(crate) mod phantom { bn254::BN254_NUM_LIMBS, pairing::{FinalExp, MultiMillerLoop}, }; - use openvm_rv32im_circuit::adapters::{compose, unsafe_read_rv32_register}; - use openvm_stark_backend::p3_field::PrimeField32; + use openvm_rv32im_circuit::adapters::{memory_read, read_rv32_register}; + use openvm_stark_backend::p3_field::Field; + use rand::rngs::StdRng; use super::PairingCurve; pub struct PairingHintSubEx; - impl PhantomSubExecutor for PairingHintSubEx { + impl PhantomSubExecutor for PairingHintSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, c_upper: u16, ) -> eyre::Result<()> { - let rs1 = unsafe_read_rv32_register(memory, a); - let rs2 = unsafe_read_rv32_register(memory, b); + let rs1 = read_rv32_register(memory, a); + let rs2 = read_rv32_register(memory, b); hint_pairing(memory, &mut streams.hint_stream, rs1, rs2, c_upper) } } - fn hint_pairing( - memory: &MemoryController, + fn hint_pairing( + memory: &GuestMemory, hint_stream: &mut VecDeque, rs1: u32, rs2: u32, c_upper: u16, ) -> eyre::Result<()> { - let p_ptr = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1), - )); + let p_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs1)); // len in bytes - let p_len = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1 + RV32_REGISTER_NUM_LIMBS as u32), - )); - let q_ptr = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs2), + let p_len = u32::from_le_bytes(memory_read( + memory, + RV32_MEMORY_AS, + rs1 + RV32_REGISTER_NUM_LIMBS as u32, )); + + let q_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs2)); // len in bytes - let q_len = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs2 + RV32_REGISTER_NUM_LIMBS as u32), + let q_len = u32::from_le_bytes(memory_read( + memory, + RV32_MEMORY_AS, + rs2 + RV32_REGISTER_NUM_LIMBS as u32, )); match PairingCurve::from_repr(c_upper as usize) { @@ -178,8 +186,8 @@ pub(crate) mod phantom { let p = (0..p_len) .map(|i| -> eyre::Result<_> { let ptr = p_ptr + i * 2 * (N as u32); - let x = read_fp::(memory, ptr)?; - let y = read_fp::(memory, ptr + N as u32)?; + let x = read_fp::(memory, ptr)?; + let y = read_fp::(memory, ptr + N as u32)?; Ok(AffinePoint::new(x, y)) }) .collect::>>()?; @@ -187,8 +195,8 @@ pub(crate) mod phantom { .map(|i| -> eyre::Result<_> { let mut ptr = q_ptr + i * 4 * (N as u32); let mut read_fp2 = || -> eyre::Result<_> { - let c0 = read_fp::(memory, ptr)?; - let c1 = read_fp::(memory, ptr + N as u32)?; + let c0 = read_fp::(memory, ptr)?; + let c1 = read_fp::(memory, ptr + N as u32)?; ptr += 2 * N as u32; Ok(Fq2::new(c0, c1)) }; @@ -220,8 +228,8 @@ pub(crate) mod phantom { let p = (0..p_len) .map(|i| -> eyre::Result<_> { let ptr = p_ptr + i * 2 * (N as u32); - let x = read_fp::(memory, ptr)?; - let y = read_fp::(memory, ptr + N as u32)?; + let x = read_fp::(memory, ptr)?; + let y = read_fp::(memory, ptr + N as u32)?; Ok(AffinePoint::new(x, y)) }) .collect::>>()?; @@ -229,8 +237,8 @@ pub(crate) mod phantom { .map(|i| -> eyre::Result<_> { let mut ptr = q_ptr + i * 4 * (N as u32); let mut read_fp2 = || -> eyre::Result<_> { - let c0 = read_fp::(memory, ptr)?; - let c1 = read_fp::(memory, ptr + N as u32)?; + let c0 = read_fp::(memory, ptr)?; + let c1 = read_fp::(memory, ptr + N as u32)?; ptr += 2 * N as u32; Ok(Fq2 { c0, c1 }) }; @@ -259,24 +267,21 @@ pub(crate) mod phantom { Ok(()) } - fn read_fp( - memory: &MemoryController, + fn read_fp( + memory: &GuestMemory, ptr: u32, ) -> eyre::Result where Fp::Repr: From<[u8; N]>, { - let mut repr = [0u8; N]; - for (i, byte) in repr.iter_mut().enumerate() { - *byte = memory - .unsafe_read_cell( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(ptr + i as u32), - ) - .as_canonical_u32() - .try_into()?; - } - Fp::from_repr(repr.into()) + let repr: &[u8; N] = unsafe { + memory + .memory + .get_slice::((RV32_MEMORY_AS, ptr), N) + .try_into() + .unwrap() + }; + Fp::from_repr((*repr).into()) .into_option() .ok_or(eyre::eyre!("bad ff::PrimeField repr")) } diff --git a/extensions/pairing/guest/src/halo2curves_shims/bn254/final_exp.rs b/extensions/pairing/guest/src/halo2curves_shims/bn254/final_exp.rs index f4808e08b6..7cc7e78c40 100644 --- a/extensions/pairing/guest/src/halo2curves_shims/bn254/final_exp.rs +++ b/extensions/pairing/guest/src/halo2curves_shims/bn254/final_exp.rs @@ -2,11 +2,32 @@ use halo2curves_axiom::{ bn256::{Fq, Fq12, Fq2}, ff::Field, }; +use lazy_static::lazy_static; use openvm_ecc_guest::{algebra::field::FieldExtension, AffinePoint}; use super::{Bn254, EXP1, EXP2, M_INV, R_INV, U27_COEFF_0, U27_COEFF_1}; use crate::pairing::{FinalExp, MultiMillerLoop}; +lazy_static! { + pub static ref UNITY_ROOT_27: Fq12 = { + let u0 = U27_COEFF_0.to_u64_digits(); + let u1 = U27_COEFF_1.to_u64_digits(); + let u_coeffs = Fq2::from_coeffs([ + Fq::from_raw([u0[0], u0[1], u0[2], u0[3]]), + Fq::from_raw([u1[0], u1[1], u1[2], u1[3]]), + ]); + Fq12::from_coeffs([ + Fq2::ZERO, + Fq2::ZERO, + u_coeffs, + Fq2::ZERO, + Fq2::ZERO, + Fq2::ZERO, + ]) + }; + pub static ref UNITY_ROOT_27_EXP2: Fq12 = UNITY_ROOT_27.pow(EXP2.to_u64_digits()); +} + #[allow(non_snake_case)] impl FinalExp for Bn254 { type Fp = Fq; @@ -50,21 +71,7 @@ impl FinalExp for Bn254 { // Cubic nonresidue power let u; - // get the 27th root of unity - let u0 = U27_COEFF_0.to_u64_digits(); - let u1 = U27_COEFF_1.to_u64_digits(); - let u_coeffs = Fq2::from_coeffs([ - Fq::from_raw([u0[0], u0[1], u0[2], u0[3]]), - Fq::from_raw([u1[0], u1[1], u1[2], u1[3]]), - ]); - let unity_root_27 = Fq12::from_coeffs([ - Fq2::ZERO, - Fq2::ZERO, - u_coeffs, - Fq2::ZERO, - Fq2::ZERO, - Fq2::ZERO, - ]); + let unity_root_27 = *UNITY_ROOT_27; debug_assert_eq!(unity_root_27.pow([27]), Fq12::one()); if f.pow(EXP1.to_u64_digits()) == Fq12::ONE { @@ -115,8 +122,9 @@ impl FinalExp for Bn254 { tonelli_shanks_loop(&mut x3, &mut tmp, &mut t); + let unity_root_27_exp2 = *UNITY_ROOT_27_EXP2; while t != 0 { - tmp = unity_root_27.pow(EXP2.to_u64_digits()); + tmp = unity_root_27_exp2; x *= tmp; x3 = x.square() * x * c_inv; diff --git a/extensions/pairing/transpiler/Cargo.toml b/extensions/pairing/transpiler/Cargo.toml index a5557b03d1..9ce32bc85c 100644 --- a/extensions/pairing/transpiler/Cargo.toml +++ b/extensions/pairing/transpiler/Cargo.toml @@ -14,4 +14,3 @@ openvm-transpiler = { workspace = true } rrs-lib = { workspace = true } strum = { workspace = true } openvm-pairing-guest = { workspace = true } -openvm-instructions-derive = { workspace = true } diff --git a/extensions/pairing/transpiler/src/lib.rs b/extensions/pairing/transpiler/src/lib.rs index 7777c37c91..e80deaf154 100644 --- a/extensions/pairing/transpiler/src/lib.rs +++ b/extensions/pairing/transpiler/src/lib.rs @@ -1,71 +1,11 @@ use openvm_instructions::{ - instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant, + instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, PhantomDiscriminant, }; -use openvm_instructions_derive::LocalOpcode; use openvm_pairing_guest::{PairingBaseFunct7, OPCODE, PAIRING_FUNCT3}; use openvm_stark_backend::p3_field::PrimeField32; use openvm_transpiler::{TranspilerExtension, TranspilerOutput}; use rrs_lib::instruction_formats::RType; -use strum::{EnumCount, EnumIter, FromRepr}; - -// NOTE: the following opcodes are enabled only in testing and not enabled in the VM Extension -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x750] -#[repr(usize)] -#[allow(non_camel_case_types)] -pub enum PairingOpcode { - MILLER_DOUBLE_AND_ADD_STEP, - MILLER_DOUBLE_STEP, - EVALUATE_LINE, - MUL_013_BY_013, - MUL_023_BY_023, - MUL_BY_01234, - MUL_BY_02345, -} - -// NOTE: Fp12 opcodes are only enabled in testing and not enabled in the VM Extension -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x700] -#[repr(usize)] -#[allow(non_camel_case_types)] -pub enum Fp12Opcode { - ADD, - SUB, - MUL, -} -const FP12_OPS: usize = 4; - -pub struct Bn254Fp12Opcode(Fp12Opcode); - -impl LocalOpcode for Bn254Fp12Opcode { - const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET; - - fn from_usize(value: usize) -> Self { - Self(Fp12Opcode::from_usize(value)) - } - - fn local_usize(&self) -> usize { - self.0.local_usize() - } -} - -pub struct Bls12381Fp12Opcode(Fp12Opcode); - -impl LocalOpcode for Bls12381Fp12Opcode { - const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET + FP12_OPS; - - fn from_usize(value: usize) -> Self { - Self(Fp12Opcode::from_usize(value - FP12_OPS)) - } - - fn local_usize(&self) -> usize { - self.0.local_usize() + FP12_OPS - } -} +use strum::FromRepr; #[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)] #[repr(u16)] diff --git a/extensions/rv32-adapters/Cargo.toml b/extensions/rv32-adapters/Cargo.toml index adf133555b..54ec529e2c 100644 --- a/extensions/rv32-adapters/Cargo.toml +++ b/extensions/rv32-adapters/Cargo.toml @@ -19,9 +19,6 @@ openvm-instructions = { workspace = true } itertools.workspace = true derive-new.workspace = true rand.workspace = true -serde = { workspace = true, features = ["derive"] } -serde-big-array.workspace = true -serde_with.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/rv32-adapters/src/eq_mod.rs b/extensions/rv32-adapters/src/eq_mod.rs index ab80481f19..0d06ae83e2 100644 --- a/extensions/rv32-adapters/src/eq_mod.rs +++ b/extensions/rv32-adapters/src/eq_mod.rs @@ -1,26 +1,26 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -29,16 +29,13 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; -use serde_with::serde_as; /// This adapter reads from NUM_READS <= 2 pointers and writes to a register. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -47,7 +44,7 @@ use serde_with::serde_as; /// starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes are to 32-bit register rd. #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct Rv32IsEqualModAdapterCols< T, const NUM_READS: usize, @@ -227,209 +224,233 @@ impl< } } -pub struct Rv32IsEqualModAdapterChip< - F: Field, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32IsEqualModAdapterRecord< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCK_SIZE: usize, + const TOTAL_READ_SIZE: usize, +> { + pub from_pc: u32, + pub timestamp: u32, + + pub rs_ptr: [u32; NUM_READS], + pub rs_val: [u32; NUM_READS], + pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], + pub heap_read_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS], + + pub rd_ptr: u32, + pub writes_aux: MemoryWriteBytesAuxRecord, +} + +#[derive(Clone, Copy)] +pub struct Rv32IsEqualModAdapterExecutor< const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, > { - pub air: Rv32IsEqualModAdapterAir, + pointer_max_bits: usize, +} + +#[derive(derive_new::new)] +pub struct Rv32IsEqualModAdapterFiller< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCK_SIZE: usize, + const TOTAL_READ_SIZE: usize, +> { + pointer_max_bits: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, } impl< - F: PrimeField32, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, - > Rv32IsEqualModAdapterChip + > Rv32IsEqualModAdapterExecutor { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { + pub fn new(pointer_max_bits: usize) -> Self { assert!(NUM_READS <= 2); assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Self { - air: Rv32IsEqualModAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } + Self { pointer_max_bits } } } -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32IsEqualModReadRecord< - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCK_SIZE: usize, -> { - #[serde(with = "BigArray")] - pub rs: [RecordId; NUM_READS], - #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")] - pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS], -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32IsEqualModWriteRecord { - pub from_state: ExecutionState, - pub rd_id: RecordId, -} - impl< F: PrimeField32, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, - > VmAdapterChip - for Rv32IsEqualModAdapterChip + > AdapterTraceExecutor + for Rv32IsEqualModAdapterExecutor +where + F: PrimeField32, { - type ReadRecord = Rv32IsEqualModReadRecord; - type WriteRecord = Rv32IsEqualModWriteRecord; - type Air = Rv32IsEqualModAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, + const WIDTH: usize = + Rv32IsEqualModAdapterCols::::width(); + type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32IsEqualModAdapterRecord< NUM_READS, - 1, + BLOCKS_PER_READ, + BLOCK_SIZE, TOTAL_READ_SIZE, - RV32_REGISTER_NUM_LIMBS, >; - fn preprocess( - &mut self, - memory: &mut MemoryController, + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { let Instruction { b, c, d, e, .. } = *instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); + // Read register values + record.rs_val = from_fn(|i| { + record.rs_ptr[i] = if i == 0 { b } else { c }.as_canonical_u32(); - let read_records = rs_vals.map(|address| { - debug_assert!(address < (1 << self.air.address_bits)); - from_fn(|i| { - memory - .read::(e, F::from_canonical_u32(address + (i * BLOCK_SIZE) as u32)) - }) + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptr[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) }); - let read_data = read_records.map(|r| { - let read = r.map(|x| x.1); - let mut read_it = read.iter().flatten(); - from_fn(|_| *(read_it.next().unwrap())) - }); - let record = Rv32IsEqualModReadRecord { - rs: rs_records, - reads: read_records.map(|r| r.map(|x| x.0)), - }; - - Ok((read_data, record)) + // Read memory values + from_fn(|i| { + debug_assert!( + record.rs_val[i] as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits) + ); + from_fn::<_, BLOCKS_PER_READ, _>(|j| { + tracing_read::( + memory, + RV32_MEMORY_AS, + record.rs_val[i] + (j * BLOCK_SIZE) as u32, + &mut record.heap_read_aux[i][j].prev_timestamp, + ) + }) + .concat() + .try_into() + .unwrap() + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - debug_assert!( - memory.timestamp() - from_state.timestamp - == (NUM_READS * (BLOCKS_PER_READ + 1) + 1) as u32, - "timestamp delta is {}, expected {}", - memory.timestamp() - from_state.timestamp, - NUM_READS * (BLOCKS_PER_READ + 1) + 1 + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ) { + let Instruction { a, .. } = *instruction; + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + record.rd_ptr, + data, + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, ); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32IsEqualModAdapterCols = - row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rs = read_record.rs.map(|r| memory.record_by_id(r)); - for (i, r) in rs.iter().enumerate() { - row_slice.rs_ptr[i] = r.pointer; - row_slice.rs_val[i].copy_from_slice(r.data_slice()); - aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); - for (j, x) in read_record.reads[i].iter().enumerate() { - let read = memory.record_by_id(*x); - aux_cols_factory.generate_read_aux(read, &mut row_slice.heap_read_aux[i][j]); - } - } - - let rd = memory.record_by_id(write_record.rd_id); - row_slice.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); - - // Range checks - let need_range_check: [u32; 2] = from_fn(|i| { - if i < NUM_READS { - rs[i] - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() +impl< + F: PrimeField32, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCK_SIZE: usize, + const TOTAL_READ_SIZE: usize, + > AdapterTraceFiller + for Rv32IsEqualModAdapterFiller +{ + const WIDTH: usize = + Rv32IsEqualModAdapterCols::::width(); + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32IsEqualModAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCK_SIZE, + TOTAL_READ_SIZE, + > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + + let cols: &mut Rv32IsEqualModAdapterCols = + adapter_row.borrow_mut(); + + let mut timestamp = record.timestamp + (NUM_READS + NUM_READS * BLOCKS_PER_READ) as u32 + 1; + let mut timestamp_mm = || { + timestamp -= 1; + timestamp + }; + // Do range checks before writing anything: + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + self.bitwise_lookup_chip.request_range( + (record.rs_val[0] >> MSL_SHIFT) << limb_shift_bits, + if NUM_READS > 1 { + (record.rs_val[1] >> MSL_SHIFT) << limb_shift_bits } else { 0 - } - }); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits; - self.bitwise_lookup_chip.request_range( - need_range_check[0] << limb_shift_bits, - need_range_check[1] << limb_shift_bits, + }, ); - } + // Writing in reverse order + cols.writes_aux + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + timestamp_mm(), + cols.writes_aux.as_mut(), + ); + cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + + // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records + cols.heap_read_aux + .iter_mut() + .rev() + .zip(record.heap_read_aux.iter().rev()) + .for_each(|(col_reads, record_reads)| { + col_reads + .iter_mut() + .rev() + .zip(record_reads.iter().rev()) + .for_each(|(col, record)| { + mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut()); + }); + }); + + cols.rs_read_aux + .iter_mut() + .rev() + .zip(record.rs_read_aux.iter().rev()) + .for_each(|(col, record)| { + mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut()); + }); + + cols.rs_val = record + .rs_val + .map(|val| val.to_le_bytes().map(F::from_canonical_u8)); + cols.rs_ptr = record.rs_ptr.map(|ptr| F::from_canonical_u32(ptr)); - fn air(&self) -> &Self::Air { - &self.air + cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32-adapters/src/heap.rs b/extensions/rv32-adapters/src/heap.rs index cd9f93abbc..10409d97e9 100644 --- a/extensions/rv32-adapters/src/heap.rs +++ b/extensions/rv32-adapters/src/heap.rs @@ -1,38 +1,28 @@ -use std::{ - array::{self, from_fn}, - borrow::Borrow, - marker::PhantomData, -}; +use std::borrow::Borrow; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, - }, - system::{ - memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory}, - program::ProgramBus, + AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, BasicAdapterInterface, + ExecutionBridge, MinimalInstruction, VmAdapterAir, }, + system::memory::{offline_checker::MemoryBridge, online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{ instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, }; -use openvm_rv32im_circuit::adapters::read_rv32_register; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, PrimeField32}, }; -use super::{ - vec_heap_generate_trace_row_impl, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, - Rv32VecHeapReadRecord, Rv32VecHeapWriteRecord, +use crate::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor, + Rv32VecHeapAdapterFiller, Rv32VecHeapAdapterRecord, }; /// This adapter reads from NUM_READS <= 2 pointers and writes to 1 pointer. @@ -101,137 +91,95 @@ impl< } } -pub struct Rv32HeapAdapterChip< - F: Field, +#[derive(Clone, Copy)] +pub struct Rv32HeapAdapterExecutor< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, -> { - pub air: Rv32HeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, +>(Rv32VecHeapAdapterExecutor); + +impl + Rv32HeapAdapterExecutor +{ + pub fn new(pointer_max_bits: usize) -> Self { + assert!(NUM_READS <= 2); + assert!( + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" + ); + Rv32HeapAdapterExecutor(Rv32VecHeapAdapterExecutor::new(pointer_max_bits)) + } } -impl - Rv32HeapAdapterChip +pub struct Rv32HeapAdapterFiller< + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +>(Rv32VecHeapAdapterFiller); + +impl + Rv32HeapAdapterFiller { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Self { - air: Rv32HeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + Rv32HeapAdapterFiller(Rv32VecHeapAdapterFiller::new( + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, - } + )) } } impl - VmAdapterChip for Rv32HeapAdapterChip + AdapterTraceExecutor for Rv32HeapAdapterExecutor +where + F: PrimeField32, { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord<1, WRITE_SIZE>; - type Air = Rv32HeapAdapterAir; - type Interface = - BasicAdapterInterface, NUM_READS, 1, READ_SIZE, WRITE_SIZE>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; - - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); - - let read_records = rs_vals.map(|address| { - debug_assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - [memory.read::(e, F::from_canonical_u32(address))] - }); - let read_data = read_records.map(|r| r[0].1); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| array::from_fn(|i| r[i].0)), - }; - - Ok((read_data, record)) + const WIDTH: usize = + Rv32VecHeapAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; 1]; + type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord; + + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let writes = [memory.write(e, read_record.rd_val, output.writes[0]).0]; - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 6, - "timestamp delta is {}, expected 6", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let read_data = AdapterTraceExecutor::::read(&self.0, memory, instruction, record); + read_data.map(|r| r[0]) } - fn generate_trace_row( + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ); + AdapterTraceExecutor::::write(&self.0, memory, instruction, data, record); } +} + +impl + AdapterTraceFiller for Rv32HeapAdapterFiller +{ + const WIDTH: usize = + Rv32VecHeapAdapterCols::::width(); - fn air(&self) -> &Self::Air { - &self.air + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, adapter_row: &mut [F]) { + AdapterTraceFiller::::fill_trace_row(&self.0, mem_helper, adapter_row); } } diff --git a/extensions/rv32-adapters/src/heap_branch.rs b/extensions/rv32-adapters/src/heap_branch.rs index 29c9a151c9..3585e5e91f 100644 --- a/extensions/rv32-adapters/src/heap_branch.rs +++ b/extensions/rv32-adapters/src/heap_branch.rs @@ -1,27 +1,23 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, - iter::once, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -29,16 +25,12 @@ use openvm_instructions::{ program::DEFAULT_PC_STEP, riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; -use openvm_rv32im_circuit::adapters::{ - read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, -}; +use openvm_rv32im_circuit::adapters::{tracing_read, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; /// This adapter reads from NUM_READS <= 2 pointers. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -170,158 +162,162 @@ impl VmA } } -pub struct Rv32HeapBranchAdapterChip { - pub air: Rv32HeapBranchAdapterAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32HeapBranchAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs_ptr: [u32; NUM_READS], + pub rs_vals: [u32; NUM_READS], + + pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], + pub heap_read_aux: [MemoryReadAuxRecord; NUM_READS], +} + +#[derive(Clone, Copy)] +pub struct Rv32HeapBranchAdapterExecutor { + pub pointer_max_bits: usize, +} + +#[derive(derive_new::new)] +pub struct Rv32HeapBranchAdapterFiller { + pub pointer_max_bits: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, } -impl - Rv32HeapBranchAdapterChip +impl + Rv32HeapBranchAdapterExecutor { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { + pub fn new(pointer_max_bits: usize) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Self { - air: Rv32HeapBranchAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } + Self { pointer_max_bits } } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Rv32HeapBranchReadRecord { - #[serde(with = "BigArray")] - pub rs_reads: [RecordId; NUM_READS], - #[serde(with = "BigArray")] - pub heap_reads: [RecordId; NUM_READS], -} - -impl VmAdapterChip - for Rv32HeapBranchAdapterChip +impl AdapterTraceExecutor + for Rv32HeapBranchAdapterExecutor { - type ReadRecord = Rv32HeapBranchReadRecord; - type WriteRecord = ExecutionState; - type Air = Rv32HeapBranchAdapterAir; - type Interface = BasicAdapterInterface, NUM_READS, 0, READ_SIZE, 0>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = (); + type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord; + + fn start(pc: u32, memory: &TracingMemory, adapter_record: &mut Self::RecordMut<'_>) { + adapter_record.from_pc = pc; + adapter_record.from_timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { let Instruction { a, b, d, e, .. } = *instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { a } else { b }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); - - let heap_records = rs_vals.map(|address| { - assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - memory.read::(e, F::from_canonical_u32(address)) + // Read register values + record.rs_vals = from_fn(|i| { + record.rs_ptr[i] = if i == 0 { a } else { b }.as_canonical_u32(); + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptr[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) }); - let record = Rv32HeapBranchReadRecord { - rs_reads: rs_records, - heap_reads: heap_records.map(|r| r.0), - }; - Ok((heap_records.map(|r| r.1), record)) + // Read memory values + from_fn(|i| { + debug_assert!( + record.rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits) + ); + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[i], + &mut record.heap_read_aux[i].prev_timestamp, + ) + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + _memory: &mut TracingMemory, _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 4, - "timestamp delta is {}, expected 4", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - from_state, - )) + _data: Self::WriteData, + _record: &mut Self::RecordMut<'_>, + ) { + // This adapter doesn't write anything } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = - row_slice.borrow_mut(); - row_slice.from_state = write_record.map(F::from_canonical_u32); +impl AdapterTraceFiller + for Rv32HeapBranchAdapterFiller +{ + const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); - let rs_reads = read_record.rs_reads.map(|r| memory.record_by_id(r)); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32HeapBranchAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let cols: &mut Rv32HeapBranchAdapterCols = + adapter_row.borrow_mut(); - for (i, rs_read) in rs_reads.iter().enumerate() { - row_slice.rs_ptr[i] = rs_read.pointer; - row_slice.rs_val[i].copy_from_slice(rs_read.data_slice()); - aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]); - } + // Range checks: + // **NOTE**: Must do the range checks before overwriting the records + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + if NUM_READS > 1 { + (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits + } else { + 0 + }, + ); - for (i, heap_read) in read_record.heap_reads.iter().enumerate() { - let record = memory.record_by_id(*heap_read); - aux_cols_factory.generate_read_aux(record, &mut row_slice.heap_read_aux[i]); + // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records + for i in (0..NUM_READS).rev() { + mem_helper.fill( + record.heap_read_aux[i].prev_timestamp, + record.from_timestamp + (i + NUM_READS) as u32, + cols.heap_read_aux[i].as_mut(), + ); } - // Range checks: - let need_range_check: Vec = rs_reads - .iter() - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }) - .chain(once(0)) // in case NUM_READS is odd - .collect(); - debug_assert!(self.air.address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits; - for pair in need_range_check.chunks_exact(2) { - self.bitwise_lookup_chip - .request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + for i in (0..NUM_READS).rev() { + mem_helper.fill( + record.rs_read_aux[i].prev_timestamp, + record.from_timestamp + i as u32, + cols.rs_read_aux[i].as_mut(), + ); } - } - fn air(&self) -> &Self::Air { - &self.air + cols.rs_val + .iter_mut() + .rev() + .zip(record.rs_vals.iter().rev()) + .for_each(|(col, record)| { + *col = record.to_le_bytes().map(F::from_canonical_u8); + }); + + cols.rs_ptr + .iter_mut() + .rev() + .zip(record.rs_ptr.iter().rev()) + .for_each(|(col, record)| { + *col = F::from_canonical_u32(*record); + }); + + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32-adapters/src/lib.rs b/extensions/rv32-adapters/src/lib.rs index d84c82f617..6d884daedf 100644 --- a/extensions/rv32-adapters/src/lib.rs +++ b/extensions/rv32-adapters/src/lib.rs @@ -2,13 +2,11 @@ mod eq_mod; mod heap; mod heap_branch; mod vec_heap; -mod vec_heap_two_reads; pub use eq_mod::*; pub use heap::*; pub use heap_branch::*; pub use vec_heap::*; -pub use vec_heap_two_reads::*; #[cfg(any(test, feature = "test-utils"))] mod test_utils; diff --git a/extensions/rv32-adapters/src/vec_heap.rs b/extensions/rv32-adapters/src/vec_heap.rs index fab0df3334..ea3fc80113 100644 --- a/extensions/rv32-adapters/src/vec_heap.rs +++ b/extensions/rv32-adapters/src/vec_heap.rs @@ -2,25 +2,26 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, iter::{once, zip}, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VecHeapAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + ExecutionBridge, ExecutionState, VecHeapAdapterInterface, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -29,15 +30,13 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; /// This adapter reads from R (R <= 2) pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -46,89 +45,8 @@ use serde_with::serde_as; /// starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the /// heap, starting from the address in `rd`. -#[derive(Clone)] -pub struct Rv32VecHeapAdapterChip< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub air: - Rv32VecHeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - Rv32VecHeapAdapterChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - assert!(NUM_READS <= 2); - assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" - ); - Self { - air: Rv32VecHeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - #[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -#[serde(bound = "F: Field")] -pub struct Rv32VecHeapReadRecord< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const READ_SIZE: usize, -> { - /// Read register value from address space e=1 - #[serde_as(as = "[_; NUM_READS]")] - pub rs: [RecordId; NUM_READS], - /// Read register value from address space d=1 - pub rd: RecordId, - - pub rd_val: F, - - #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")] - pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS], -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32VecHeapWriteRecord { - pub from_state: ExecutionState, - #[serde_as(as = "[_; BLOCKS_PER_WRITE]")] - pub writes: [RecordId; BLOCKS_PER_WRITE], -} - -#[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct Rv32VecHeapAdapterCols< T, const NUM_READS: usize, @@ -346,6 +264,55 @@ impl< } } +// Intermediate type that should not be copied or cloned and should be directly written to +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32VecHeapAdapterRecord< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs_ptrs: [u32; NUM_READS], + pub rd_ptr: u32, + + pub rs_vals: [u32; NUM_READS], + pub rd_val: u32, + + pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], + pub rd_read_aux: MemoryReadAuxRecord, + + pub reads_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS], + pub writes_aux: [MemoryWriteBytesAuxRecord; BLOCKS_PER_WRITE], +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct Rv32VecHeapAdapterExecutor< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pointer_max_bits: usize, +} + +#[derive(derive_new::new)] +pub struct Rv32VecHeapAdapterFiller< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pointer_max_bits: usize, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + impl< F: PrimeField32, const NUM_READS: usize, @@ -353,9 +320,8 @@ impl< const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterChip - for Rv32VecHeapAdapterChip< - F, + > AdapterTraceExecutor + for Rv32VecHeapAdapterExecutor< NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, @@ -363,184 +329,246 @@ impl< WRITE_SIZE, > { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord; - type Air = - Rv32VecHeapAdapterAir; - type Interface = VecHeapAdapterInterface< + const WIDTH: usize = Rv32VecHeapAdapterCols::< F, NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, + >::width(); + type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; + type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, >; - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; + record: &mut &mut Rv32VecHeapAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, + ) -> Self::ReadData { + let &Instruction { a, b, c, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); // Read register values - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record + record.rs_vals = from_fn(|i| { + record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32(); + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptrs[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); + + record.rd_ptr = a.as_canonical_u32(); + record.rd_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.rd_read_aux.prev_timestamp, + )); // Read memory values - let read_records = rs_vals.map(|address| { - assert!( - address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.air.address_bits) + from_fn(|i| { + debug_assert!( + (record.rs_vals[i] + (READ_SIZE * BLOCKS_PER_READ - 1) as u32) + < (1 << self.pointer_max_bits) as u32 ); - from_fn(|i| { - memory.read::(e, F::from_canonical_u32(address + (i * READ_SIZE) as u32)) + from_fn(|j| { + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[i] + (j * READ_SIZE) as u32, + &mut record.reads_aux[i][j].prev_timestamp, + ) }) - }); - let read_data = read_records.map(|r| r.map(|x| x.1)); - assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits)); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| r.map(|x| x.0)), - }; - - Ok((read_data, record)) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let mut i = 0; - let writes = output.writes.map(|write| { - let (record_id, _) = memory.write( - e, - read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32), - write, - ); - i += 1; - record_id - }); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + }) } - fn generate_trace_row( + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut &mut Rv32VecHeapAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ) - } + debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS); + + debug_assert!( + record.rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 + < (1 << self.pointer_max_bits) + ); - fn air(&self) -> &Self::Air { - &self.air + #[allow(clippy::needless_range_loop)] + for i in 0..BLOCKS_PER_WRITE { + tracing_write( + memory, + RV32_MEMORY_AS, + record.rd_val + (i * WRITE_SIZE) as u32, + data[i], + &mut record.writes_aux[i].prev_timestamp, + &mut record.writes_aux[i].prev_data, + ); + } } } -pub(super) fn vec_heap_generate_trace_row_impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->( - row_slice: &mut [F], - read_record: &Rv32VecHeapReadRecord, - write_record: &Rv32VecHeapWriteRecord, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - address_bits: usize, - memory: &OfflineMemory, -) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32VecHeapAdapterCols< +impl< + F: PrimeField32, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceFiller + for Rv32VecHeapAdapterFiller< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > +{ + const WIDTH: usize = Rv32VecHeapAdapterCols::< F, NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rd = memory.record_by_id(read_record.rd); - let rs = read_record - .rs - .into_iter() - .map(|r| memory.record_by_id(r)) - .collect::>(); - - row_slice.rd_ptr = rd.pointer; - row_slice.rd_val.copy_from_slice(rd.data_slice()); - - for (i, r) in rs.iter().enumerate() { - row_slice.rs_ptr[i] = r.pointer; - row_slice.rs_val[i].copy_from_slice(r.data_slice()); - aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); - } + >::width(); - aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32VecHeapAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; - for (i, reads) in read_record.reads.iter().enumerate() { - for (j, &x) in reads.iter().enumerate() { - let record = memory.record_by_id(x); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads_aux[i][j]); + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + + // Range checks: + // **NOTE**: Must do the range checks before overwriting the records + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + if NUM_READS > 1 { + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); + } else { + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); } - } - for (i, &w) in write_record.writes.iter().enumerate() { - let record = memory.record_by_id(w); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); - } + let timestamp_delta = NUM_READS + 1 + NUM_READS * BLOCKS_PER_READ + BLOCKS_PER_WRITE; + let mut timestamp = record.from_timestamp + timestamp_delta as u32; + let mut timestamp_mm = || { + timestamp -= 1; + timestamp + }; - // Range checks: - let need_range_check: Vec = rs - .iter() - .chain(std::iter::repeat_n(&rd, 2)) - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }) - .collect(); - debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits; - for pair in need_range_check.chunks_exact(2) { - bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records + record + .writes_aux + .iter() + .rev() + .zip(cols.writes_aux.iter_mut().rev()) + .for_each(|(write, cols_write)| { + cols_write.set_prev_data(write.prev_data.map(F::from_canonical_u8)); + mem_helper.fill(write.prev_timestamp, timestamp_mm(), cols_write.as_mut()); + }); + + record + .reads_aux + .iter() + .zip(cols.reads_aux.iter_mut()) + .rev() + .for_each(|(reads, cols_reads)| { + reads + .iter() + .zip(cols_reads.iter_mut()) + .rev() + .for_each(|(read, cols_read)| { + mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut()); + }); + }); + + mem_helper.fill( + record.rd_read_aux.prev_timestamp, + timestamp_mm(), + cols.rd_read_aux.as_mut(), + ); + + record + .rs_read_aux + .iter() + .zip(cols.rs_read_aux.iter_mut()) + .rev() + .for_each(|(aux, cols_aux)| { + mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut()); + }); + + cols.rd_val = record.rd_val.to_le_bytes().map(F::from_canonical_u8); + cols.rs_val + .iter_mut() + .rev() + .zip(record.rs_vals.iter().rev()) + .for_each(|(cols_val, val)| { + *cols_val = val.to_le_bytes().map(F::from_canonical_u8); + }); + cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + cols.rs_ptr + .iter_mut() + .rev() + .zip(record.rs_ptrs.iter().rev()) + .for_each(|(cols_ptr, ptr)| { + *cols_ptr = F::from_canonical_u32(*ptr); + }); + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32-adapters/src/vec_heap_two_reads.rs b/extensions/rv32-adapters/src/vec_heap_two_reads.rs deleted file mode 100644 index f829db8bbc..0000000000 --- a/extensions/rv32-adapters/src/vec_heap_two_reads.rs +++ /dev/null @@ -1,577 +0,0 @@ -use std::{ - array::from_fn, - borrow::{Borrow, BorrowMut}, - iter::zip, - marker::PhantomData, -}; - -use itertools::izip; -use openvm_circuit::{ - arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VecHeapTwoReadsAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface, - }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, - }, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{ - instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, -}; -use openvm_rv32im_circuit::adapters::{ - abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, -}; -use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, -}; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; - -/// This adapter reads from 2 pointers and writes to 1 pointer. -/// * The data is read from the heap (address space 2), and the pointers are read from registers -/// (address space 1). -/// * Reads take the form of `BLOCKS_PER_READX` consecutive reads of size `READ_SIZE` from the heap, -/// starting from the addresses in `rs[X]` -/// * NOTE that the two reads can read different numbers of blocks. -/// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the -/// heap, starting from the address in `rd`. -pub struct Rv32VecHeapTwoReadsAdapterChip< - F: Field, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub air: Rv32VecHeapTwoReadsAdapterAir< - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl< - F: PrimeField32, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - Rv32VecHeapTwoReadsAdapterChip< - F, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" - ); - Self { - air: Rv32VecHeapTwoReadsAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32VecHeapTwoReadsReadRecord< - F: Field, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const READ_SIZE: usize, -> { - /// Read register value from address space e=1 - pub rs1: RecordId, - pub rs2: RecordId, - /// Read register value from address space d=1 - pub rd: RecordId, - - pub rd_val: F, - - #[serde_as(as = "[_; BLOCKS_PER_READ1]")] - pub reads1: [RecordId; BLOCKS_PER_READ1], - #[serde_as(as = "[_; BLOCKS_PER_READ2]")] - pub reads2: [RecordId; BLOCKS_PER_READ2], -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Rv32VecHeapTwoReadsWriteRecord { - pub from_state: ExecutionState, - #[serde_as(as = "[_; BLOCKS_PER_WRITE]")] - pub writes: [RecordId; BLOCKS_PER_WRITE], -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct Rv32VecHeapTwoReadsAdapterCols< - T, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub from_state: ExecutionState, - - pub rs1_ptr: T, - pub rs2_ptr: T, - pub rd_ptr: T, - - pub rs1_val: [T; RV32_REGISTER_NUM_LIMBS], - pub rs2_val: [T; RV32_REGISTER_NUM_LIMBS], - pub rd_val: [T; RV32_REGISTER_NUM_LIMBS], - - pub rs1_read_aux: MemoryReadAuxCols, - pub rs2_read_aux: MemoryReadAuxCols, - pub rd_read_aux: MemoryReadAuxCols, - - pub reads1_aux: [MemoryReadAuxCols; BLOCKS_PER_READ1], - pub reads2_aux: [MemoryReadAuxCols; BLOCKS_PER_READ2], - pub writes_aux: [MemoryWriteAuxCols; BLOCKS_PER_WRITE], -} - -#[allow(dead_code)] -#[derive(Clone, Copy, Debug, derive_new::new)] -pub struct Rv32VecHeapTwoReadsAdapterAir< - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub(super) execution_bridge: ExecutionBridge, - pub(super) memory_bridge: MemoryBridge, - pub bus: BitwiseOperationLookupBus, - /// The max number of bits for an address in memory - address_bits: usize, -} - -impl< - F: Field, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > BaseAir - for Rv32VecHeapTwoReadsAdapterAir< - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > -{ - fn width(&self) -> usize { - Rv32VecHeapTwoReadsAdapterCols::< - F, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >::width() - } -} - -impl< - AB: InteractionBuilder, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > VmAdapterAir - for Rv32VecHeapTwoReadsAdapterAir< - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > -{ - type Interface = VecHeapTwoReadsAdapterInterface< - AB::Expr, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >; - - fn eval( - &self, - builder: &mut AB, - local: &[AB::Var], - ctx: AdapterAirContext, - ) { - let cols: &Rv32VecHeapTwoReadsAdapterCols< - _, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > = local.borrow(); - let timestamp = cols.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) - }; - - let ptrs = [cols.rs1_ptr, cols.rs2_ptr, cols.rd_ptr]; - let vals = [cols.rs1_val, cols.rs2_val, cols.rd_val]; - let auxs = [&cols.rs1_read_aux, &cols.rs2_read_aux, &cols.rd_read_aux]; - - // Read register values for rs1, rs2, rd - for (ptr, val, aux) in izip!(ptrs, vals, auxs) { - self.memory_bridge - .read( - MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr), - val, - timestamp_pp(), - aux, - ) - .eval(builder, ctx.instruction.is_valid.clone()); - } - - // Range checks: see Rv32VecHeaperAdapterAir - let need_range_check = [&cols.rs1_val, &cols.rs2_val, &cols.rd_val, &cols.rd_val] - .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1]); - - // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain - // the correct amount of bits - let limb_shift = AB::F::from_canonical_usize( - 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits), - ); - - // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS - // thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that - // limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))) - for pair in need_range_check.chunks_exact(2) { - self.bus - .send_range(pair[0] * limb_shift, pair[1] * limb_shift) - .eval(builder, ctx.instruction.is_valid.clone()); - } - - let rd_val_f: AB::Expr = abstract_compose(cols.rd_val); - let rs1_val_f: AB::Expr = abstract_compose(cols.rs1_val); - let rs2_val_f: AB::Expr = abstract_compose(cols.rs2_val); - - let e = AB::F::from_canonical_u32(RV32_MEMORY_AS); - // Reads from heap - for (i, (read, aux)) in zip(ctx.reads.0, &cols.reads1_aux).enumerate() { - self.memory_bridge - .read( - MemoryAddress::new( - e, - rs1_val_f.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE), - ), - read, - timestamp_pp(), - aux, - ) - .eval(builder, ctx.instruction.is_valid.clone()); - } - for (i, (read, aux)) in zip(ctx.reads.1, &cols.reads2_aux).enumerate() { - self.memory_bridge - .read( - MemoryAddress::new( - e, - rs2_val_f.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE), - ), - read, - timestamp_pp(), - aux, - ) - .eval(builder, ctx.instruction.is_valid.clone()); - } - - // Writes to heap - for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() { - self.memory_bridge - .write( - MemoryAddress::new( - e, - rd_val_f.clone() + AB::Expr::from_canonical_usize(i * WRITE_SIZE), - ), - write, - timestamp_pp(), - aux, - ) - .eval(builder, ctx.instruction.is_valid.clone()); - } - - self.execution_bridge - .execute_and_increment_or_set_pc( - ctx.instruction.opcode, - [ - cols.rd_ptr.into(), - cols.rs1_ptr.into(), - cols.rs2_ptr.into(), - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - e.into(), - ], - cols.from_state, - AB::F::from_canonical_usize(timestamp_delta), - (DEFAULT_PC_STEP, ctx.to_pc), - ) - .eval(builder, ctx.instruction.is_valid.clone()); - } - - fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { - let cols: &Rv32VecHeapTwoReadsAdapterCols< - _, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > = local.borrow(); - cols.from_state.pc - } -} - -impl< - F: PrimeField32, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > VmAdapterChip - for Rv32VecHeapTwoReadsAdapterChip< - F, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > -{ - type ReadRecord = - Rv32VecHeapTwoReadsReadRecord; - type WriteRecord = Rv32VecHeapTwoReadsWriteRecord; - type Air = Rv32VecHeapTwoReadsAdapterAir< - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >; - type Interface = VecHeapTwoReadsAdapterInterface< - F, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; - - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - let (rs1_record, rs1_val) = read_rv32_register(memory, d, b); - let (rs2_record, rs2_val) = read_rv32_register(memory, d, c); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); - - assert!(rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 < (1 << self.air.address_bits)); - let read1_records = from_fn(|i| { - memory.read::(e, F::from_canonical_u32(rs1_val + (i * READ_SIZE) as u32)) - }); - let read1_data = read1_records.map(|r| r.1); - assert!(rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 < (1 << self.air.address_bits)); - let read2_records = from_fn(|i| { - memory.read::(e, F::from_canonical_u32(rs2_val + (i * READ_SIZE) as u32)) - }); - let read2_data = read2_records.map(|r| r.1); - assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits)); - - let record = Rv32VecHeapTwoReadsReadRecord { - rs1: rs1_record, - rs2: rs2_record, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads1: read1_records.map(|r| r.0), - reads2: read2_records.map(|r| r.0), - }; - - Ok(((read1_data, read2_data), record)) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let mut i = 0; - let writes = output.writes.map(|write| { - let (record_id, _) = memory.write( - e, - read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32), - write, - ); - i += 1; - record_id - }); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) - } - - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - vec_heap_two_reads_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ) - } - - fn air(&self) -> &Self::Air { - &self.air - } -} - -pub(super) fn vec_heap_two_reads_generate_trace_row_impl< - F: PrimeField32, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->( - row_slice: &mut [F], - read_record: &Rv32VecHeapTwoReadsReadRecord, - write_record: &Rv32VecHeapTwoReadsWriteRecord, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - address_bits: usize, - memory: &OfflineMemory, -) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32VecHeapTwoReadsAdapterCols< - F, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rd = memory.record_by_id(read_record.rd); - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - - row_slice.rd_ptr = rd.pointer; - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - - row_slice.rd_val.copy_from_slice(rd.data_slice()); - row_slice.rs1_val.copy_from_slice(rs1.data_slice()); - row_slice.rs2_val.copy_from_slice(rs2.data_slice()); - - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.rs1_read_aux); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.rs2_read_aux); - aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); - - for (i, r) in read_record.reads1.iter().enumerate() { - let record = memory.record_by_id(*r); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads1_aux[i]); - } - - for (i, r) in read_record.reads2.iter().enumerate() { - let record = memory.record_by_id(*r); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads2_aux[i]); - } - - for (i, w) in write_record.writes.iter().enumerate() { - let record = memory.record_by_id(*w); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); - } - // Range checks: - let need_range_check = [ - &read_record.rs1, - &read_record.rs2, - &read_record.rd, - &read_record.rd, - ] - .map(|record| { - memory - .record_by_id(*record) - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }); - debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits; - for pair in need_range_check.chunks_exact(2) { - bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); - } -} diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 8b20385104..9f6bbb6824 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -21,15 +21,16 @@ derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true + # for div_rem: num-bigint.workspace = true num-integer.workspace = true serde = { workspace = true, features = ["derive", "std"] } -serde-big-array.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel", "jemalloc"] diff --git a/extensions/rv32im/circuit/src/adapters/alu.rs b/extensions/rv32im/circuit/src/adapters/alu.rs index b61e2a224a..08cfa31b08 100644 --- a/extensions/rv32im/circuit/src/adapters/alu.rs +++ b/extensions/rv32im/circuit/src/adapters/alu.rs @@ -1,25 +1,23 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -32,60 +30,10 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -use super::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; - -/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e. -/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c -/// is an immediate). -pub struct Rv32BaseAluAdapterChip { - pub air: Rv32BaseAluAdapterAir, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl Rv32BaseAluAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - Self { - air: Rv32BaseAluAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32BaseAluReadRecord { - /// Read register value from address space d=1 - pub rs1: RecordId, - /// Either - /// - read rs2 register value or - /// - if `rs2_is_imm` is true, this is None - pub rs2: Option, - /// immediate value of rs2 or 0 - pub rs2_imm: F, -} -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32BaseAluWriteRecord { - pub from_state: ExecutionState, - /// Write to destination register - pub rd: (RecordId, [F; 4]), -} +use super::{ + tracing_read, tracing_read_imm, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; #[repr(C)] #[derive(AlignedBorrow)] @@ -101,7 +49,9 @@ pub struct Rv32BaseAluAdapterCols { pub writes_aux: MemoryWriteAuxCols, } -#[allow(dead_code)] +/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e. +/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c +/// is an immediate). #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32BaseAluAdapterAir { pub(super) execution_bridge: ExecutionBridge, @@ -213,129 +163,169 @@ impl VmAdapterAir for Rv32BaseAluAdapterAir { } } -impl VmAdapterChip for Rv32BaseAluAdapterChip { - type ReadRecord = Rv32BaseAluReadRecord; - type WriteRecord = Rv32BaseAluWriteRecord; - type Air = Rv32BaseAluAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; +#[derive(Clone, derive_new::new)] +pub struct Rv32BaseAluAdapterExecutor; + +#[derive(derive_new::new)] +pub struct Rv32BaseAluAdapterFiller { + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + +// Intermediate type that should not be copied or cloned and should be directly written to +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32BaseAluAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rd_ptr: u32, + pub rs1_ptr: u32, + /// Pointer if rs2 was a read, immediate value otherwise + pub rs2: u32, + /// 1 if rs2 was a read, 0 if an immediate + pub rs2_as: u8, + + pub reads_aux: [MemoryReadAuxRecord; 2], + pub writes_aux: MemoryWriteBytesAuxRecord, +} - fn preprocess( - &mut self, - memory: &mut MemoryController, +impl AdapterTraceExecutor + for Rv32BaseAluAdapterExecutor +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + type RecordMut<'a> = &'a mut Rv32BaseAluAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BaseAluAdapterRecord) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + // @dev cannot get rid of double &mut due to trait + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; + record: &mut &mut Rv32BaseAluAdapterRecord, + ) -> Self::ReadData { + let &Instruction { b, c, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert!( - e.as_canonical_u32() == RV32_IMM_AS || e.as_canonical_u32() == RV32_REGISTER_AS + e.as_canonical_u32() == RV32_REGISTER_AS || e.as_canonical_u32() == RV32_IMM_AS ); - let rs1 = memory.read::(d, b); - let (rs2, rs2_data, rs2_imm) = if e.is_zero() { - let c_u32 = c.as_canonical_u32(); - debug_assert_eq!(c_u32 >> 24, 0); - memory.increment_timestamp(); - ( - None, - [ - c_u32 as u8, - (c_u32 >> 8) as u8, - (c_u32 >> 16) as u8, - (c_u32 >> 16) as u8, - ] - .map(F::from_canonical_u8), - c, + record.rs1_ptr = b.as_canonical_u32(); + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + record.rs1_ptr, + &mut record.reads_aux[0].prev_timestamp, + ); + + let rs2 = if e.as_canonical_u32() == RV32_REGISTER_AS { + record.rs2_as = RV32_REGISTER_AS as u8; + record.rs2 = c.as_canonical_u32(); + + tracing_read( + memory, + RV32_REGISTER_AS, + record.rs2, + &mut record.reads_aux[1].prev_timestamp, ) } else { - let rs2_read = memory.read::(e, c); - (Some(rs2_read.0), rs2_read.1, F::ZERO) + record.rs2_as = RV32_IMM_AS as u8; + + tracing_read_imm(memory, c.as_canonical_u32(), &mut record.rs2) }; - Ok(( - [rs1.1, rs2_data], - Self::ReadRecord { - rs1: rs1.0, - rs2, - rs2_imm, - }, - )) + [rs1, rs2] } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = instruction; - let rd = memory.write(*d, *a, output.writes[0]); - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 3, - "timestamp delta is {}, expected 3", - timestamp_delta - ); + data: Self::WriteData, + record: &mut &mut Rv32BaseAluAdapterRecord, + ) { + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd }, - )) + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + record.rd_ptr, + data[0], + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, + ); } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - let row_slice: &mut Rv32BaseAluAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - let rd = memory.record_by_id(write_record.rd.0); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.rd_ptr = rd.pointer; - - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = read_record.rs2.map(|rs2| memory.record_by_id(rs2)); - row_slice.rs1_ptr = rs1.pointer; - - if let Some(rs2) = rs2 { - row_slice.rs2 = rs2.pointer; - row_slice.rs2_as = rs2.address_space; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); +impl AdapterTraceFiller + for Rv32BaseAluAdapterFiller +{ + const WIDTH: usize = size_of::>(); + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + // SAFETY: the following is highly unsafe. We are going to cast `adapter_row` to a record + // buffer, and then do an _overlapping_ write to the `adapter_row` as a row of field + // elements. This requires: + // - Cols struct should be repr(C) and we write in reverse order (to ensure non-overlapping) + // - Do not overwrite any reference in `record` before it has already been used or moved + // - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic + // otherwise) + let record: &Rv32BaseAluAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); + + // We must assign in reverse + const TIMESTAMP_DELTA: u32 = 2; + let mut timestamp = record.from_timestamp + TIMESTAMP_DELTA; + + adapter_row + .writes_aux + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + timestamp, + adapter_row.writes_aux.as_mut(), + ); + timestamp -= 1; + + if record.rs2_as != 0 { + mem_helper.fill( + record.reads_aux[1].prev_timestamp, + timestamp, + adapter_row.reads_aux[1].as_mut(), + ); } else { - row_slice.rs2 = read_record.rs2_imm; - row_slice.rs2_as = F::ZERO; - let rs2_imm = row_slice.rs2.as_canonical_u32(); + mem_helper.fill_zero(adapter_row.reads_aux[1].as_mut()); + let rs2_imm = record.rs2; let mask = (1 << RV32_CELL_BITS) - 1; self.bitwise_lookup_chip .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask); - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - // row_slice.reads_aux[1] is disabled } - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); - } + timestamp -= 1; + + mem_helper.fill( + record.reads_aux[0].prev_timestamp, + timestamp, + adapter_row.reads_aux[0].as_mut(), + ); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.rs2_as = F::from_canonical_u8(record.rs2_as); + adapter_row.rs2 = F::from_canonical_u32(record.rs2); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); + adapter_row.from_state.timestamp = F::from_canonical_u32(timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/branch.rs b/extensions/rv32im/circuit/src/adapters/branch.rs index 3e26f37f4c..3f891f0791 100644 --- a/extensions/rv32im/circuit/src/adapters/branch.rs +++ b/extensions/rv32im/circuit/src/adapters/branch.rs @@ -1,22 +1,17 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -26,48 +21,9 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c. -/// Operands d and e can only be 1. -#[derive(Debug)] -pub struct Rv32BranchAdapterChip { - pub air: Rv32BranchAdapterAir, - _marker: PhantomData, -} - -impl Rv32BranchAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32BranchAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32BranchReadRecord { - /// Read register value from address space d = 1 - pub rs1: RecordId, - /// Read register value from address space e = 1 - pub rs2: RecordId, -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32BranchWriteRecord { - pub from_state: ExecutionState, -} +use crate::adapters::tracing_read; #[repr(C)] #[derive(AlignedBorrow)] @@ -149,80 +105,108 @@ impl VmAdapterAir for Rv32BranchAdapterAir { } } -impl VmAdapterChip for Rv32BranchAdapterChip { - type ReadRecord = Rv32BranchReadRecord; - type WriteRecord = Rv32BranchWriteRecord; - type Air = Rv32BranchAdapterAir; - type Interface = BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32BranchAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub reads_aux: [MemoryReadAuxRecord; 2], +} - fn preprocess( - &mut self, - memory: &mut MemoryController, +/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c. +/// Operands d and e can only be 1. +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32BranchAdapterExecutor; + +#[derive(derive_new::new)] +pub struct Rv32BranchAdapterFiller; + +impl AdapterTraceExecutor for Rv32BranchAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = (); + type RecordMut<'a> = &'a mut Rv32BranchAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BranchAdapterRecord) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, d, e, .. } = *instruction; + record: &mut &mut Rv32BranchAdapterRecord, + ) -> Self::ReadData { + let &Instruction { a, b, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_REGISTER_AS); - let rs1 = memory.read::(d, a); - let rs2 = memory.read::(e, b); - - Ok(( - [rs1.1, rs2.1], - Self::ReadRecord { - rs1: rs1.0, - rs2: rs2.0, - }, - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 2, - "timestamp delta is {}, expected 2", - timestamp_delta + record.rs1_ptr = a.as_canonical_u32(); + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.reads_aux[0].prev_timestamp, + ); + record.rs2_ptr = b.as_canonical_u32(); + let rs2 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut record.reads_aux[1].prev_timestamp, ); - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state }, - )) + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _data: Self::WriteData, + _record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32BranchAdapterCols<_> = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); + // This function is intentionally left empty } +} + +impl AdapterTraceFiller for Rv32BranchAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32BranchAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32BranchAdapterCols = adapter_row.borrow_mut(); + + // We must assign in reverse + let timestamp = record.from_timestamp; + + mem_helper.fill( + record.reads_aux[1].prev_timestamp, + timestamp + 1, + adapter_row.reads_aux[1].as_mut(), + ); + + mem_helper.fill( + record.reads_aux[0].prev_timestamp, + timestamp, + adapter_row.reads_aux[0].as_mut(), + ); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); } } diff --git a/extensions/rv32im/circuit/src/adapters/jalr.rs b/extensions/rv32im/circuit/src/adapters/jalr.rs index f7dbf623b8..c1b3434e83 100644 --- a/extensions/rv32im/circuit/src/adapters/jalr.rs +++ b/extensions/rv32im/circuit/src/adapters/jalr.rs @@ -1,23 +1,20 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, Result, SignedImmInstruction, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, SignedImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -27,44 +24,9 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd) -#[derive(Debug)] -pub struct Rv32JalrAdapterChip { - pub air: Rv32JalrAdapterAir, - _marker: PhantomData, -} - -impl Rv32JalrAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32JalrAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32JalrReadRecord { - pub rs1: RecordId, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32JalrWriteRecord { - pub from_state: ExecutionState, - pub rd_id: Option, -} +use crate::adapters::{tracing_read, tracing_write}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -179,84 +141,126 @@ impl VmAdapterAir for Rv32JalrAdapterAir { } } -impl VmAdapterChip for Rv32JalrAdapterChip { - type ReadRecord = Rv32JalrReadRecord; - type WriteRecord = Rv32JalrWriteRecord; - type Air = Rv32JalrAdapterAir; - type Interface = BasicAdapterInterface< - F, - SignedImmInstruction, - 1, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, d, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32JalrAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs1_ptr: u32, + // Will use u32::MAX to indicate no write + pub rd_ptr: u32, + + pub reads_aux: MemoryReadAuxRecord, + pub writes_aux: MemoryWriteBytesAuxRecord, +} - let rs1 = memory.read::(d, b); +// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd) +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32JalrAdapterExecutor; + +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32JalrAdapterFiller; - Ok(([rs1.1], Rv32JalrReadRecord { rs1: rs1.0 })) +impl AdapterTraceExecutor for Rv32JalrAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [u8; RV32_REGISTER_NUM_LIMBS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32JalrAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { - a, d, f: enabled, .. - } = *instruction; - let rd_id = if enabled != F::ZERO { - let (record_id, _) = memory.write(d, a, output.writes[0]); - Some(record_id) - } else { - memory.increment_timestamp(); - None - }; + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + record.rs1_ptr = b.as_canonical_u32(); + tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut record.reads_aux.prev_timestamp, + ) } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32JalrAdapterCols<_> = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1); - adapter_cols.rs1_ptr = rs1.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); - if let Some(id) = write_record.rd_id { - let rd = memory.record_by_id(id); - adapter_cols.rd_ptr = rd.pointer; - adapter_cols.needs_write = F::ONE; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); + let &Instruction { + a, d, f: enabled, .. + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + if enabled.is_one() { + record.rd_ptr = a.as_canonical_u32(); + + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + data, + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, + ); + } else { + record.rd_ptr = u32::MAX; + memory.increment_timestamp(); } } +} + +impl AdapterTraceFiller for Rv32JalrAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32JalrAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); + + // We must assign in reverse + adapter_row.needs_write = F::from_bool(record.rd_ptr != u32::MAX); + + if record.rd_ptr != u32::MAX { + adapter_row + .rd_aux_cols + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + record.from_timestamp + 1, + adapter_row.rd_aux_cols.as_mut(), + ); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); + } else { + adapter_row.rd_ptr = F::ZERO; + } - fn air(&self) -> &Self::Air { - &self.air + mem_helper.fill( + record.reads_aux.prev_timestamp, + record.from_timestamp, + adapter_row.rs1_aux_cols.as_mut(), + ); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/loadstore.rs b/extensions/rv32im/circuit/src/adapters/loadstore.rs index b92680a0c7..8e151789b5 100644 --- a/extensions/rv32im/circuit/src/adapters/loadstore.rs +++ b/extensions/rv32im/circuit/src/adapters/loadstore.rs @@ -1,34 +1,36 @@ use std::{ - array, borrow::{Borrow, BorrowMut}, marker::PhantomData, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface, }, system::{ memory::{ offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols, + MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, + MemoryWriteAuxCols, }, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{memory_read_native, timed_write_native}, }, }; use openvm_circuit_primitives::{ utils::{not, select}, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, - riscv::{RV32_IMM_AS, RV32_REGISTER_AS}, - LocalOpcode, + riscv::{RV32_IMM_AS, RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, NATIVE_AS, }; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ @@ -36,10 +38,9 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use super::{compose, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::RV32_CELL_BITS; +use super::RV32_REGISTER_NUM_LIMBS; +use crate::adapters::{memory_read, timed_write, tracing_read, RV32_CELL_BITS}; /// LoadStore Adapter handles all memory and register operations, so it must be aware /// of the instruction type, specifically whether it is a load or store @@ -64,22 +65,6 @@ pub struct LoadStoreInstruction { pub store_shift_amount: T, } -/// The LoadStoreAdapter separates Runtime and Air AdapterInterfaces. -/// This is necessary because `prev_data` should be owned by the core chip and sent to the adapter, -/// and it must have an AB::Var type in AIR as to satisfy the memory_bridge interface. -/// This is achieved by having different types for reads and writes in Air AdapterInterface. -/// This method ensures that there are no modifications to the global interfaces. -/// -/// Here 2 reads represent read_data and prev_data, -/// The second element of the tuple in Reads is the shift amount needed to be passed to the core -/// chip Getting the intermediate pointer is completely internal to the adapter and shouldn't be a -/// part of the AdapterInterface -pub struct Rv32LoadStoreAdapterRuntimeInterface(PhantomData); -impl VmAdapterInterface for Rv32LoadStoreAdapterRuntimeInterface { - type Reads = ([[T; RV32_REGISTER_NUM_LIMBS]; 2], T); - type Writes = [[T; RV32_REGISTER_NUM_LIMBS]; 1]; - type ProcessedInstruction = (); -} pub struct Rv32LoadStoreAdapterAirInterface(PhantomData); /// Using AB::Var for prev_data and AB::Expr for read_data @@ -92,65 +77,6 @@ impl VmAdapterInterface for Rv32LoadStoreAdapt type ProcessedInstruction = LoadStoreInstruction; } -/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm. -/// In case of Loads, reads from the shifted intermediate pointer and writes to rd. -/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer. -pub struct Rv32LoadStoreAdapterChip { - pub air: Rv32LoadStoreAdapterAir, - pub range_checker_chip: SharedVariableRangeCheckerChip, - _marker: PhantomData, -} - -impl Rv32LoadStoreAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - pointer_max_bits: usize, - range_checker_chip: SharedVariableRangeCheckerChip, - ) -> Self { - assert!(range_checker_chip.range_max_bits() >= 15); - Self { - air: Rv32LoadStoreAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - range_bus: range_checker_chip.bus(), - pointer_max_bits, - }, - range_checker_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32LoadStoreReadRecord { - pub rs1_record: RecordId, - /// This will be a read from a register in case of Stores and a read from RISC-V memory in case - /// of Loads. - pub read: RecordId, - pub rs1_ptr: F, - pub imm: F, - pub imm_sign: F, - pub mem_as: F, - pub mem_ptr_limbs: [u32; 2], - pub shift_amount: u32, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32LoadStoreWriteRecord { - /// This will be a write to a register in case of Load and a write to RISC-V memory in case of - /// Stores. For better struct packing, `RecordId(usize::MAX)` is used to indicate that - /// there is no write. - pub write_id: RecordId, - pub from_state: ExecutionState, - pub rd_rs2_ptr: F, -} - #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32LoadStoreAdapterCols { @@ -366,22 +292,69 @@ impl VmAdapterAir for Rv32LoadStoreAdapterAir { } } -impl VmAdapterChip for Rv32LoadStoreAdapterChip { - type ReadRecord = Rv32LoadStoreReadRecord; - type WriteRecord = Rv32LoadStoreWriteRecord; - type Air = Rv32LoadStoreAdapterAir; - type Interface = Rv32LoadStoreAdapterRuntimeInterface; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32LoadStoreAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, - #[allow(clippy::type_complexity)] - fn preprocess( - &mut self, - memory: &mut MemoryController, + pub rs1_ptr: u32, + pub rs1_val: u32, + pub rs1_aux_record: MemoryReadAuxRecord, + + pub rd_rs2_ptr: u32, + pub read_data_aux: MemoryReadAuxRecord, + pub imm: u16, + pub imm_sign: bool, + + pub mem_as: u8, + + pub write_prev_timestamp: u32, +} + +/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm. +/// In case of Loads, reads from the shifted intermediate pointer and writes to rd. +/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer. +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32LoadStoreAdapterExecutor { + pointer_max_bits: usize, +} + +#[derive(derive_new::new)] +pub struct Rv32LoadStoreAdapterFiller { + pointer_max_bits: usize, + pub range_checker_chip: SharedVariableRangeCheckerChip, +} + +impl AdapterTraceExecutor for Rv32LoadStoreAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = ( + ( + [u32; RV32_REGISTER_NUM_LIMBS], + [u8; RV32_REGISTER_NUM_LIMBS], + ), + u8, + ); + type WriteData = [u32; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32LoadStoreAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { opcode, a, b, @@ -390,154 +363,193 @@ impl VmAdapterChip for Rv32LoadStoreAdapterChip { e, g, .. - } = *instruction; + } = instruction; + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert!(e.as_canonical_u32() != RV32_IMM_AS); let local_opcode = Rv32LoadStoreOpcode::from_usize( opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let rs1_record = memory.read::(d, b); - let rs1_val = compose(rs1_record.1); - let imm = c.as_canonical_u32(); - let imm_sign = g.as_canonical_u32(); - let imm_extended = imm + imm_sign * 0xffff0000; + record.rs1_ptr = b.as_canonical_u32(); + record.rs1_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs1_ptr, + &mut record.rs1_aux_record.prev_timestamp, + )); + + record.imm = c.as_canonical_u32() as u16; + record.imm_sign = g.is_one(); + let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000; + + let ptr_val = record.rs1_val.wrapping_add(imm_extended); + let shift_amount = ptr_val & 3; + let ptr_val = ptr_val - shift_amount; - let ptr_val = rs1_val.wrapping_add(imm_extended); - let shift_amount = ptr_val % 4; assert!( - ptr_val < (1 << self.air.pointer_max_bits), - "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}", - self.air.pointer_max_bits + ptr_val < (1 << self.pointer_max_bits), + "ptr_val: {ptr_val} = rs1_val: {} + imm_extended: {imm_extended} >= 2 ** {}", + record.rs1_val, + self.pointer_max_bits ); - let mem_ptr_limbs = array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff)); - - let ptr_val = ptr_val - shift_amount; - let read_record = match local_opcode { + // prev_data: We need to keep values of some cells to keep them unchanged when writing to + // those cells + let (read_data, prev_data) = match local_opcode { LOADW | LOADB | LOADH | LOADBU | LOADHU => { - memory.read::(e, F::from_canonical_u32(ptr_val)) + debug_assert_eq!(e, F::from_canonical_u32(RV32_MEMORY_AS)); + record.mem_as = RV32_MEMORY_AS as u8; + let read_data = tracing_read( + memory, + RV32_MEMORY_AS, + ptr_val, + &mut record.read_data_aux.prev_timestamp, + ); + let prev_data = memory_read(memory.data(), RV32_REGISTER_AS, a.as_canonical_u32()) + .map(u32::from); + (read_data, prev_data) } - STOREW | STOREH | STOREB => memory.read::(d, a), - }; - - // We need to keep values of some cells to keep them unchanged when writing to those cells - let prev_data = match local_opcode { - STOREW | STOREH | STOREB => array::from_fn(|i| { - memory.unsafe_read_cell(e, F::from_canonical_usize(ptr_val as usize + i)) - }), - LOADW | LOADB | LOADH | LOADBU | LOADHU => { - array::from_fn(|i| memory.unsafe_read_cell(d, a + F::from_canonical_usize(i))) + STOREW | STOREH | STOREB => { + let e = e.as_canonical_u32(); + debug_assert_ne!(e, RV32_IMM_AS); + debug_assert_ne!(e, RV32_REGISTER_AS); + record.mem_as = e as u8; + let read_data = tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.read_data_aux.prev_timestamp, + ); + let prev_data = if e == NATIVE_AS { + memory_read_native(memory.data(), ptr_val).map(|x: F| x.as_canonical_u32()) + } else { + memory_read(memory.data(), e, ptr_val).map(u32::from) + }; + (read_data, prev_data) } }; - Ok(( - ( - [prev_data, read_record.1], - F::from_canonical_u32(shift_amount), - ), - Self::ReadRecord { - rs1_record: rs1_record.0, - rs1_ptr: b, - read: read_record.0, - imm: c, - imm_sign: g, - shift_amount, - mem_ptr_limbs, - mem_as: e, - }, - )) + ((prev_data, read_data), shift_amount as u8) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ) { + let &Instruction { opcode, a, d, e, f: enabled, .. - } = *instruction; + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_ne!(e.as_canonical_u32(), RV32_IMM_AS); + debug_assert_ne!(e.as_canonical_u32(), RV32_REGISTER_AS); let local_opcode = Rv32LoadStoreOpcode::from_usize( opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let write_id = if enabled != F::ZERO { - let (record_id, _) = match local_opcode { + if enabled != F::ZERO { + record.rd_rs2_ptr = a.as_canonical_u32(); + + record.write_prev_timestamp = match local_opcode { STOREW | STOREH | STOREB => { - let ptr = read_record.mem_ptr_limbs[0] - + read_record.mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2)); - memory.write(e, F::from_canonical_u32(ptr & 0xfffffffc), output.writes[0]) + let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000; + let ptr = record.rs1_val.wrapping_add(imm_extended) & !3; + + if record.mem_as == 4 { + timed_write_native(memory, ptr, data.map(F::from_canonical_u32)).0 + } else { + timed_write(memory, record.mem_as as u32, ptr, data.map(|x| x as u8)).0 + } + } + LOADW | LOADB | LOADH | LOADBU | LOADHU => { + timed_write( + memory, + RV32_REGISTER_AS, + record.rd_rs2_ptr, + data.map(|x| x as u8), + ) + .0 } - LOADW | LOADB | LOADH | LOADBU | LOADHU => memory.write(d, a, output.writes[0]), }; - record_id } else { + record.rd_rs2_ptr = u32::MAX; memory.increment_timestamp(); - // RecordId will never get to usize::MAX, so it can be used as a flag for no write - RecordId(usize::MAX) }; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - write_id, - rd_rs2_ptr: a, - }, - )) } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - self.range_checker_chip.add_count( - (read_record.mem_ptr_limbs[0] - read_record.shift_amount) / 4, - RV32_CELL_BITS * 2 - 2, +impl AdapterTraceFiller for Rv32LoadStoreAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + debug_assert!(self.range_checker_chip.range_max_bits() >= 15); + + let record: &Rv32LoadStoreAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); + + let needs_write = record.rd_rs2_ptr != u32::MAX; + // Writing in reverse order + adapter_row.needs_write = F::from_bool(needs_write); + + if needs_write { + mem_helper.fill( + record.write_prev_timestamp, + record.from_timestamp + 2, + &mut adapter_row.write_base_aux, + ); + } else { + mem_helper.fill_zero(&mut adapter_row.write_base_aux); + } + + adapter_row.mem_as = F::from_canonical_u8(record.mem_as); + let ptr = record + .rs1_val + .wrapping_add(record.imm as u32 + record.imm_sign as u32 * 0xffff0000); + + let ptr_limbs = [ptr & 0xffff, ptr >> 16]; + self.range_checker_chip + .add_count(ptr_limbs[0] >> 2, RV32_CELL_BITS * 2 - 2); + self.range_checker_chip + .add_count(ptr_limbs[1], self.pointer_max_bits - 16); + adapter_row.mem_ptr_limbs = ptr_limbs.map(F::from_canonical_u32); + + adapter_row.imm_sign = F::from_bool(record.imm_sign); + adapter_row.imm = F::from_canonical_u16(record.imm); + + mem_helper.fill( + record.read_data_aux.prev_timestamp, + record.from_timestamp + 1, + adapter_row.read_data_aux.as_mut(), ); - self.range_checker_chip.add_count( - read_record.mem_ptr_limbs[1], - self.air.pointer_max_bits - RV32_CELL_BITS * 2, + adapter_row.rd_rs2_ptr = if record.rd_rs2_ptr != u32::MAX { + F::from_canonical_u32(record.rd_rs2_ptr) + } else { + F::ZERO + }; + + mem_helper.fill( + record.rs1_aux_record.prev_timestamp, + record.from_timestamp, + adapter_row.rs1_aux_cols.as_mut(), ); - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32LoadStoreAdapterCols<_> = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1_record); - adapter_cols.rs1_data.copy_from_slice(rs1.data_slice()); - aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); - adapter_cols.rs1_ptr = read_record.rs1_ptr; - adapter_cols.rd_rs2_ptr = write_record.rd_rs2_ptr; - let read = memory.record_by_id(read_record.read); - aux_cols_factory.generate_read_aux(read, &mut adapter_cols.read_data_aux); - adapter_cols.imm = read_record.imm; - adapter_cols.imm_sign = read_record.imm_sign; - adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32); - adapter_cols.mem_as = read_record.mem_as; - if write_record.write_id.0 != usize::MAX { - let write = memory.record_by_id(write_record.write_id); - aux_cols_factory.generate_base_aux(write, &mut adapter_cols.write_base_aux); - adapter_cols.needs_write = F::ONE; - } - } + adapter_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/mod.rs b/extensions/rv32im/circuit/src/adapters/mod.rs index ab15671b74..c95df6ac43 100644 --- a/extensions/rv32im/circuit/src/adapters/mod.rs +++ b/extensions/rv32im/circuit/src/adapters/mod.rs @@ -1,6 +1,13 @@ use std::ops::Mul; -use openvm_circuit::system::memory::{MemoryController, RecordId}; +use openvm_circuit::{ + arch::{execution_mode::E1ExecutionCtx, VmStateMut}, + system::memory::{ + merkle::public_values::PUBLIC_VALUES_AS, + online::{GuestMemory, TracingMemory}, + }, +}; +use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; mod alu; @@ -46,25 +53,177 @@ pub fn decompose(value: u32) -> [F; RV32_REGISTER_NUM_LIMBS] { }) } -/// Read register value as [RV32_REGISTER_NUM_LIMBS] limbs from memory. -/// Returns the read record and the register value as u32. -/// Does not make any range check calls. -pub fn read_rv32_register( - memory: &mut MemoryController, - address_space: F, - pointer: F, -) -> (RecordId, u32) { - debug_assert_eq!(address_space, F::ONE); - let record = memory.read::(address_space, pointer); - let val = compose(record.1); - (record.0, val) +#[inline(always)] +pub fn imm_to_bytes(imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { + debug_assert_eq!(imm >> 24, 0); + let mut imm_le = imm.to_le_bytes(); + imm_le[3] = imm_le[2]; + imm_le } -/// Peeks at the value of a register without updating the memory state or incrementing the -/// timestamp. -pub fn unsafe_read_rv32_register(memory: &MemoryController, pointer: F) -> u32 { - let data = memory.unsafe_read::(F::ONE, pointer); - compose(data) +#[inline(always)] +pub fn memory_read(memory: &GuestMemory, address_space: u32, ptr: u32) -> [u8; N] { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS, + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn memory_write( + memory: &mut GuestMemory, + address_space: u32, + ptr: u32, + data: [u8; N], +) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +pub fn timed_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn timed_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +/// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + prev_timestamp: &mut u32, +) -> [u8; N] { + let (t_prev, data) = timed_read(memory, address_space, ptr); + *prev_timestamp = t_prev; + data +} + +#[inline(always)] +pub fn tracing_read_imm( + memory: &mut TracingMemory, + imm: u32, + imm_mut: &mut u32, +) -> [u8; RV32_REGISTER_NUM_LIMBS] { + *imm_mut = imm; + debug_assert_eq!(imm >> 24, 0); // highest byte should be zero to prevent overflow + + memory.increment_timestamp(); + + let mut imm_le = imm.to_le_bytes(); + // Important: we set the highest byte equal to the second highest byte, using the assumption + // that imm is at most 24 bits + imm_le[3] = imm_le[2]; + imm_le +} + +/// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], + prev_timestamp: &mut u32, + prev_data: &mut [u8; N], +) { + let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data); + *prev_timestamp = t_prev; + *prev_data = data_prev; +} + +#[inline(always)] +pub fn memory_read_from_state( + state: &mut VmStateMut, + address_space: u32, + ptr: u32, +) -> [u8; N] +where + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(address_space, ptr, N as u32); + + memory_read(state.memory, address_space, ptr) +} + +#[inline(always)] +pub fn memory_write_from_state( + state: &mut VmStateMut, + address_space: u32, + ptr: u32, + data: [u8; N], +) where + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(address_space, ptr, N as u32); + + memory_write(state.memory, address_space, ptr, data) +} + +#[inline(always)] +pub fn read_rv32_register_from_state( + state: &mut VmStateMut, + ptr: u32, +) -> u32 +where + Ctx: E1ExecutionCtx, +{ + u32::from_le_bytes(memory_read_from_state(state, RV32_REGISTER_AS, ptr)) +} + +#[inline(always)] +pub fn read_rv32_register(memory: &GuestMemory, ptr: u32) -> u32 { + u32::from_le_bytes(memory_read(memory, RV32_REGISTER_AS, ptr)) } pub fn abstract_compose>( @@ -76,3 +235,8 @@ pub fn abstract_compose>( acc + limb * T::from_canonical_u32(1 << (i * RV32_CELL_BITS)) }) } + +// TEMP[jpw] +pub fn tmp_convert_to_u8s(data: [F; N]) -> [u8; N] { + data.map(|x| x.as_canonical_u32() as u8) +} diff --git a/extensions/rv32im/circuit/src/adapters/mul.rs b/extensions/rv32im/circuit/src/adapters/mul.rs index a82e83acaa..f0a8281c22 100644 --- a/extensions/rv32im/circuit/src/adapters/mul.rs +++ b/extensions/rv32im/circuit/src/adapters/mul.rs @@ -1,22 +1,20 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -26,49 +24,9 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -use super::RV32_REGISTER_NUM_LIMBS; - -/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d. -/// Operand d can only be 1, and there is no immediate support. -#[derive(Debug)] -pub struct Rv32MultAdapterChip { - pub air: Rv32MultAdapterAir, - _marker: PhantomData, -} - -impl Rv32MultAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32MultAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32MultReadRecord { - /// Reads from operand registers - pub rs1: RecordId, - pub rs2: RecordId, -} -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32MultWriteRecord { - pub from_state: ExecutionState, - /// Write to destination register - pub rd_id: RecordId, -} +use super::{tracing_write, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::tracing_read; #[repr(C)] #[derive(AlignedBorrow)] @@ -81,6 +39,8 @@ pub struct Rv32MultAdapterCols { pub writes_aux: MemoryWriteAuxCols, } +/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d. +/// Operand d can only be 1, and there is no immediate support. #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32MultAdapterAir { pub(super) execution_bridge: ExecutionBridge, @@ -167,92 +127,130 @@ impl VmAdapterAir for Rv32MultAdapterAir { } } -impl VmAdapterChip for Rv32MultAdapterChip { - type ReadRecord = Rv32MultReadRecord; - type WriteRecord = Rv32MultWriteRecord; - type Air = Rv32MultAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32MultAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, .. } = *instruction; + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + pub reads_aux: [MemoryReadAuxRecord; 2], + pub writes_aux: MemoryWriteBytesAuxRecord, +} + +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32MultAdapterExecutor; - let rs1 = memory.read::(d, b); - let rs2 = memory.read::(d, c); +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32MultAdapterFiller; - Ok(( - [rs1.1, rs2.1], - Self::ReadRecord { - rs1: rs1.0, - rs2: rs2.0, - }, - )) +impl AdapterTraceExecutor for Rv32MultAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + type RecordMut<'a> = &'a mut Rv32MultAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 3, - "timestamp delta is {}, expected 3", - timestamp_delta + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, c, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + record.rs1_ptr = b.as_canonical_u32(); + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut record.reads_aux[0].prev_timestamp, + ); + record.rs2_ptr = c.as_canonical_u32(); + let rs2 = tracing_read( + memory, + RV32_REGISTER_AS, + c.as_canonical_u32(), + &mut record.reads_aux[1].prev_timestamp, ); - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32MultAdapterCols<_> = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - let rd = memory.record_by_id(write_record.rd_id); - row_slice.rd_ptr = rd.pointer; - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + data[0], + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, + ) } +} + +impl AdapterTraceFiller for Rv32MultAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32MultAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32MultAdapterCols = adapter_row.borrow_mut(); + + let timestamp = record.from_timestamp; + + adapter_row + .writes_aux + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + timestamp + 2, + adapter_row.writes_aux.as_mut(), + ); + + mem_helper.fill( + record.reads_aux[1].prev_timestamp, + timestamp + 1, + adapter_row.reads_aux[1].as_mut(), + ); + + mem_helper.fill( + record.reads_aux[0].prev_timestamp, + timestamp, + adapter_row.reads_aux[0].as_mut(), + ); + + adapter_row.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/rdwrite.rs b/extensions/rv32im/circuit/src/adapters/rdwrite.rs index abd4d8eb17..02d669a2d9 100644 --- a/extensions/rv32im/circuit/src/adapters/rdwrite.rs +++ b/extensions/rv32im/circuit/src/adapters/rdwrite.rs @@ -1,23 +1,17 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -27,59 +21,9 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 -#[derive(Debug)] -pub struct Rv32RdWriteAdapterChip { - pub air: Rv32RdWriteAdapterAir, - _marker: PhantomData, -} - -/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 -#[derive(Debug)] -pub struct Rv32CondRdWriteAdapterChip { - /// Do not use the inner air directly, use `air` instead. - inner: Rv32RdWriteAdapterChip, - pub air: Rv32CondRdWriteAdapterAir, -} - -impl Rv32RdWriteAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32RdWriteAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -impl Rv32CondRdWriteAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - let inner = Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge); - let air = Rv32CondRdWriteAdapterAir { inner: inner.air }; - Self { inner, air } - } -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32RdWriteWriteRecord { - pub from_state: ExecutionState, - pub rd_id: Option, -} +use crate::adapters::tracing_write; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -92,16 +36,18 @@ pub struct Rv32RdWriteAdapterCols { #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32CondRdWriteAdapterCols { - inner: Rv32RdWriteAdapterCols, + pub inner: Rv32RdWriteAdapterCols, pub needs_write: T, } +/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32RdWriteAdapterAir { pub(super) memory_bridge: MemoryBridge, pub(super) execution_bridge: ExecutionBridge, } +/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32CondRdWriteAdapterAir { inner: Rv32RdWriteAdapterAir, @@ -241,131 +187,187 @@ impl VmAdapterAir for Rv32CondRdWriteAdapterAir { } } -impl VmAdapterChip for Rv32RdWriteAdapterChip { - type ReadRecord = (); - type WriteRecord = Rv32RdWriteWriteRecord; - type Air = Rv32RdWriteAdapterAir; - type Interface = BasicAdapterInterface, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>; - - fn preprocess( - &mut self, - _memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let d = instruction.d; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); +/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Rv32RdWriteAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + // Will use u32::MAX to indicate no write + pub rd_ptr: u32, + pub rd_aux_record: MemoryWriteBytesAuxRecord, +} - Ok(([], ())) +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32RdWriteAdapterExecutor; + +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32RdWriteAdapterFiller; + +impl AdapterTraceExecutor for Rv32RdWriteAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - rd_id: Some(rd_id), - }, - )) + #[inline(always)] + fn read( + &self, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + // Rv32RdWriteAdapter doesn't read anything } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32RdWriteAdapterCols = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rd = memory.record_by_id(write_record.rd_id.unwrap()); - adapter_cols.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + record.rd_ptr, + data, + &mut record.rd_aux_record.prev_timestamp, + &mut record.rd_aux_record.prev_data, + ); } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterTraceFiller for Rv32RdWriteAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32RdWriteAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32RdWriteAdapterCols = adapter_row.borrow_mut(); + + adapter_row + .rd_aux_cols + .set_prev_data(record.rd_aux_record.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.rd_aux_record.prev_timestamp, + record.from_timestamp, + adapter_row.rd_aux_cols.as_mut(), + ); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } -impl VmAdapterChip for Rv32CondRdWriteAdapterChip { - type ReadRecord = (); - type WriteRecord = Rv32RdWriteWriteRecord; - type Air = Rv32CondRdWriteAdapterAir; - type Interface = BasicAdapterInterface, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>; +/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32CondRdWriteAdapterExecutor { + inner: Rv32RdWriteAdapterExecutor, +} - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - self.inner.preprocess(memory, instruction) +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32CondRdWriteAdapterFiller { + inner: Rv32RdWriteAdapterFiller, +} + +impl AdapterTraceExecutor for Rv32CondRdWriteAdapterExecutor +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let rd_id = if instruction.f != F::ZERO { - let (rd_id, _) = memory.write(d, a, output.writes[0]); - Some(rd_id) - } else { - memory.increment_timestamp(); - None - }; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + >::read( + &self.inner, + memory, + instruction, + record, + ) } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32CondRdWriteAdapterCols = row_slice.borrow_mut(); - adapter_cols.inner.from_state = write_record.from_state.map(F::from_canonical_u32); - if let Some(rd_id) = write_record.rd_id { - let rd = memory.record_by_id(rd_id); - adapter_cols.inner.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.inner.rd_aux_cols); - adapter_cols.needs_write = F::ONE; + let Instruction { f: enabled, .. } = instruction; + + if enabled.is_one() { + >::write( + &self.inner, + memory, + instruction, + data, + record, + ); + } else { + memory.increment_timestamp(); + record.rd_ptr = u32::MAX; } } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterTraceFiller for Rv32CondRdWriteAdapterFiller { + const WIDTH: usize = size_of::>(); + + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32RdWriteAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_cols: &mut Rv32CondRdWriteAdapterCols = adapter_row.borrow_mut(); + + adapter_cols.needs_write = F::from_bool(record.rd_ptr != u32::MAX); + + if record.rd_ptr != u32::MAX { + unsafe { + self.inner.fill_trace_row( + mem_helper, + adapter_row + .split_at_mut_unchecked(size_of::>()) + .0, + ) + }; + } else { + adapter_cols.inner.rd_ptr = F::ZERO; + mem_helper.fill_zero(adapter_cols.inner.rd_aux_cols.as_mut()); + adapter_cols.inner.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_cols.inner.from_state.pc = F::from_canonical_u32(record.from_pc); + } } } diff --git a/extensions/rv32im/circuit/src/auipc/core.rs b/extensions/rv32im/circuit/src/auipc/core.rs index 8ec9e274f6..7ec7fdeeaf 100644 --- a/extensions/rv32im/circuit/src/auipc/core.rs +++ b/extensions/rv32im/circuit/src/auipc/core.rs @@ -1,17 +1,26 @@ use std::{ - array, + array::{self, from_fn}, borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_REGISTER_AS, + LocalOpcode, +}; use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -19,11 +28,10 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; - -use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; +use crate::adapters::{ + Rv32RdWriteAdapterExecutor, Rv32RdWriteAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -36,7 +44,7 @@ pub struct Rv32AuipcCoreCols { pub rd_data: [T; RV32_REGISTER_NUM_LIMBS], } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, derive_new::new)] pub struct Rv32AuipcCoreAir { pub bus: BitwiseOperationLookupBus, } @@ -186,116 +194,205 @@ where } #[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32AuipcCoreRecord { - pub imm_limbs: [F; RV32_REGISTER_NUM_LIMBS - 1], - pub pc_limbs: [F; RV32_REGISTER_NUM_LIMBS - 2], - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS], +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Rv32AuipcCoreRecord { + pub from_pc: u32, + pub imm: u32, } -pub struct Rv32AuipcCoreChip { - pub air: Rv32AuipcCoreAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32AuipcExecutor { + adapter: A, } -impl Rv32AuipcCoreChip { - pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip) -> Self { - Self { - air: Rv32AuipcCoreAir { - bus: bitwise_lookup_chip.bus(), - }, - bitwise_lookup_chip, - } - } +#[derive(Clone, derive_new::new)] +pub struct Rv32AuipcFiller { + adapter: A, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl> VmCoreChip for Rv32AuipcCoreChip +impl PreflightExecutor for Rv32AuipcExecutor where - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceExecutor, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut Rv32AuipcCoreRecord), + >, { - type Record = Rv32AuipcCoreRecord; - type Air = Rv32AuipcCoreAir; + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", AUIPC) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - from_pc: u32, - _reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = Rv32AuipcOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32AuipcOpcode::CLASS_OFFSET), - ); - let imm = instruction.c.as_canonical_u32(); - let rd_data = run_auipc(local_opcode, from_pc, imm); - let rd_data_field = rd_data.map(F::from_canonical_u32); - - let output = AdapterRuntimeContext::without_pc([rd_data_field]); - - let imm_limbs = array::from_fn(|i| (imm >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX); - let pc_limbs: [u32; RV32_REGISTER_NUM_LIMBS] = - array::from_fn(|i| (from_pc >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX); + ) -> Result<(), ExecutionError> { + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); - for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) { + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.from_pc = *state.pc; + core_record.imm = instruction.c.as_canonical_u32(); + + let rd = run_auipc(*state.pc, core_record.imm); + + self.adapter + .write(state.memory, instruction, rd, &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for Rv32AuipcFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &Rv32AuipcCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut Rv32AuipcCoreCols = core_row.borrow_mut(); + + let imm_limbs = record.imm.to_le_bytes(); + let pc_limbs = record.from_pc.to_le_bytes(); + let rd_data = run_auipc(record.from_pc, record.imm); + debug_assert_eq!(imm_limbs[3], 0); + + // range checks: + // hardcoding for performance: first 3 limbs of imm_limbs, last 3 limbs of pc_limbs where + // most significant limb of pc_limbs is shifted up + self.bitwise_lookup_chip + .request_range(imm_limbs[0] as u32, imm_limbs[1] as u32); + self.bitwise_lookup_chip + .request_range(imm_limbs[2] as u32, pc_limbs[1] as u32); + let msl_shift = RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - PC_BITS; + self.bitwise_lookup_chip + .request_range(pc_limbs[2] as u32, (pc_limbs[3] as u32) << msl_shift); + for pair in rd_data.chunks_exact(2) { self.bitwise_lookup_chip - .request_range(rd_data[i * 2], rd_data[i * 2 + 1]); + .request_range(pair[0] as u32, pair[1] as u32); } + // Writing in reverse order + core_row.rd_data = rd_data.map(F::from_canonical_u8); + // only the middle 2 limbs: + core_row.pc_limbs = from_fn(|i| F::from_canonical_u8(pc_limbs[i + 1])); + core_row.imm_limbs = from_fn(|i| F::from_canonical_u8(imm_limbs[i])); - let mut need_range_check: Vec = Vec::new(); - for limb in imm_limbs { - need_range_check.push(limb); - } + core_row.is_valid = F::ONE; + } +} - for (i, limb) in pc_limbs.iter().enumerate().skip(1) { - if i == pc_limbs.len() - 1 { - need_range_check.push((*limb) << (pc_limbs.len() * RV32_CELL_BITS - PC_BITS)); - } else { - need_range_check.push(*limb); - } - } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct AuiPcPreCompute { + imm: u32, + a: u8, +} - for pair in need_range_check.chunks(2) { - self.bitwise_lookup_chip.request_range(pair[0], pair[1]); - } +impl Executor for Rv32AuipcExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } - Ok(( - output, - Self::Record { - imm_limbs: imm_limbs.map(F::from_canonical_u32), - pc_limbs: array::from_fn(|i| F::from_canonical_u32(pc_limbs[i + 1])), - rd_data: rd_data.map(F::from_canonical_u32), - }, - )) + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut AuiPcPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(|pre_compute, vm_state| { + let pre_compute: &AuiPcPreCompute = pre_compute.borrow(); + unsafe { + execute_e1_impl(pre_compute, vm_state); + } + }) } +} + +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &AuiPcPreCompute, + vm_state: &mut VmExecState, +) { + let rd = run_auipc(vm_state.pc, pre_compute.imm); + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32AuipcOpcode::from_usize(opcode - Rv32AuipcOpcode::CLASS_OFFSET) - ) +impl MeteredExecutor for Rv32AuipcExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut Rv32AuipcCoreCols = row_slice.borrow_mut(); - core_cols.imm_limbs = record.imm_limbs; - core_cols.pc_limbs = record.pc_limbs; - core_cols.rd_data = record.rd_data; - core_cols.is_valid = F::ONE; + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(|pre_compute, vm_state| { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + unsafe { + execute_e1_impl(&pre_compute.data, vm_state); + } + }) } +} - fn air(&self) -> &Self::Air { - &self.air +impl Rv32AuipcExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut AuiPcPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { a, c: imm, d, .. } = inst; + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + let imm = imm.as_canonical_u32(); + let data: &mut AuiPcPreCompute = data.borrow_mut(); + *data = AuiPcPreCompute { + imm, + a: a.as_canonical_u32() as u8, + }; + Ok(()) } } // returns rd_data -pub(super) fn run_auipc( - _opcode: Rv32AuipcOpcode, - pc: u32, - imm: u32, -) -> [u32; RV32_REGISTER_NUM_LIMBS] { +#[inline(always)] +pub(super) fn run_auipc(pc: u32, imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { let rd = pc.wrapping_add(imm << RV32_CELL_BITS); - array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX) + rd.to_le_bytes() } diff --git a/extensions/rv32im/circuit/src/auipc/mod.rs b/extensions/rv32im/circuit/src/auipc/mod.rs index 6e2234bfbd..dfe71956bd 100644 --- a/extensions/rv32im/circuit/src/auipc/mod.rs +++ b/extensions/rv32im/circuit/src/auipc/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use crate::adapters::Rv32RdWriteAdapterChip; +use crate::adapters::Rv32RdWriteAdapterAir; mod core; pub use core::*; @@ -8,4 +8,5 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32AuipcChip = VmChipWrapper, Rv32AuipcCoreChip>; +pub type Rv32AuipcAir = VmAirWrapper; +pub type Rv32AuipcChip = VmChipWrapper; diff --git a/extensions/rv32im/circuit/src/auipc/tests.rs b/extensions/rv32im/circuit/src/auipc/tests.rs index 2c8a399198..be80b756ba 100644 --- a/extensions/rv32im/circuit/src/auipc/tests.rs +++ b/extensions/rv32im/circuit/src/auipc/tests.rs @@ -1,52 +1,92 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, sync::Arc}; -use openvm_circuit::arch::{testing::VmChipTestBuilder, VmAdapterChip}; +use openvm_circuit::arch::{ + testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + Arena, DenseRecordArena, EmptyAdapterCoreLayout, PreflightExecutor, VmAirWrapper, + VmChipWrapper, +}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *}; use openvm_stark_backend::{ - interaction::BusIndex, p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{run_auipc, Rv32AuipcChip, Rv32AuipcCoreChip, Rv32AuipcCoreCols}; -use crate::adapters::{Rv32RdWriteAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::{run_auipc, Rv32AuipcChip, Rv32AuipcCoreAir, Rv32AuipcCoreCols, Rv32AuipcExecutor}; +use crate::{ + adapters::{ + Rv32RdWriteAdapterAir, Rv32RdWriteAdapterExecutor, Rv32RdWriteAdapterFiller, + Rv32RdWriteAdapterRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + test_utils::get_verification_error, + Rv32AuipcAir, Rv32AuipcCoreRecord, Rv32AuipcFiller, +}; const IMM_BITS: usize = 24; -const BITWISE_OP_LOOKUP_BUS: BusIndex = 9; - +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +type Harness = TestChipHarness, RA>; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = VmAirWrapper::new( + Rv32RdWriteAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + Rv32AuipcCoreAir::new(bitwise_bus), + ); + let executor = Rv32AuipcExecutor::new(Rv32RdWriteAdapterExecutor::new()); + let chip = VmChipWrapper::::new( + Rv32AuipcFiller::new(Rv32RdWriteAdapterFiller::new(), bitwise_chip.clone()), + tester.memory_helper(), + ); + let harness = Harness::::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} -fn set_and_execute( +fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut Rv32AuipcChip, + harness: &mut Harness, rng: &mut StdRng, opcode: Rv32AuipcOpcode, imm: Option, initial_pc: Option, -) { +) where + Rv32AuipcExecutor: PreflightExecutor, +{ let imm = imm.unwrap_or(rng.gen_range(0..(1 << IMM_BITS))) as usize; let a = rng.gen_range(0..32) << 2; tester.execute_with_pc( - chip, + harness, &Instruction::from_usize(opcode.global_opcode(), [a, 0, imm, 1, 0]), initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), ); let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); - - let rd_data = run_auipc(opcode, initial_pc, imm as u32); - - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + let rd_data = run_auipc(initial_pc, imm as u32); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -59,24 +99,18 @@ fn set_and_execute( #[test] fn rand_auipc_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32AuipcCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32AuipcChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut harness, bitwise) = create_test_chip(&tester); let num_tests: usize = 100; for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, AUIPC, None, None); + set_and_execute(&mut tester, &mut harness, &mut rng, AUIPC, None, None); } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } @@ -84,75 +118,62 @@ fn rand_auipc_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct AuipcPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub imm_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, + pub pc_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 2]>, +} + fn run_negative_auipc_test( opcode: Rv32AuipcOpcode, initial_imm: Option, initial_pc: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - imm_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, - pc_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 2]>, - expected_error: VerificationError, + prank_vals: AuipcPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let core = Rv32AuipcCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32AuipcChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut harness, bitwise) = create_test_chip(&tester); set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, opcode, initial_imm, initial_pc, ); - let tester = tester.build(); - - let auipc_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let auipc_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = auipc_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (_, core_row) = trace_row.split_at_mut(adapter_width); - let core_cols: &mut Rv32AuipcCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(data) = imm_limbs { + if let Some(data) = prank_vals.imm_limbs { core_cols.imm_limbs = data.map(F::from_canonical_u32); } - - if let Some(data) = pc_limbs { + if let Some(data) = prank_vals.pc_limbs { core_cols.pc_limbs = data.map(F::from_canonical_u32); } - *auipc_trace = RowMajorMatrix::new(trace_row, auipc_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; + disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) - .load(bitwise_chip) + .build() + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -161,47 +182,53 @@ fn invalid_limb_negative_tests() { AUIPC, Some(9722891), None, - None, - Some([107, 46, 81]), - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + imm_limbs: Some([107, 46, 81]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, Some(0), Some(2110400), - Some([194, 51, 32, 240]), - None, - Some([51, 32]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([194, 51, 32, 240]), + pc_limbs: Some([51, 32]), + ..Default::default() + }, + true, ); run_negative_auipc_test( AUIPC, None, None, - None, - None, - Some([206, 166]), - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + pc_limbs: Some([206, 166]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, None, None, - Some([30, 92, 82, 132]), - None, - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + rd_data: Some([30, 92, 82, 132]), + ..Default::default() + }, + false, ); - run_negative_auipc_test( AUIPC, None, Some(876487877), - Some([197, 202, 49, 70]), - Some([166, 243, 17]), - Some([36, 62]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([197, 202, 49, 70]), + imm_limbs: Some([166, 243, 17]), + pc_limbs: Some([36, 62]), + }, + true, ); } @@ -211,37 +238,42 @@ fn overflow_negative_tests() { AUIPC, Some(256264), None, - None, - Some([3592, 219, 3]), - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + imm_limbs: Some([3592, 219, 3]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, None, None, - None, - None, - Some([0, 0]), - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + pc_limbs: Some([0, 0]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, Some(255), None, - None, - Some([F::NEG_ONE.as_canonical_u32(), 1, 0]), - None, - VerificationError::ChallengePhaseError, + AuipcPrankValues { + imm_limbs: Some([F::NEG_ONE.as_canonical_u32(), 1, 0]), + ..Default::default() + }, + true, ); run_negative_auipc_test( AUIPC, Some(0), Some(255), - Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), - Some([0, 0, 0]), - Some([1, 0]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), + imm_limbs: Some([0, 0, 0]), + pc_limbs: Some([1, 0]), + }, + true, ); } @@ -251,33 +283,54 @@ fn overflow_negative_tests() { /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32AuipcCoreChip::new(bitwise_chip); - let mut chip = Rv32AuipcChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, AUIPC, None, None); - } -} - #[test] fn run_auipc_sanity_test() { - let opcode = AUIPC; let initial_pc = 234567890; let imm = 11302451; - let rd_data = run_auipc(opcode, initial_pc, imm); + let rd_data = run_auipc(initial_pc, imm); assert_eq!(rd_data, [210, 107, 113, 186]); } + +// //////////////////////////////////////////////////////////////////////////////////// +// DENSE TESTS + +// Ensure that the chip works as expected with dense records. +// We first execute some instructions with a [DenseRecordArena] and transfer the records +// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +// all the constraints pass. +// //////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_harness, bitwise) = create_test_chip(&tester); + + { + let mut dense_harness = create_test_chip::(&tester).0; + + let num_ops: usize = 100; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut dense_harness, &mut rng, AUIPC, None, None); + } + + type Record<'a> = ( + &'a mut Rv32RdWriteAdapterRecord, + &'a mut Rv32AuipcCoreRecord, + ); + + let mut record_interpreter = dense_harness.arena.get_record_seeker::(); + record_interpreter.transfer_to_matrix_arena( + &mut sparse_harness.arena, + EmptyAdapterCoreLayout::::new(), + ); + } + + let tester = tester + .build() + .load(sparse_harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} diff --git a/extensions/rv32im/circuit/src/base_alu/core.rs b/extensions/rv32im/circuit/src/base_alu/core.rs index a87418cc91..bf2c1e1e89 100644 --- a/extensions/rv32im/circuit/src/base_alu/core.rs +++ b/extensions/rv32im/circuit/src/base_alu/core.rs @@ -1,18 +1,28 @@ use std::{ array, borrow::{Borrow, BorrowMut}, + iter::zip, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::BaseAluOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,12 +30,12 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; +use crate::adapters::imm_to_bytes; + #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct BaseAluCoreCols { pub a: [T; NUM_LIMBS], pub b: [T; NUM_LIMBS], @@ -38,10 +48,10 @@ pub struct BaseAluCoreCols { pub opcode_and_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct BaseAluCoreAir { pub bus: BitwiseOperationLookupBus, - offset: usize, + pub offset: usize, } impl BaseAir @@ -165,175 +175,397 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct BaseAluCoreRecord { - pub opcode: BaseAluOpcode, - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], +#[repr(C, align(4))] +#[derive(AlignedBytesBorrow, Debug)] +pub struct BaseAluCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + // Use u8 instead of usize for better packing + pub local_opcode: u8, } -pub struct BaseAluCoreChip { - pub air: BaseAluCoreAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +#[derive(Clone, Copy, derive_new::new)] +pub struct BaseAluExecutor { + adapter: A, + pub offset: usize, } -impl BaseAluCoreChip { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: BaseAluCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, - bitwise_lookup_chip, - } - } +#[derive(derive_new::new)] +pub struct BaseAluFiller { + adapter: A, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, } -impl VmCoreChip - for BaseAluCoreChip +impl PreflightExecutor + for BaseAluExecutor where F: PrimeField32, - I: VmAdapterInterface, - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + A: 'static + + AdapterTraceExecutor< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut BaseAluCoreRecord), + >, { - type Record = BaseAluCoreRecord; - type Air = BaseAluCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", BaseAluOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { let Instruction { opcode, .. } = instruction; - let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let a = run_alu::(local_opcode, &b, &c); + let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); - let output = AdapterRuntimeContext { - to_pc: None, - writes: [a.map(F::from_canonical_u32)].into(), - }; + A::start(*state.pc, state.memory, &mut adapter_record); + + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let rd = run_alu::(local_opcode, &core_record.b, &core_record.c); + + core_record.local_opcode = local_opcode as u8; + + self.adapter + .write(state.memory, instruction, [rd].into(), &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for BaseAluFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &BaseAluCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut BaseAluCoreCols = core_row.borrow_mut(); + // SAFETY: the following is highly unsafe. We are going to cast `core_row` to a record + // buffer, and then do an _overlapping_ write to the `core_row` as a row of field elements. + // This requires: + // - Cols and Record structs should be repr(C) and we write in reverse order (to ensure + // non-overlapping) + // - Do not overwrite any reference in `record` before it has already been used or moved + // - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic + // otherwise) + + let local_opcode = BaseAluOpcode::from_usize(record.local_opcode as usize); + let a = run_alu::(local_opcode, &record.b, &record.c); + // PERF: needless conversion + core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND); + core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR); + core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR); + core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB); + core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD); if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB { for a_val in a { - self.bitwise_lookup_chip.request_xor(a_val, a_val); + self.bitwise_lookup_chip + .request_xor(a_val as u32, a_val as u32); } } else { - for (b_val, c_val) in b.iter().zip(c.iter()) { - self.bitwise_lookup_chip.request_xor(*b_val, *c_val); + for (b_val, c_val) in zip(record.b, record.c) { + self.bitwise_lookup_chip + .request_xor(b_val as u32, c_val as u32); } } + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BaseAluPreCompute { + c: u32, + a: u8, + b: u8, +} + +impl Executor + for BaseAluExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } - let record = Self::Record { - opcode: local_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BaseAluPreCompute = data.borrow_mut(); + let is_imm = self.pre_compute_impl(pc, inst, data)?; + let opcode = inst.opcode; + + let fn_ptr = match ( + is_imm, + BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)), + ) { + (true, BaseAluOpcode::ADD) => execute_e1_impl::<_, _, true, AddOp>, + (false, BaseAluOpcode::ADD) => execute_e1_impl::<_, _, false, AddOp>, + (true, BaseAluOpcode::SUB) => execute_e1_impl::<_, _, true, SubOp>, + (false, BaseAluOpcode::SUB) => execute_e1_impl::<_, _, false, SubOp>, + (true, BaseAluOpcode::XOR) => execute_e1_impl::<_, _, true, XorOp>, + (false, BaseAluOpcode::XOR) => execute_e1_impl::<_, _, false, XorOp>, + (true, BaseAluOpcode::OR) => execute_e1_impl::<_, _, true, OrOp>, + (false, BaseAluOpcode::OR) => execute_e1_impl::<_, _, false, OrOp>, + (true, BaseAluOpcode::AND) => execute_e1_impl::<_, _, true, AndOp>, + (false, BaseAluOpcode::AND) => execute_e1_impl::<_, _, false, AndOp>, }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BaseAluPreCompute, + vm_state: &mut VmExecState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = if IS_IMM { + pre_compute.c.to_le_bytes() + } else { + vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c) + }; + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + let rd = ::compute(rs1, rs2); + let rd = rd.to_le_bytes(); + vm_state.vm_write::(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} - Ok((output, record)) +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &BaseAluPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl MeteredExecutor + for BaseAluExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", BaseAluOpcode::from_usize(opcode - self.air.offset)) + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?; + let opcode = inst.opcode; + + let fn_ptr = match ( + is_imm, + BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)), + ) { + (true, BaseAluOpcode::ADD) => execute_e2_impl::<_, _, true, AddOp>, + (false, BaseAluOpcode::ADD) => execute_e2_impl::<_, _, false, AddOp>, + (true, BaseAluOpcode::SUB) => execute_e2_impl::<_, _, true, SubOp>, + (false, BaseAluOpcode::SUB) => execute_e2_impl::<_, _, false, SubOp>, + (true, BaseAluOpcode::XOR) => execute_e2_impl::<_, _, true, XorOp>, + (false, BaseAluOpcode::XOR) => execute_e2_impl::<_, _, false, XorOp>, + (true, BaseAluOpcode::OR) => execute_e2_impl::<_, _, true, OrOp>, + (false, BaseAluOpcode::OR) => execute_e2_impl::<_, _, false, OrOp>, + (true, BaseAluOpcode::AND) => execute_e2_impl::<_, _, true, AndOp>, + (false, BaseAluOpcode::AND) => execute_e2_impl::<_, _, false, AndOp>, + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.opcode_add_flag = F::from_bool(record.opcode == BaseAluOpcode::ADD); - row_slice.opcode_sub_flag = F::from_bool(record.opcode == BaseAluOpcode::SUB); - row_slice.opcode_xor_flag = F::from_bool(record.opcode == BaseAluOpcode::XOR); - row_slice.opcode_or_flag = F::from_bool(record.opcode == BaseAluOpcode::OR); - row_slice.opcode_and_flag = F::from_bool(record.opcode == BaseAluOpcode::AND); +impl BaseAluExecutor { + /// Return `is_imm`, true if `e` is RV32_IMM_AS. + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BaseAluPreCompute, + ) -> Result { + let Instruction { a, b, c, d, e, .. } = inst; + let e_u32 = e.as_canonical_u32(); + if (d.as_canonical_u32() != RV32_REGISTER_AS) + || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS) + { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + let is_imm = e_u32 == RV32_IMM_AS; + let c_u32 = c.as_canonical_u32(); + *data = BaseAluPreCompute { + c: if is_imm { + u32::from_le_bytes(imm_to_bytes(c_u32)) + } else { + c_u32 + }, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok(is_imm) } +} - fn air(&self) -> &Self::Air { - &self.air +trait AluOp { + fn compute(rs1: u32, rs2: u32) -> u32; +} +struct AddOp; +struct SubOp; +struct XorOp; +struct OrOp; +struct AndOp; +impl AluOp for AddOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1.wrapping_add(rs2) + } +} +impl AluOp for SubOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1.wrapping_sub(rs2) + } +} +impl AluOp for XorOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1 ^ rs2 + } +} +impl AluOp for OrOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1 | rs2 + } +} +impl AluOp for AndOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1 & rs2 } } +#[inline(always)] pub(super) fn run_alu( opcode: BaseAluOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + debug_assert!(LIMB_BITS <= 8, "specialize for bytes"); match opcode { BaseAluOpcode::ADD => run_add::(x, y), BaseAluOpcode::SUB => run_subtract::(x, y), - BaseAluOpcode::XOR => run_xor::(x, y), - BaseAluOpcode::OR => run_or::(x, y), - BaseAluOpcode::AND => run_and::(x, y), + BaseAluOpcode::XOR => run_xor::(x, y), + BaseAluOpcode::OR => run_or::(x, y), + BaseAluOpcode::AND => run_and::(x, y), } } +#[inline(always)] fn run_add( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { - let mut z = [0u32; NUM_LIMBS]; - let mut carry = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + let mut z = [0u8; NUM_LIMBS]; + let mut carry = [0u8; NUM_LIMBS]; for i in 0..NUM_LIMBS { - z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 }; - carry[i] = z[i] >> LIMB_BITS; - z[i] &= (1 << LIMB_BITS) - 1; + let mut overflow = + (x[i] as u16) + (y[i] as u16) + if i > 0 { carry[i - 1] as u16 } else { 0 }; + carry[i] = (overflow >> LIMB_BITS) as u8; + overflow &= (1u16 << LIMB_BITS) - 1; + z[i] = overflow as u8; } z } +#[inline(always)] fn run_subtract( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { - let mut z = [0u32; NUM_LIMBS]; - let mut carry = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + let mut z = [0u8; NUM_LIMBS]; + let mut carry = [0u8; NUM_LIMBS]; for i in 0..NUM_LIMBS { - let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 }; - if x[i] >= rhs { - z[i] = x[i] - rhs; + let rhs = y[i] as u16 + if i > 0 { carry[i - 1] as u16 } else { 0 }; + if x[i] as u16 >= rhs { + z[i] = x[i] - rhs as u8; carry[i] = 0; } else { - z[i] = x[i] + (1 << LIMB_BITS) - rhs; + z[i] = (x[i] as u16 + (1u16 << LIMB_BITS) - rhs) as u8; carry[i] = 1; } } z } -fn run_xor( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_xor(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] ^ y[i]) } -fn run_or( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_or(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] | y[i]) } -fn run_and( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_and(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] & y[i]) } diff --git a/extensions/rv32im/circuit/src/base_alu/mod.rs b/extensions/rv32im/circuit/src/base_alu/mod.rs index cbda8ce555..f6ee100c9b 100644 --- a/extensions/rv32im/circuit/src/base_alu/mod.rs +++ b/extensions/rv32im/circuit/src/base_alu/mod.rs @@ -1,7 +1,9 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32BaseAluAdapterChip; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -9,8 +11,18 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32BaseAluAir = + VmAirWrapper>; +pub type Rv32BaseAluExecutor = BaseAluExecutor< + Rv32BaseAluAdapterExecutor, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, +>; pub type Rv32BaseAluChip = VmChipWrapper< F, - Rv32BaseAluAdapterChip, - BaseAluCoreChip, + BaseAluFiller< + Rv32BaseAluAdapterFiller, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >, >; diff --git a/extensions/rv32im/circuit/src/base_alu/tests.rs b/extensions/rv32im/circuit/src/base_alu/tests.rs index 165cd12526..bc1880953c 100644 --- a/extensions/rv32im/circuit/src/base_alu/tests.rs +++ b/extensions/rv32im/circuit/src/base_alu/tests.rs @@ -1,44 +1,119 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut, sync::Arc}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, ExecutionState, - MinimalInstruction, Result, VmAdapterChip, VmAdapterInterface, VmChipWrapper, - }, - system::memory::{MemoryController, OfflineMemory}, - utils::generate_long_number, -}; +use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::BaseAluOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::BaseAluOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_alu, BaseAluCoreChip, Rv32BaseAluChip}; +use super::{core::run_alu, BaseAluCoreAir, Rv32BaseAluChip, Rv32BaseAluExecutor}; use crate::{ adapters::{ - Rv32BaseAluAdapterAir, Rv32BaseAluAdapterChip, Rv32BaseAluReadRecord, - Rv32BaseAluWriteRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, + RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }, base_alu::BaseAluCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, + BaseAluFiller, Rv32BaseAluAir, }; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +type Harness = TestChipHarness>; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = Rv32BaseAluAir::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), + ); + let executor = Rv32BaseAluExecutor::new( + Rv32BaseAluAdapterExecutor::new(), + BaseAluOpcode::CLASS_OFFSET, + ); + let chip = Rv32BaseAluChip::new( + BaseAluFiller::new( + Rv32BaseAluAdapterFiller::new(bitwise_chip.clone()), + bitwise_chip.clone(), + BaseAluOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} + +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: BaseAluOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) + } else { + generate_rv32_is_type_immediate(rng) + }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(harness, &instruction); + + let a = run_alu::(opcode, &b, &c) + .map(F::from_canonical_u8); + assert_eq!(a, tester.read::(1, rd)) +} ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS @@ -47,227 +122,266 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// -fn run_rv32_alu_rand_test(opcode: BaseAluOpcode, num_ops: usize) { +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test(opcode: BaseAluOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise) = create_test_chip(&tester); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) - } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) - }; + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); - let (instruction, rd) = rv32_rand_write_register_or_imm( + for _ in 0..num_ops { + set_and_execute( &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), + &mut harness, &mut rng, + opcode, + None, + None, + None, ); - tester.execute(&mut chip, &instruction); - - let a = run_alu::(opcode, &b, &c) - .map(F::from_canonical_u32); - assert_eq!(a, tester.read::(1, rd)) } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_alu_add_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::ADD, 100); -} +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test_persistent(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); -#[test] -fn rv32_alu_sub_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::SUB, 100); -} + let mut tester = VmChipTestBuilder::default_persistent(); + let (mut harness, bitwise) = create_test_chip(&tester); -#[test] -fn rv32_alu_xor_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::XOR, 100); -} + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); -#[test] -fn rv32_alu_or_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::OR, 100); -} + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + None, + None, + None, + ); + } -#[test] -fn rv32_alu_and_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::AND, 100); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); } ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BaseAluTestChip = - VmChipWrapper, BaseAluCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_alu_negative_test( +fn run_negative_alu_test( opcode: BaseAluOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + prank_opcode_flags: Option<[bool; 5]>, + is_imm: Option, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise) = create_test_chip(&tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + &mut rng, + opcode, + Some(b), + is_imm, + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - - if (opcode == BaseAluOpcode::ADD || opcode == BaseAluOpcode::SUB) - && a.iter().all(|&a_val| a_val < (1 << RV32_CELL_BITS)) - { - bitwise_chip.clear(); - for a_val in a { - bitwise_chip.request_xor(a_val, a_val); - } - } - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut BaseAluCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + if let Some(prank_c) = prank_c { + cols.c = prank_c.map(F::from_canonical_u32); + } + if let Some(prank_opcode_flags) = prank_opcode_flags { + cols.opcode_add_flag = F::from_bool(prank_opcode_flags[0]); + cols.opcode_and_flag = F::from_bool(prank_opcode_flags[1]); + cols.opcode_or_flag = F::from_bool(prank_opcode_flags[2]); + cols.opcode_sub_flag = F::from_bool(prank_opcode_flags[3]); + cols.opcode_xor_flag = F::from_bool(prank_opcode_flags[4]); + } + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_alu_add_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::ADD, + run_negative_alu_test( + ADD, [246, 0, 0, 0], [250, 0, 0, 0], [250, 0, 0, 0], + None, + None, + None, false, ); } #[test] fn rv32_alu_add_out_of_range_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::ADD, + run_negative_alu_test( + ADD, [500, 0, 0, 0], [250, 0, 0, 0], [250, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_sub_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::SUB, + run_negative_alu_test( + SUB, [255, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], + None, + None, + None, false, ); } #[test] fn rv32_alu_sub_out_of_range_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::SUB, + run_negative_alu_test( + SUB, [F::NEG_ONE.as_canonical_u32(), 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_xor_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::XOR, + run_negative_alu_test( + XOR, [255, 255, 255, 255], [0, 0, 1, 0], [255, 255, 255, 255], + None, + None, + None, true, ); } #[test] fn rv32_alu_or_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::OR, + run_negative_alu_test( + OR, [255, 255, 255, 255], [255, 255, 255, 254], [0, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_and_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::AND, + run_negative_alu_test( + AND, [255, 255, 255, 255], [0, 0, 1, 0], [0, 0, 0, 0], + None, + None, + None, true, ); } +#[test] +fn rv32_alu_adapter_unconstrained_imm_limb_test() { + run_negative_alu_test( + ADD, + [255, 7, 0, 0], + [0, 0, 0, 0], + [255, 7, 0, 0], + Some([511, 6, 0, 0]), + None, + Some(true), + true, + ); +} + +#[test] +fn rv32_alu_adapter_unconstrained_rs2_read_test() { + run_negative_alu_test( + ADD, + [2, 2, 2, 2], + [1, 1, 1, 1], + [1, 1, 1, 1], + None, + Some([false, false, false, false, false]), + Some(false), + false, + ); +} + /////////////////////////////////////////////////////////////////////////////////////// /// SANITY TESTS /// @@ -276,10 +390,10 @@ fn rv32_alu_and_wrong_negative_test() { #[test] fn run_add_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; - let result = run_alu::(BaseAluOpcode::ADD, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; + let result = run_alu::(ADD, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -287,10 +401,10 @@ fn run_add_sanity_test() { #[test] fn run_sub_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; - let result = run_alu::(BaseAluOpcode::SUB, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; + let result = run_alu::(SUB, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -298,10 +412,10 @@ fn run_sub_sanity_test() { #[test] fn run_xor_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; - let result = run_alu::(BaseAluOpcode::XOR, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; + let result = run_alu::(XOR, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -309,10 +423,10 @@ fn run_xor_sanity_test() { #[test] fn run_or_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; - let result = run_alu::(BaseAluOpcode::OR, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; + let result = run_alu::(OR, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -320,195 +434,11 @@ fn run_or_sanity_test() { #[test] fn run_and_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; - let result = run_alu::(BaseAluOpcode::AND, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; + let result = run_alu::(AND, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } } - -////////////////////////////////////////////////////////////////////////////////////// -// ADAPTER TESTS -// -// Ensure that the adapter is correct. -////////////////////////////////////////////////////////////////////////////////////// - -// A pranking chip where `preprocess` can have `rs2` limbs that overflow. -struct Rv32BaseAluAdapterTestChip(Rv32BaseAluAdapterChip); - -impl VmAdapterChip for Rv32BaseAluAdapterTestChip { - type ReadRecord = Rv32BaseAluReadRecord; - type WriteRecord = Rv32BaseAluWriteRecord; - type Air = Rv32BaseAluAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; - - let rs1 = memory.read::(d, b); - let (rs2, rs2_data, rs2_imm) = if e.is_zero() { - let c_u32 = c.as_canonical_u32(); - memory.increment_timestamp(); - let mask1 = (1 << 9) - 1; - let mask2 = (1 << 3) - 2; - ( - None, - [ - (c_u32 & mask1) as u16, - ((c_u32 >> 8) & mask2) as u16, - (c_u32 >> 16) as u16, - (c_u32 >> 16) as u16, - ] - .map(F::from_canonical_u16), - c, - ) - } else { - let rs2_read = memory.read::(e, c); - (Some(rs2_read.0), rs2_read.1, F::ZERO) - }; - - Ok(( - [rs1.1, rs2_data], - Self::ReadRecord { - rs1: rs1.0, - rs2, - rs2_imm, - }, - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - self.0 - .postprocess(memory, instruction, from_state, output, _read_record) - } - - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - self.0 - .generate_trace_row(row_slice, read_record, write_record, memory) - } - - fn air(&self) -> &Self::Air { - self.0.air() - } -} - -#[test] -fn rv32_alu_adapter_unconstrained_imm_limb_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = VmChipWrapper::new( - Rv32BaseAluAdapterTestChip(Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - )), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - let b = [0, 0, 0, 0]; - let (c_imm, c) = { - let imm = (1 << 11) - 1; - let fake_c = [(1 << 9) - 1, (1 << 3) - 2, 0, 0]; - (Some(imm), fake_c) - }; - - let (instruction, _rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - BaseAluOpcode::ADD.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - disable_debug_builder(); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test_with_expected_error(VerificationError::ChallengePhaseError); -} - -#[test] -fn rv32_alu_adapter_unconstrained_rs2_read_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - let b = [1, 1, 1, 1]; - let c = [1, 1, 1, 1]; - let (instruction, _rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - None, - BaseAluOpcode::ADD.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - - let modify_trace = |trace: &mut DenseMatrix| { - let mut values = trace.row_slice(0).to_vec(); - let mut dummy_values = values.clone(); - let cols: &mut BaseAluCoreCols = - dummy_values.split_at_mut(adapter_width).1.borrow_mut(); - cols.opcode_add_flag = F::ZERO; - values.extend(dummy_values); - *trace = RowMajorMatrix::new(values, trace_width); - }; - - disable_debug_builder(); - let tester = tester - .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) - .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); -} diff --git a/extensions/rv32im/circuit/src/branch_eq/core.rs b/extensions/rv32im/circuit/src/branch_eq/core.rs index bb04d86ee5..6a1e680ee3 100644 --- a/extensions/rv32im/circuit/src/branch_eq/core.rs +++ b/extensions/rv32im/circuit/src/branch_eq/core.rs @@ -1,15 +1,17 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, -}; +use std::borrow::{Borrow, BorrowMut}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::utils::not; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode, +}; use openvm_rv32im_transpiler::BranchEqualOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -17,8 +19,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -37,7 +37,7 @@ pub struct BranchEqualCoreCols { pub diff_inv_marker: [T; NUM_LIMBS], } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct BranchEqualCoreAir { offset: usize, pc_step: u32, @@ -135,117 +135,273 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BranchEqualCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - pub cmp_result: T, - pub imm: T, - pub diff_inv_val: T, - pub diff_idx: usize, - pub opcode: BranchEqualOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct BranchEqualCoreRecord { + pub a: [u8; NUM_LIMBS], + pub b: [u8; NUM_LIMBS], + pub imm: u32, + pub local_opcode: u8, +} + +#[derive(Clone, Copy, derive_new::new)] +pub struct BranchEqualExecutor { + adapter: A, + pub offset: usize, + pub pc_step: u32, } -#[derive(Debug)] -pub struct BranchEqualCoreChip { - pub air: BranchEqualCoreAir, +#[derive(Clone, Copy, derive_new::new)] +pub struct BranchEqualFiller { + adapter: A, + pub offset: usize, + pub pc_step: u32, } -impl BranchEqualCoreChip { - pub fn new(offset: usize, pc_step: u32) -> Self { - Self { - air: BranchEqualCoreAir { offset, pc_step }, +impl PreflightExecutor + for BranchEqualExecutor +where + F: PrimeField32, + A: 'static + AdapterTraceExecutor, WriteData = ()>, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut BranchEqualCoreRecord, + ), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", BranchEqualOpcode::from_usize(opcode - self.offset)) + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { opcode, c: imm, .. } = instruction; + + let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.a = rs1; + core_record.b = rs2; + core_record.imm = imm.as_canonical_u32(); + core_record.local_opcode = branch_eq_opcode as u8; + + if fast_run_eq(branch_eq_opcode, &rs1, &rs2) { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); } + + Ok(()) } } -impl, const NUM_LIMBS: usize> VmCoreChip - for BranchEqualCoreChip +impl TraceFiller for BranchEqualFiller where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: Default, + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &BranchEqualCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut BranchEqualCoreCols = core_row.borrow_mut(); + + let (cmp_result, diff_idx, diff_inv_val) = run_eq::( + record.local_opcode == BranchEqualOpcode::BEQ as u8, + &record.a, + &record.b, + ); + core_row.diff_inv_marker = [F::ZERO; NUM_LIMBS]; + core_row.diff_inv_marker[diff_idx] = diff_inv_val; + + core_row.opcode_bne_flag = + F::from_bool(record.local_opcode == BranchEqualOpcode::BNE as u8); + core_row.opcode_beq_flag = + F::from_bool(record.local_opcode == BranchEqualOpcode::BEQ as u8); + + core_row.imm = F::from_canonical_u32(record.imm); + core_row.cmp_result = F::from_bool(cmp_result); + + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = record.a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchEqualPreCompute { + imm: isize, + a: u8, + b: u8, +} + +impl Executor for BranchEqualExecutor +where + F: PrimeField32, { - type Record = BranchEqualCoreRecord; - type Air = BranchEqualCoreAir; + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } - #[allow(clippy::type_complexity)] - fn execute_instruction( + #[inline(always)] + fn pre_compute( &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, c: imm, .. } = *instruction; - let branch_eq_opcode = - BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let x = data[0].map(|x| x.as_canonical_u32()); - let y = data[1].map(|y| y.as_canonical_u32()); - let (cmp_result, diff_idx, diff_inv_val) = run_eq::(branch_eq_opcode, &x, &y); - - let output = AdapterRuntimeContext { - to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()), - writes: Default::default(), - }; - let record = BranchEqualCoreRecord { - opcode: branch_eq_opcode, - a: data[0], - b: data[1], - cmp_result: F::from_bool(cmp_result), - imm, - diff_idx, - diff_inv_val, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut BranchEqualPreCompute = data.borrow_mut(); + let is_bne = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = if is_bne { + execute_e1_impl::<_, _, true> + } else { + execute_e1_impl::<_, _, false> }; + Ok(fn_ptr) + } +} - Ok((output, record)) +impl MeteredExecutor for BranchEqualExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - BranchEqualOpcode::from_usize(opcode - self.air.offset) - ) + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = if is_bne { + execute_e2_impl::<_, _, true> + } else { + execute_e2_impl::<_, _, false> + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BranchEqualCoreCols<_, NUM_LIMBS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.cmp_result = record.cmp_result; - row_slice.imm = record.imm; - row_slice.opcode_beq_flag = F::from_bool(record.opcode == BranchEqualOpcode::BEQ); - row_slice.opcode_bne_flag = F::from_bool(record.opcode == BranchEqualOpcode::BNE); - row_slice.diff_inv_marker = array::from_fn(|i| { - if i == record.diff_idx { - record.diff_inv_val - } else { - F::ZERO - } - }); +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchEqualPreCompute, + vm_state: &mut VmExecState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + if (rs1 == rs2) ^ IS_NE { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); } + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &BranchEqualPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} - fn air(&self) -> &Self::Air { - &self.air +impl BranchEqualExecutor { + /// Return `is_bne`, true if the local opcode is BNE. + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchEqualPreCompute, + ) -> Result { + let data: &mut BranchEqualPreCompute = data.borrow_mut(); + let &Instruction { + opcode, a, b, c, d, .. + } = inst; + let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = BranchEqualPreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok(local_opcode == BranchEqualOpcode::BNE) } } // Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) -pub(super) fn run_eq( +#[inline(always)] +pub(super) fn fast_run_eq( local_opcode: BranchEqualOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> (bool, usize, F) { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> bool { + match local_opcode { + BranchEqualOpcode::BEQ => x == y, + BranchEqualOpcode::BNE => x != y, + } +} + +// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) +#[inline(always)] +pub(super) fn run_eq( + is_beq: bool, + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> (bool, usize, F) +where + F: PrimeField32, +{ for i in 0..NUM_LIMBS { if x[i] != y[i] { return ( - local_opcode == BranchEqualOpcode::BNE, + !is_beq, i, - (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])).inverse(), + (F::from_canonical_u8(x[i]) - F::from_canonical_u8(y[i])).inverse(), ); } } - (local_opcode == BranchEqualOpcode::BEQ, 0, F::ZERO) + (is_beq, 0, F::ZERO) } diff --git a/extensions/rv32im/circuit/src/branch_eq/mod.rs b/extensions/rv32im/circuit/src/branch_eq/mod.rs index 7d53946a73..8e71c48c94 100644 --- a/extensions/rv32im/circuit/src/branch_eq/mod.rs +++ b/extensions/rv32im/circuit/src/branch_eq/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; use super::adapters::RV32_REGISTER_NUM_LIMBS; -use crate::adapters::Rv32BranchAdapterChip; +use crate::adapters::{Rv32BranchAdapterAir, Rv32BranchAdapterExecutor, Rv32BranchAdapterFiller}; mod core; pub use core::*; @@ -9,5 +9,9 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32BranchEqualAir = + VmAirWrapper>; +pub type Rv32BranchEqualExecutor = + BranchEqualExecutor; pub type Rv32BranchEqualChip = - VmChipWrapper, BranchEqualCoreChip>; + VmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/branch_eq/tests.rs b/extensions/rv32im/circuit/src/branch_eq/tests.rs index c16858b071..094a13f259 100644 --- a/extensions/rv32im/circuit/src/branch_eq/tests.rs +++ b/extensions/rv32im/circuit/src/branch_eq/tests.rs @@ -1,11 +1,11 @@ use std::{array, borrow::BorrowMut}; -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder}, - BasicAdapterInterface, ExecutionBridge, ImmInstruction, InstructionExecutor, VmAdapterChip, - VmChipWrapper, VmCoreChip, +use openvm_circuit::arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + LocalOpcode, }; -use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::BranchEqualOpcode; use openvm_stark_backend::{ p3_air::BaseAir, @@ -15,44 +15,76 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{ - core::{run_eq, BranchEqualCoreChip}, - BranchEqualCoreCols, Rv32BranchEqualChip, +use super::{core::run_eq, BranchEqualCoreCols, Rv32BranchEqualChip}; +use crate::{ + adapters::{ + Rv32BranchAdapterAir, Rv32BranchAdapterExecutor, Rv32BranchAdapterFiller, + RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS, + }, + branch_eq::fast_run_eq, + test_utils::get_verification_error, + BranchEqualCoreAir, BranchEqualFiller, Rv32BranchEqualAir, Rv32BranchEqualExecutor, }; -use crate::adapters::{Rv32BranchAdapterChip, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); +type Harness = + TestChipHarness>; + +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Harness { + let air = Rv32BranchEqualAir::new( + Rv32BranchAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchEqualCoreAir::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ); + let executor = Rv32BranchEqualExecutor::new( + Rv32BranchAdapterExecutor, + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ); + let chip = Rv32BranchEqualChip::new( + BranchEqualFiller::new( + Rv32BranchAdapterFiller, + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + tester.memory_helper(), + ); -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_branch_eq_rand_execute>( +fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut E, - opcode: BranchEqualOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - imm: i32, + harness: &mut Harness, rng: &mut StdRng, + opcode: BranchEqualOpcode, + a: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + imm: Option, ) { + let a = a.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let b = b.unwrap_or(if rng.gen_bool(0.5) { + a + } else { + array::from_fn(|_| rng.gen_range(0..=u8::MAX)) + }); + + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - tester.write::(1, rs1, a.map(F::from_canonical_u32)); - tester.write::(1, rs2, b.map(F::from_canonical_u32)); + tester.write::(1, rs1, a.map(F::from_canonical_u8)); + tester.write::(1, rs2, b.map(F::from_canonical_u8)); + let initial_pc = rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))); tester.execute_with_pc( - chip, + harness, &Instruction::from_isize( opcode.global_opcode(), rs1 as isize, @@ -61,10 +93,10 @@ fn run_rv32_branch_eq_rand_execute>( 1, 1, ), - rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))), + initial_pc, ); - let (cmp_result, _, _) = run_eq::(opcode, &a, &b); + let cmp_result = fast_run_eq(opcode, &a, &b); let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; let pc_inc = if cmp_result { imm } else { 4 }; @@ -72,183 +104,176 @@ fn run_rv32_branch_eq_rand_execute>( assert_eq!(to_pc, from_pc + pc_inc); } -fn run_rv32_branch_eq_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(BranchEqualOpcode::BEQ, 100)] +#[test_case(BranchEqualOpcode::BNE, 100)] +fn rand_rv32_branch_eq_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchEqualChip::::new( - Rv32BranchAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), - ); + let mut harness = create_test_chip(&mut tester); for _ in 0..num_ops { - let a = array::from_fn(|_| rng.gen_range(0..F::ORDER_U32)); - let b = if rng.gen_bool(0.5) { - a - } else { - array::from_fn(|_| rng.gen_range(0..F::ORDER_U32)) - }; - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - run_rv32_branch_eq_rand_execute(&mut tester, &mut chip, opcode, a, b, imm, &mut rng); + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + None, + None, + None, + ); } - let tester = tester.build().load(chip).finalize(); + let tester = tester.build().load(harness).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_beq_rand_test() { - run_rv32_branch_eq_rand_test(BranchEqualOpcode::BEQ, 100); -} - -#[test] -fn rv32_bne_rand_test() { - run_rv32_branch_eq_rand_test(BranchEqualOpcode::BNE, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BranchEqualTestChip = - VmChipWrapper, BranchEqualCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_beq_negative_test( +fn run_negative_branch_eq_test( opcode: BranchEqualOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, - diff_inv_marker: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + a: [u8; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: Option, + prank_diff_inv_marker: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + interaction_error: bool, ) { - let imm = 16u32; + let imm = 16i32; + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchEqualTestChip::::new( - TestAdapterChip::new( - vec![[a.map(F::from_canonical_u32), b.map(F::from_canonical_u32)].concat()], - vec![if cmp_result { Some(imm) } else { None }], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), + let mut harness = create_test_chip(&mut tester); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(a), + Some(b), + Some(imm), ); - tester.execute( - &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, imm as usize, 1, 1]), - ); - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - + let adapter_width = BaseAir::::width(&harness.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut BranchEqualCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.cmp_result = F::from_bool(cmp_result); - if let Some(diff_inv_marker) = diff_inv_marker { + if let Some(cmp_result) = prank_cmp_result { + cols.cmp_result = F::from_bool(cmp_result); + } + if let Some(diff_inv_marker) = prank_diff_inv_marker { cols.diff_inv_marker = diff_inv_marker.map(F::from_canonical_u32); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) + .load_and_prank_trace(harness, modify_trace) .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_beq_wrong_cmp_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 0, 7], - true, + Some(true), None, + false, ); - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 7, 0], - false, + Some(false), None, + false, ); } #[test] fn rv32_beq_zero_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 0, 7], - true, + Some(true), Some([0, 0, 0, 0]), + false, ); } #[test] fn rv32_beq_invalid_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 7, 0], - false, + Some(false), Some([0, 0, 1, 0]), + false, ); } #[test] fn rv32_bne_wrong_cmp_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 0, 7], - false, + Some(false), None, + false, ); - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 7, 0], - true, + Some(true), None, + false, ); } #[test] fn rv32_bne_zero_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 0, 7], - false, + Some(false), Some([0, 0, 0, 0]), + false, ); } #[test] fn rv32_bne_invalid_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 7, 0], - true, + Some(true), Some([0, 0, 1, 0]), + false, ); } @@ -259,66 +284,61 @@ fn rv32_bne_invalid_inv_marker_negative_test() { /////////////////////////////////////////////////////////////////////////////////////// #[test] -fn execute_pc_increment_sanity_test() { - let core = - BranchEqualCoreChip::::new(BranchEqualOpcode::CLASS_OFFSET, 4); - - let mut instruction = Instruction:: { - opcode: BranchEqualOpcode::BEQ.global_opcode(), - c: F::from_canonical_u8(8), - ..Default::default() - }; - let x: [F; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60].map(F::from_canonical_u32); - let y: [F; RV32_REGISTER_NUM_LIMBS] = [19, 32, 1804, 60].map(F::from_canonical_u32); - - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, y]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_none()); - - instruction.opcode = BranchEqualOpcode::BNE.global_opcode(); - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, y]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_some()); - assert_eq!(output.to_pc.unwrap(), 8); +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut harness = create_test_chip(&mut tester); + + let x = [19, 4, 179, 60]; + let y = [19, 32, 180, 60]; + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + BranchEqualOpcode::BEQ, + Some(x), + Some(y), + Some(8), + ); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + BranchEqualOpcode::BNE, + Some(x), + Some(y), + Some(8), + ); } #[test] fn run_eq_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60]; - let (cmp_result, _, diff_val) = - run_eq::(BranchEqualOpcode::BEQ, &x, &x); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 4, 17, 60]; + let (cmp_result, _, diff_val) = run_eq::(true, &x, &x); assert!(cmp_result); assert_eq!(diff_val, F::ZERO); - let (cmp_result, _, diff_val) = - run_eq::(BranchEqualOpcode::BNE, &x, &x); + let (cmp_result, _, diff_val) = run_eq::(false, &x, &x); assert!(!cmp_result); assert_eq!(diff_val, F::ZERO); } #[test] fn run_ne_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 32, 1804, 60]; - let (cmp_result, diff_idx, diff_val) = - run_eq::(BranchEqualOpcode::BEQ, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 4, 17, 60]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 32, 18, 60]; + let (cmp_result, diff_idx, diff_val) = run_eq::(true, &x, &y); assert!(!cmp_result); assert_eq!( - diff_val * (F::from_canonical_u32(x[diff_idx]) - F::from_canonical_u32(y[diff_idx])), + diff_val * (F::from_canonical_u8(x[diff_idx]) - F::from_canonical_u8(y[diff_idx])), F::ONE ); - let (cmp_result, diff_idx, diff_val) = - run_eq::(BranchEqualOpcode::BNE, &x, &y); + let (cmp_result, diff_idx, diff_val) = run_eq::(false, &x, &y); assert!(cmp_result); assert_eq!( - diff_val * (F::from_canonical_u32(x[diff_idx]) - F::from_canonical_u32(y[diff_idx])), + diff_val * (F::from_canonical_u8(x[diff_idx]) - F::from_canonical_u8(y[diff_idx])), F::ONE ); } diff --git a/extensions/rv32im/circuit/src/branch_lt/core.rs b/extensions/rv32im/circuit/src/branch_lt/core.rs index 3eebb02146..95b7a6b04c 100644 --- a/extensions/rv32im/circuit/src/branch_lt/core.rs +++ b/extensions/rv32im/circuit/src/branch_lt/core.rs @@ -1,18 +1,21 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, -}; - -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode, +}; use openvm_rv32im_transpiler::BranchLessThanOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,8 +23,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -53,7 +54,7 @@ pub struct BranchLessThanCoreCols { pub bus: BitwiseOperationLookupBus, offset: usize, @@ -188,183 +189,360 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BranchLessThanCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - pub cmp_result: T, - pub cmp_lt: T, - pub imm: T, - pub a_msb_f: T, - pub b_msb_f: T, - pub diff_val: T, - pub diff_idx: usize, - pub opcode: BranchLessThanOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct BranchLessThanCoreRecord { + pub a: [u8; NUM_LIMBS], + pub b: [u8; NUM_LIMBS], + pub imm: u32, + pub local_opcode: u8, } -pub struct BranchLessThanCoreChip { - pub air: BranchLessThanCoreAir, +#[derive(Clone, Copy, derive_new::new)] +pub struct BranchLessThanExecutor { + adapter: A, + pub offset: usize, +} + +#[derive(Clone, derive_new::new)] +pub struct BranchLessThanFiller { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, } -impl BranchLessThanCoreChip { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: BranchLessThanCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, - bitwise_lookup_chip, +impl PreflightExecutor + for BranchLessThanExecutor +where + F: PrimeField32, + A: 'static + AdapterTraceExecutor, WriteData = ()>, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut BranchLessThanCoreRecord, + ), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + BranchLessThanOpcode::from_usize(opcode - self.offset) + ) + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { opcode, c: imm, .. } = instruction; + + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.a = rs1; + core_record.b = rs2; + core_record.imm = imm.as_canonical_u32(); + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + + if run_cmp::(core_record.local_opcode, &rs1, &rs2).0 { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); } + + Ok(()) } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for BranchLessThanCoreChip +impl TraceFiller + for BranchLessThanFiller where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: Default, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = BranchLessThanCoreRecord; - type Air = BranchLessThanCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + let record: &BranchLessThanCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + + self.adapter.fill_trace_row(mem_helper, adapter_row); + let core_row: &mut BranchLessThanCoreCols = core_row.borrow_mut(); + + let signed = record.local_opcode == BranchLessThanOpcode::BLT as u8 + || record.local_opcode == BranchLessThanOpcode::BGE as u8; + let ge_op = record.local_opcode == BranchLessThanOpcode::BGE as u8 + || record.local_opcode == BranchLessThanOpcode::BGEU as u8; - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, c: imm, .. } = *instruction; - let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let a = data[0].map(|x| x.as_canonical_u32()); - let b = data[1].map(|y| y.as_canonical_u32()); let (cmp_result, diff_idx, a_sign, b_sign) = - run_cmp::(blt_opcode, &a, &b); + run_cmp::(record.local_opcode, &record.a, &record.b); - let signed = matches!( - blt_opcode, - BranchLessThanOpcode::BLT | BranchLessThanOpcode::BGE - ); - let ge_opcode = matches!( - blt_opcode, - BranchLessThanOpcode::BGE | BranchLessThanOpcode::BGEU - ); - let cmp_lt = cmp_result ^ ge_opcode; + let cmp_lt = cmp_result ^ ge_op; // We range check (a_msb_f + 128) and (b_msb_f + 128) if signed, // a_msb_f and b_msb_f if not let (a_msb_f, a_msb_range) = if a_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - a[NUM_LIMBS - 1]), - a[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u32((1 << LIMB_BITS) - record.a[NUM_LIMBS - 1] as u32), + record.a[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(a[NUM_LIMBS - 1]), - a[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)), + F::from_canonical_u32(record.a[NUM_LIMBS - 1] as u32), + record.a[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)), ) }; let (b_msb_f, b_msb_range) = if b_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u32((1 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u32), + record.b[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)), + F::from_canonical_u32(record.b[NUM_LIMBS - 1] as u32), + record.b[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)), ) }; - self.bitwise_lookup_chip - .request_range(a_msb_range, b_msb_range); - let diff_val = if diff_idx == NUM_LIMBS { - 0 + core_row.diff_val = if diff_idx == NUM_LIMBS { + F::ZERO } else if diff_idx == (NUM_LIMBS - 1) { if cmp_lt { b_msb_f - a_msb_f } else { a_msb_f - b_msb_f } - .as_canonical_u32() } else if cmp_lt { - b[diff_idx] - a[diff_idx] + F::from_canonical_u8(record.b[diff_idx] - record.a[diff_idx]) } else { - a[diff_idx] - b[diff_idx] + F::from_canonical_u8(record.a[diff_idx] - record.b[diff_idx]) }; + self.bitwise_lookup_chip + .request_range(a_msb_range, b_msb_range); + + core_row.diff_marker = [F::ZERO; NUM_LIMBS]; + if diff_idx != NUM_LIMBS { - self.bitwise_lookup_chip.request_range(diff_val - 1, 0); + self.bitwise_lookup_chip + .request_range(core_row.diff_val.as_canonical_u32() - 1, 0); + core_row.diff_marker[diff_idx] = F::ONE; } - let output = AdapterRuntimeContext { - to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()), - writes: Default::default(), - }; - let record = BranchLessThanCoreRecord { - opcode: blt_opcode, - a: data[0], - b: data[1], - cmp_result: F::from_bool(cmp_result), - cmp_lt: F::from_bool(cmp_lt), - imm, - a_msb_f, - b_msb_f, - diff_val: F::from_canonical_u32(diff_val), - diff_idx, - }; + core_row.cmp_lt = F::from_bool(cmp_lt); + core_row.b_msb_f = b_msb_f; + core_row.a_msb_f = a_msb_f; + core_row.opcode_bgeu_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BGEU as u8); + core_row.opcode_bge_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BGE as u8); + core_row.opcode_bltu_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BLTU as u8); + core_row.opcode_blt_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BLT as u8); + + core_row.imm = F::from_canonical_u32(record.imm); + core_row.cmp_result = F::from_bool(cmp_result); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = record.a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchLePreCompute { + imm: isize, + a: u8, + b: u8, +} - Ok((output, record)) +impl Executor + for BranchLessThanExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - BranchLessThanOpcode::from_usize(opcode - self.air.offset) - ) + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut BranchLePreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e1_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e1_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e1_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e1_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) + } +} +impl MeteredExecutor + for BranchLessThanExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = - row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.cmp_result = record.cmp_result; - row_slice.cmp_lt = record.cmp_lt; - row_slice.imm = record.imm; - row_slice.a_msb_f = record.a_msb_f; - row_slice.b_msb_f = record.b_msb_f; - row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx)); - row_slice.diff_val = record.diff_val; - row_slice.opcode_blt_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLT); - row_slice.opcode_bltu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLTU); - row_slice.opcode_bge_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGE); - row_slice.opcode_bgeu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGEU); + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e2_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e2_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e2_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e2_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) } +} - fn air(&self) -> &Self::Air { - &self.air +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchLePreCompute, + vm_state: &mut VmExecState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let jmp = ::compute(rs1, rs2); + if jmp { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + }; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &BranchLePreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl + BranchLessThanExecutor +{ + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchLePreCompute, + ) -> Result { + let &Instruction { + opcode, a, b, c, d, .. + } = inst; + let local_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = BranchLePreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok(local_opcode) + } +} + +trait BranchLessThanOp { + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool; +} +struct BltOp; +struct BltuOp; +struct BgeOp; +struct BgeuOp; + +impl BranchLessThanOp for BltOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = i32::from_le_bytes(rs1); + let rs2 = i32::from_le_bytes(rs2); + rs1 < rs2 + } +} +impl BranchLessThanOp for BltuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + rs1 < rs2 + } +} +impl BranchLessThanOp for BgeOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = i32::from_le_bytes(rs1); + let rs2 = i32::from_le_bytes(rs2); + rs1 >= rs2 + } +} +impl BranchLessThanOp for BgeuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + rs1 >= rs2 } } // Returns (cmp_result, diff_idx, x_sign, y_sign) +#[inline(always)] pub(super) fn run_cmp( - local_opcode: BranchLessThanOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + local_opcode: u8, + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize, bool, bool) { - let signed = - local_opcode == BranchLessThanOpcode::BLT || local_opcode == BranchLessThanOpcode::BGE; - let ge_op = - local_opcode == BranchLessThanOpcode::BGE || local_opcode == BranchLessThanOpcode::BGEU; + let signed = local_opcode == BranchLessThanOpcode::BLT as u8 + || local_opcode == BranchLessThanOpcode::BGE as u8; + let ge_op = local_opcode == BranchLessThanOpcode::BGE as u8 + || local_opcode == BranchLessThanOpcode::BGEU as u8; let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; for i in (0..NUM_LIMBS).rev() { diff --git a/extensions/rv32im/circuit/src/branch_lt/mod.rs b/extensions/rv32im/circuit/src/branch_lt/mod.rs index b0bf8fc417..eea1777bd4 100644 --- a/extensions/rv32im/circuit/src/branch_lt/mod.rs +++ b/extensions/rv32im/circuit/src/branch_lt/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32BranchAdapterChip; +use crate::adapters::{Rv32BranchAdapterAir, Rv32BranchAdapterExecutor, Rv32BranchAdapterFiller}; mod core; pub use core::*; @@ -9,8 +9,13 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32BranchLessThanAir = VmAirWrapper< + Rv32BranchAdapterAir, + BranchLessThanCoreAir, +>; +pub type Rv32BranchLessThanExecutor = + BranchLessThanExecutor; pub type Rv32BranchLessThanChip = VmChipWrapper< F, - Rv32BranchAdapterChip, - BranchLessThanCoreChip, + BranchLessThanFiller, >; diff --git a/extensions/rv32im/circuit/src/branch_lt/tests.rs b/extensions/rv32im/circuit/src/branch_lt/tests.rs index 8c1d7f697a..da5ab3e86f 100644 --- a/extensions/rv32im/circuit/src/branch_lt/tests.rs +++ b/extensions/rv32im/circuit/src/branch_lt/tests.rs @@ -1,15 +1,14 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut, sync::Arc}; use openvm_circuit::{ - arch::{ - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - BasicAdapterInterface, ExecutionBridge, ImmInstruction, InstructionExecutor, VmAdapterChip, - VmChipWrapper, VmCoreChip, + arch::testing::{ + memory::gen_pointer, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, }, - utils::{generate_long_number, i32_to_f}, + utils::i32_to_f, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::BranchLessThanOpcode; @@ -21,49 +20,92 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{ - core::{run_cmp, BranchLessThanCoreChip}, - Rv32BranchLessThanChip, -}; +use super::{core::run_cmp, Rv32BranchLessThanChip}; use crate::{ adapters::{ - Rv32BranchAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS, + Rv32BranchAdapterAir, Rv32BranchAdapterExecutor, Rv32BranchAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS, }, branch_lt::BranchLessThanCoreCols, + test_utils::get_verification_error, + BranchLessThanCoreAir, BranchLessThanFiller, Rv32BranchLessThanAir, Rv32BranchLessThanExecutor, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); +type Harness = TestChipHarness< + F, + Rv32BranchLessThanExecutor, + Rv32BranchLessThanAir, + Rv32BranchLessThanChip, +>; -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = Rv32BranchLessThanAir::new( + Rv32BranchAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchLessThanCoreAir::new(bitwise_bus, BranchLessThanOpcode::CLASS_OFFSET), + ); + let executor = Rv32BranchLessThanExecutor::new( + Rv32BranchAdapterExecutor::new(), + BranchLessThanOpcode::CLASS_OFFSET, + ); + let chip = Rv32BranchLessThanChip::new( + BranchLessThanFiller::new( + Rv32BranchAdapterFiller, + bitwise_chip.clone(), + BranchLessThanOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_branch_lt_rand_execute>( +fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut E, - opcode: BranchLessThanOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - imm: i32, + harness: &mut Harness, rng: &mut StdRng, + opcode: BranchLessThanOpcode, + a: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + imm: Option, ) { + let a = a.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let b = b.unwrap_or(if rng.gen_bool(0.5) { + a + } else { + array::from_fn(|_| rng.gen_range(0..=u8::MAX)) + }); + + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - tester.write::(1, rs1, a.map(F::from_canonical_u32)); - tester.write::(1, rs2, b.map(F::from_canonical_u32)); + tester.write::(1, rs1, a.map(F::from_canonical_u8)); + tester.write::(1, rs2, b.map(F::from_canonical_u8)); tester.execute_with_pc( - chip, + harness, &Instruction::from_isize( opcode.global_opcode(), rs1 as isize, @@ -75,7 +117,8 @@ fn run_rv32_branch_lt_rand_execute>( rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))), ); - let (cmp_result, _, _, _) = run_cmp::(opcode, &a, &b); + let (cmp_result, _, _, _) = + run_cmp::(opcode.local_usize() as u8, &a, &b); let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; let pc_inc = if cmp_result { imm } else { 4 }; @@ -83,93 +126,69 @@ fn run_rv32_branch_lt_rand_execute>( assert_eq!(to_pc, from_pc + pc_inc); } -fn run_rv32_branch_lt_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(BranchLessThanOpcode::BLT, 100)] +#[test_case(BranchLessThanOpcode::BLTU, 100)] +#[test_case(BranchLessThanOpcode::BGE, 100)] +#[test_case(BranchLessThanOpcode::BGEU, 100)] +fn rand_branch_lt_test(opcode: BranchLessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchLessThanChip::::new( - Rv32BranchAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - BranchLessThanCoreChip::new(bitwise_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise_chip) = create_test_chip(&mut tester); for _ in 0..num_ops { - let a = generate_long_number::(&mut rng); - let b = if rng.gen_bool(0.5) { - a - } else { - generate_long_number::(&mut rng) - }; - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - run_rv32_branch_lt_rand_execute(&mut tester, &mut chip, opcode, a, b, imm, &mut rng); + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + None, + None, + None, + ); } // Test special case where b = c - run_rv32_branch_lt_rand_execute( + set_and_execute( &mut tester, - &mut chip, - opcode, - [101, 128, 202, 255], - [101, 128, 202, 255], - 24, + &mut harness, &mut rng, + opcode, + Some([101, 128, 202, 255]), + Some([101, 128, 202, 255]), + Some(24), ); - run_rv32_branch_lt_rand_execute( + set_and_execute( &mut tester, - &mut chip, - opcode, - [36, 0, 0, 0], - [36, 0, 0, 0], - 24, + &mut harness, &mut rng, + opcode, + Some([36, 0, 0, 0]), + Some([36, 0, 0, 0]), + Some(24), ); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise_chip) + .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_blt_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BLT, 10); -} - -#[test] -fn rv32_bltu_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BLTU, 12); -} - -#[test] -fn rv32_bge_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BGE, 12); -} - -#[test] -fn rv32_bgeu_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BGEU, 12); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BranchLessThanTestChip = VmChipWrapper< - F, - TestAdapterChip, - BranchLessThanCoreChip, ->; - #[derive(Clone, Copy, Default, PartialEq)] struct BranchLessThanPrankValues { pub a_msb: Option, @@ -179,66 +198,31 @@ struct BranchLessThanPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_blt_negative_test( +fn run_negative_branch_lt_test( opcode: BranchLessThanOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, + a: [u8; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: bool, prank_vals: BranchLessThanPrankValues, interaction_error: bool, ) { - let imm = 16u32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32BranchLessThanTestChip::::new( - TestAdapterChip::new( - vec![[a.map(F::from_canonical_u32), b.map(F::from_canonical_u32)].concat()], - vec![if cmp_result { Some(imm) } else { None }], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BranchLessThanCoreChip::new(bitwise_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let imm = 16i32; + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise) = create_test_chip(&mut tester); - tester.execute( - &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, imm as usize, 1, 1]), + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(a), + Some(b), + Some(imm), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); + let adapter_width = BaseAir::::width(&harness.air.adapter); let ge_opcode = opcode == BranchLessThanOpcode::BGE || opcode == BranchLessThanOpcode::BGEU; - let (_, _, a_sign, b_sign) = run_cmp::(opcode, &a, &b); - - if prank_vals != BranchLessThanPrankValues::default() { - debug_assert!(prank_vals.diff_val.is_some()); - let a_msb = prank_vals.a_msb.unwrap_or( - a[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if a_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let b_msb = prank_vals.b_msb.unwrap_or( - b[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if b_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let signed_offset = match opcode { - BranchLessThanOpcode::BLT | BranchLessThanOpcode::BGE => 1 << (RV32_CELL_BITS - 1), - _ => 0, - }; - - bitwise_chip.clear(); - bitwise_chip.request_range( - (a_msb + signed_offset) as u8 as u32, - (b_msb + signed_offset) as u8 as u32, - ); - - let diff_val = prank_vals - .diff_val - .unwrap() - .clamp(0, (1 << RV32_CELL_BITS) - 1); - if diff_val > 0 { - bitwise_chip.request_range(diff_val - 1, 0); - } - } let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); @@ -257,23 +241,19 @@ fn run_rv32_blt_negative_test( if let Some(diff_val) = prank_vals.diff_val { cols.diff_val = F::from_canonical_u32(diff_val); } - cols.cmp_result = F::from_bool(cmp_result); - cols.cmp_lt = F::from_bool(ge_opcode ^ cmp_result); + cols.cmp_result = F::from_bool(prank_cmp_result); + cols.cmp_lt = F::from_bool(ge_opcode ^ prank_cmp_result); - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -281,10 +261,10 @@ fn rv32_blt_wrong_lt_cmp_negative_test() { let a = [145, 34, 25, 205]; let b = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -292,10 +272,10 @@ fn rv32_blt_wrong_ge_cmp_negative_test() { let a = [73, 35, 25, 205]; let b = [145, 34, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -303,10 +283,10 @@ fn rv32_blt_wrong_eq_cmp_negative_test() { let a = [73, 35, 25, 205]; let b = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -317,10 +297,10 @@ fn rv32_blt_fake_diff_val_negative_test() { diff_val: Some(F::NEG_ONE.as_canonical_u32()), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } #[test] @@ -332,10 +312,10 @@ fn rv32_blt_zero_diff_val_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } #[test] @@ -347,10 +327,10 @@ fn rv32_blt_fake_diff_marker_negative_test() { diff_val: Some(72), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -362,10 +342,10 @@ fn rv32_blt_zero_diff_marker_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -378,8 +358,8 @@ fn rv32_blt_signed_wrong_a_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); } #[test] @@ -392,8 +372,8 @@ fn rv32_blt_signed_wrong_a_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); } #[test] @@ -406,8 +386,8 @@ fn rv32_blt_signed_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); } #[test] @@ -420,8 +400,8 @@ fn rv32_blt_signed_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, true); } #[test] @@ -434,8 +414,8 @@ fn rv32_blt_unsigned_wrong_a_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -448,8 +428,8 @@ fn rv32_blt_unsigned_wrong_a_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, true); } #[test] @@ -462,8 +442,8 @@ fn rv32_blt_unsigned_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -476,8 +456,8 @@ fn rv32_blt_unsigned_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -487,51 +467,52 @@ fn rv32_blt_unsigned_wrong_b_msb_sign_negative_test() { /////////////////////////////////////////////////////////////////////////////////////// #[test] -fn execute_pc_increment_sanity_test() { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let core = BranchLessThanCoreChip::::new( - bitwise_chip, - BranchLessThanOpcode::CLASS_OFFSET, +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, _) = create_test_chip(&mut tester); + + let x = [145, 34, 25, 205]; + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchLessThanOpcode::BLT, + Some(x), + Some(x), + Some(8), ); - let mut instruction = Instruction:: { - opcode: BranchLessThanOpcode::BLT.global_opcode(), - c: F::from_canonical_u8(8), - ..Default::default() - }; - let x: [F; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205].map(F::from_canonical_u32); - - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, x]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_none()); - - instruction.opcode = BranchLessThanOpcode::BGE.global_opcode(); - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, x]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_some()); - assert_eq!(output.to_pc.unwrap(), 8); + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchLessThanOpcode::BGE, + Some(x), + Some(x), + Some(8), + ); } #[test] fn run_cmp_unsigned_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLTU, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BLTU as u8, + &x, + &y, + ); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(!x_sign); // unsigned assert!(!y_sign); // unsigned - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGEU, &x, &y); + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BGEU as u8, + &x, + &y, + ); assert!(!cmp_result); assert_eq!(diff_idx, 1); assert!(!x_sign); // unsigned @@ -540,17 +521,17 @@ fn run_cmp_unsigned_sanity_test() { #[test] fn run_cmp_same_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLT, &x, &y); + run_cmp::(BranchLessThanOpcode::BLT as u8, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(x_sign); // negative assert!(y_sign); // negative let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGE, &x, &y); + run_cmp::(BranchLessThanOpcode::BGE as u8, &x, &y); assert!(!cmp_result); assert_eq!(diff_idx, 1); assert!(x_sign); // negative @@ -559,17 +540,17 @@ fn run_cmp_same_sign_sanity_test() { #[test] fn run_cmp_diff_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLT, &x, &y); + run_cmp::(BranchLessThanOpcode::BLT as u8, &x, &y); assert!(!cmp_result); assert_eq!(diff_idx, 3); assert!(!x_sign); // positive assert!(y_sign); // negative let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGE, &x, &y); + run_cmp::(BranchLessThanOpcode::BGE as u8, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 3); assert!(!x_sign); // positive @@ -578,27 +559,33 @@ fn run_cmp_diff_sign_sanity_test() { #[test] fn run_cmp_eq_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLT, &x, &x); + run_cmp::(BranchLessThanOpcode::BLT as u8, &x, &x); assert!(!cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLTU, &x, &x); + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BLTU as u8, + &x, + &x, + ); assert!(!cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGE, &x, &x); + run_cmp::(BranchLessThanOpcode::BGE as u8, &x, &x); assert!(cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGEU, &x, &x); + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BGEU as u8, + &x, + &x, + ); assert!(cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); diff --git a/extensions/rv32im/circuit/src/divrem/core.rs b/extensions/rv32im/circuit/src/divrem/core.rs index b21c32345e..70bb96a937 100644 --- a/extensions/rv32im/circuit/src/divrem/core.rs +++ b/extensions/rv32im/circuit/src/divrem/core.rs @@ -5,17 +5,26 @@ use std::{ use num_bigint::BigUint; use num_integer::Integer; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, utils::{not, select}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::DivRemOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -23,8 +32,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -67,7 +74,7 @@ pub struct DivRemCoreCols { pub opcode_remu_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct DivRemCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_tuple_bus: RangeTupleCheckerBus<2>, @@ -342,14 +349,38 @@ where } } -pub struct DivRemCoreChip { - pub air: DivRemCoreAir, +#[derive(Debug, Eq, PartialEq)] +#[repr(u8)] +pub(super) enum DivRemCoreSpecialCase { + None, + ZeroDivisor, + SignedOverflow, +} + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct DivRemCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, +} + +#[derive(Clone, Copy, derive_new::new)] +pub struct DivRemExecutor { + adapter: A, + pub offset: usize, +} + +pub struct DivRemFiller { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl DivRemCoreChip { +impl DivRemFiller { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize, @@ -369,83 +400,105 @@ impl DivRemCoreChip { - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub q: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub r: [T; NUM_LIMBS], - pub zero_divisor: T, - pub r_zero: T, - pub b_sign: T, - pub c_sign: T, - pub q_sign: T, - pub sign_xor: T, - pub c_sum_inv: T, - pub r_sum_inv: T, - #[serde(with = "BigArray")] - pub r_prime: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub r_inv: [T; NUM_LIMBS], - pub lt_diff_val: T, - pub lt_diff_idx: usize, - pub opcode: DivRemOpcode, -} - -#[derive(Debug, Eq, PartialEq)] -#[repr(u8)] -pub(super) enum DivRemCoreSpecialCase { - None, - ZeroDivisor, - SignedOverflow, -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for DivRemCoreChip +impl PreflightExecutor + for DivRemExecutor where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut DivRemCoreRecord), + >, { - type Record = DivRemCoreRecord; - type Air = DivRemCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", DivRemOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { let Instruction { opcode, .. } = instruction; - let divrem_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let is_div = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::DIVU; - let is_signed = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::REM; + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + + let is_signed = core_record.local_opcode == DivRemOpcode::DIV as u8 + || core_record.local_opcode == DivRemOpcode::REM as u8; + let is_div = core_record.local_opcode == DivRemOpcode::DIV as u8 + || core_record.local_opcode == DivRemOpcode::DIVU as u8; - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (q, r, b_sign, c_sign, q_sign, case) = - run_divrem::(is_signed, &b, &c); + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); - let carries = run_mul_carries::(is_signed, &c, &q, &r, q_sign); + let b = core_record.b.map(u32::from); + let c = core_record.c.map(u32::from); + let (q, r, _, _, _, _) = run_divrem::(is_signed, &b, &c); + + let rd = if is_div { + q.map(|x| x as u8) + } else { + r.map(|x| x as u8) + }; + + self.adapter + .write(state.memory, instruction, [rd].into(), &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for DivRemFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &DivRemCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut DivRemCoreCols = core_row.borrow_mut(); + + let opcode = DivRemOpcode::from_usize(record.local_opcode as usize); + let is_signed = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::REM; + + let (q, r, b_sign, c_sign, q_sign, case) = run_divrem::( + is_signed, + &record.b.map(u32::from), + &record.c.map(u32::from), + ); + + let carries = run_mul_carries::( + is_signed, + &record.c.map(u32::from), + &q, + &r, + q_sign, + ); for i in 0..NUM_LIMBS { self.range_tuple_chip.add_count(&[q[i], carries[i]]); self.range_tuple_chip @@ -464,94 +517,244 @@ where let b_sign_mask = if b_sign { 1 << (LIMB_BITS - 1) } else { 0 }; let c_sign_mask = if c_sign { 1 << (LIMB_BITS - 1) } else { 0 }; self.bitwise_lookup_chip.request_range( - (b[NUM_LIMBS - 1] - b_sign_mask) << 1, - (c[NUM_LIMBS - 1] - c_sign_mask) << 1, + (record.b[NUM_LIMBS - 1] as u32 - b_sign_mask) << 1, + (record.c[NUM_LIMBS - 1] as u32 - c_sign_mask) << 1, ); } - let c_sum_f = data[1].iter().fold(F::ZERO, |acc, c| acc + *c); - let c_sum_inv_f = c_sum_f.try_inverse().unwrap_or(F::ZERO); - - let r_sum_f = r - .iter() - .fold(F::ZERO, |acc, r| acc + F::from_canonical_u32(*r)); - let r_sum_inv_f = r_sum_f.try_inverse().unwrap_or(F::ZERO); + // Write in a reverse order + core_row.opcode_remu_flag = F::from_bool(opcode == DivRemOpcode::REMU); + core_row.opcode_rem_flag = F::from_bool(opcode == DivRemOpcode::REM); + core_row.opcode_divu_flag = F::from_bool(opcode == DivRemOpcode::DIVU); + core_row.opcode_div_flag = F::from_bool(opcode == DivRemOpcode::DIV); - let (lt_diff_idx, lt_diff_val) = if case == DivRemCoreSpecialCase::None && !r_zero { - let idx = run_sltu_diff_idx(&c, &r_prime, c_sign); + core_row.lt_diff = F::ZERO; + core_row.lt_marker = [F::ZERO; NUM_LIMBS]; + if case == DivRemCoreSpecialCase::None && !r_zero { + let idx = run_sltu_diff_idx(&record.c.map(u32::from), &r_prime, c_sign); let val = if c_sign { - r_prime[idx] - c[idx] + r_prime[idx] - record.c[idx] as u32 } else { - c[idx] - r_prime[idx] + record.c[idx] as u32 - r_prime[idx] }; self.bitwise_lookup_chip.request_range(val - 1, 0); - (idx, val) - } else { - (NUM_LIMBS, 0) - }; + core_row.lt_diff = F::from_canonical_u32(val); + core_row.lt_marker[idx] = F::ONE; + } let r_prime_f = r_prime.map(F::from_canonical_u32); - let output = AdapterRuntimeContext::without_pc([ - (if is_div { &q } else { &r }).map(F::from_canonical_u32) - ]); - let record = DivRemCoreRecord { - opcode: divrem_opcode, - b: data[0], - c: data[1], - q: q.map(F::from_canonical_u32), - r: r.map(F::from_canonical_u32), - zero_divisor: F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor), - r_zero: F::from_bool(r_zero), - b_sign: F::from_bool(b_sign), - c_sign: F::from_bool(c_sign), - q_sign: F::from_bool(q_sign), - sign_xor: F::from_bool(sign_xor), - c_sum_inv: c_sum_inv_f, - r_sum_inv: r_sum_inv_f, - r_prime: r_prime_f, - r_inv: r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse()), - lt_diff_val: F::from_canonical_u32(lt_diff_val), - lt_diff_idx, + core_row.r_inv = r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse()); + core_row.r_prime = r_prime_f; + + let r_sum_f = r + .iter() + .fold(F::ZERO, |acc, r| acc + F::from_canonical_u32(*r)); + core_row.r_sum_inv = r_sum_f.try_inverse().unwrap_or(F::ZERO); + + let c_sum_f = F::from_canonical_u32(record.c.iter().fold(0, |acc, c| acc + *c as u32)); + core_row.c_sum_inv = c_sum_f.try_inverse().unwrap_or(F::ZERO); + + core_row.sign_xor = F::from_bool(sign_xor); + core_row.q_sign = F::from_bool(q_sign); + core_row.c_sign = F::from_bool(c_sign); + core_row.b_sign = F::from_bool(b_sign); + + core_row.r_zero = F::from_bool(r_zero); + core_row.zero_divisor = F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor); + + core_row.r = r.map(F::from_canonical_u32); + core_row.q = q.map(F::from_canonical_u32); + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct DivRemPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor + for DivRemExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut DivRemPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + DivRemOpcode::DIV => execute_e1_impl::<_, _, DivOp>, + DivRemOpcode::DIVU => execute_e1_impl::<_, _, DivuOp>, + DivRemOpcode::REM => execute_e1_impl::<_, _, RemOp>, + DivRemOpcode::REMU => execute_e1_impl::<_, _, RemuOp>, }; + Ok(fn_ptr) + } +} - Ok((output, record)) +impl MeteredExecutor + for DivRemExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", DivRemOpcode::from_usize(opcode - self.air.offset)) + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + DivRemOpcode::DIV => execute_e2_impl::<_, _, DivOp>, + DivRemOpcode::DIVU => execute_e2_impl::<_, _, DivuOp>, + DivRemOpcode::REM => execute_e2_impl::<_, _, RemOp>, + DivRemOpcode::REMU => execute_e2_impl::<_, _, RemuOp>, + }; + Ok(fn_ptr) + } +} + +unsafe fn execute_e12_impl( + pre_compute: &DivRemPreCompute, + vm_state: &mut VmExecState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let result = ::compute(rs1, rs2); + vm_state.vm_write::(RV32_REGISTER_AS, pre_compute.a as u32, &result); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &DivRemPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl DivRemExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut DivRemPreCompute, + ) -> Result { + let &Instruction { + opcode, a, b, c, d, .. + } = inst; + let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + let pre_compute: &mut DivRemPreCompute = data.borrow_mut(); + *pre_compute = DivRemPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + Ok(local_opcode) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut DivRemCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.q = record.q; - row_slice.r = record.r; - row_slice.zero_divisor = record.zero_divisor; - row_slice.r_zero = record.r_zero; - row_slice.b_sign = record.b_sign; - row_slice.c_sign = record.c_sign; - row_slice.q_sign = record.q_sign; - row_slice.sign_xor = record.sign_xor; - row_slice.c_sum_inv = record.c_sum_inv; - row_slice.r_sum_inv = record.r_sum_inv; - row_slice.r_prime = record.r_prime; - row_slice.r_inv = record.r_inv; - row_slice.lt_marker = array::from_fn(|i| F::from_bool(i == record.lt_diff_idx)); - row_slice.lt_diff = record.lt_diff_val; - row_slice.opcode_div_flag = F::from_bool(record.opcode == DivRemOpcode::DIV); - row_slice.opcode_divu_flag = F::from_bool(record.opcode == DivRemOpcode::DIVU); - row_slice.opcode_rem_flag = F::from_bool(record.opcode == DivRemOpcode::REM); - row_slice.opcode_remu_flag = F::from_bool(record.opcode == DivRemOpcode::REMU); +trait DivRemOp { + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4]; +} +struct DivOp; +struct DivuOp; +struct RemOp; +struct RemuOp; +impl DivRemOp for DivOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1_i32 = i32::from_le_bytes(rs1); + let rs2_i32 = i32::from_le_bytes(rs2); + match (rs1_i32, rs2_i32) { + (_, 0) => [u8::MAX; 4], + (i32::MIN, -1) => rs1, + _ => (rs1_i32 / rs2_i32).to_le_bytes(), + } + } +} +impl DivRemOp for DivuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + if rs2 == [0; 4] { + [u8::MAX; 4] + } else { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + (rs1 / rs2).to_le_bytes() + } } +} - fn air(&self) -> &Self::Air { - &self.air +impl DivRemOp for RemOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1_i32 = i32::from_le_bytes(rs1); + let rs2_i32 = i32::from_le_bytes(rs2); + match (rs1_i32, rs2_i32) { + (_, 0) => rs1, + (i32::MIN, -1) => [0; 4], + _ => (rs1_i32 % rs2_i32).to_le_bytes(), + } + } +} + +impl DivRemOp for RemuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + if rs2 == [0; 4] { + rs1 + } else { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + (rs1 % rs2).to_le_bytes() + } } } // Returns (quotient, remainder, x_sign, y_sign, q_sign, case) where case = 0 for normal, 1 // for zero divisor, and 2 for signed overflow +#[inline(always)] pub(super) fn run_divrem( signed: bool, x: &[u32; NUM_LIMBS], @@ -628,6 +831,7 @@ pub(super) fn run_divrem( (q, r, x_sign, y_sign, q_sign, DivRemCoreSpecialCase::None) } +#[inline(always)] pub(super) fn run_sltu_diff_idx( x: &[u32; NUM_LIMBS], y: &[u32; NUM_LIMBS], @@ -644,6 +848,7 @@ pub(super) fn run_sltu_diff_idx( } // returns carries of d * q + r +#[inline(always)] pub(super) fn run_mul_carries( signed: bool, d: &[u32; NUM_LIMBS], @@ -684,6 +889,7 @@ pub(super) fn run_mul_carries( carry } +#[inline(always)] fn limbs_to_biguint( x: &[u32; NUM_LIMBS], ) -> BigUint { @@ -696,6 +902,7 @@ fn limbs_to_biguint( res } +#[inline(always)] fn biguint_to_limbs( x: &BigUint, ) -> [u32; NUM_LIMBS] { @@ -711,6 +918,7 @@ fn biguint_to_limbs( res } +#[inline(always)] fn negate( x: &[u32; NUM_LIMBS], ) -> [u32; NUM_LIMBS] { diff --git a/extensions/rv32im/circuit/src/divrem/mod.rs b/extensions/rv32im/circuit/src/divrem/mod.rs index 979ab38dc3..bba2fd9ed2 100644 --- a/extensions/rv32im/circuit/src/divrem/mod.rs +++ b/extensions/rv32im/circuit/src/divrem/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller}; mod core; pub use core::*; @@ -8,8 +9,9 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32DivRemChip = VmChipWrapper< - F, - Rv32MultAdapterChip, - DivRemCoreChip, ->; +pub type Rv32DivRemAir = + VmAirWrapper>; +pub type Rv32DivRemExecutor = + DivRemExecutor; +pub type Rv32DivRemChip = + VmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/divrem/tests.rs b/extensions/rv32im/circuit/src/divrem/tests.rs index 41d8a9cc46..806f6e9483 100644 --- a/extensions/rv32im/circuit/src/divrem/tests.rs +++ b/extensions/rv32im/circuit/src/divrem/tests.rs @@ -1,21 +1,24 @@ -use std::{array, borrow::BorrowMut}; +use std::{array, borrow::BorrowMut, sync::Arc}; use openvm_circuit::{ - arch::{ - testing::{ - memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, - RANGE_TUPLE_CHECKER_BUS, - }, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, + arch::testing::{ + memory::gen_pointer, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, + RANGE_TUPLE_CHECKER_BUS, }, utils::generate_long_number, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, + range_tuple::{ + RangeTupleCheckerAir, RangeTupleCheckerBus, RangeTupleCheckerChip, + SharedRangeTupleCheckerChip, + }, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::DivRemOpcode; +use openvm_rv32im_transpiler::DivRemOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra}, @@ -24,29 +27,29 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::core::run_divrem; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, + }, divrem::{ - run_mul_carries, run_sltu_diff_idx, DivRemCoreChip, DivRemCoreCols, DivRemCoreSpecialCase, - Rv32DivRemChip, + run_mul_carries, run_sltu_diff_idx, DivRemCoreCols, DivRemCoreSpecialCase, Rv32DivRemChip, }, + test_utils::get_verification_error, + DivRemCoreAir, DivRemFiller, Rv32DivRemAir, Rv32DivRemExecutor, }; type F = BabyBear; - -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; +type Harness = TestChipHarness>; fn limb_sra( x: [u32; NUM_LIMBS], @@ -57,15 +60,70 @@ fn limb_sra( array::from_fn(|i| if i + shift < NUM_LIMBS { x[i] } else { ext }) } +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), + (RangeTupleCheckerAir<2>, SharedRangeTupleCheckerChip<2>), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + let range_tuple_chip = + SharedRangeTupleCheckerChip::new(RangeTupleCheckerChip::<2>::new(range_tuple_bus)); + + let air = Rv32DivRemAir::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + DivRemCoreAir::new(bitwise_bus, range_tuple_bus, DivRemOpcode::CLASS_OFFSET), + ); + let executor = Rv32DivRemExecutor::new(Rv32MultAdapterExecutor, DivRemOpcode::CLASS_OFFSET); + let chip = Rv32DivRemChip::::new( + DivRemFiller::new( + Rv32MultAdapterFiller, + bitwise_chip.clone(), + range_tuple_chip.clone(), + DivRemOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + ( + harness, + (bitwise_chip.air, bitwise_chip), + (range_tuple_chip.air, range_tuple_chip), + ) +} + #[allow(clippy::too_many_arguments)] -fn run_rv32_divrem_rand_write_execute>( - opcode: DivRemOpcode, +fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut E, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + harness: &mut Harness, rng: &mut StdRng, + opcode: DivRemOpcode, + b: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, ) { + let b = b.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let c = c.unwrap_or(limb_sra::( + generate_long_number::(rng), + rng.gen_range(0..(RV32_REGISTER_NUM_LIMBS - 1)), + )); + let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); let rd = gen_pointer(rng, 4); @@ -73,13 +131,13 @@ fn run_rv32_divrem_rand_write_execute>( tester.write::(1, rs1, b.map(F::from_canonical_u32)); tester.write::(1, rs2, c.map(F::from_canonical_u32)); - let is_div = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::DIVU; - let is_signed = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::REM; + let is_div = opcode == DIV || opcode == DIVU; + let is_signed = opcode == DIV || opcode == REM; let (q, r, _, _, _, _) = run_divrem::(is_signed, &b, &c); tester.execute( - chip, + harness, &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 0]), ); @@ -89,136 +147,101 @@ fn run_rv32_divrem_rand_write_execute>( ); } -fn run_rv32_divrem_rand_test(opcode: DivRemOpcode, num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(DIV, 100)] +#[test_case(DIVU, 100)] +#[test_case(REM, 100)] +#[test_case(REMU, 100)] +fn rand_divrem_test(opcode: DivRemOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32DivRemChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - DivRemCoreChip::new( - bitwise_chip.clone(), - range_tuple_checker.clone(), - DivRemOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise, range_tuple) = create_test_chip(&mut tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let leading_zeros = rng.gen_range(0..(RV32_REGISTER_NUM_LIMBS - 1)); - let c = limb_sra::( - generate_long_number::(&mut rng), - leading_zeros, - ); - run_rv32_divrem_rand_write_execute(opcode, &mut tester, &mut chip, b, c, &mut rng); + set_and_execute(&mut tester, &mut harness, &mut rng, opcode, None, None); } // Test special cases in addition to random cases (i.e. zero divisor with b > 0, // zero divisor with b < 0, r = 0 (3 cases), and signed overflow). - run_rv32_divrem_rand_write_execute( - opcode, + set_and_execute( &mut tester, - &mut chip, - [98, 188, 163, 127], - [0, 0, 0, 0], + &mut harness, &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([98, 188, 163, 127]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, - &mut chip, - [98, 188, 163, 229], - [0, 0, 0, 0], + &mut harness, &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([98, 188, 163, 229]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, - &mut chip, - [0, 0, 0, 128], - [0, 1, 0, 0], + &mut harness, &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([0, 0, 0, 128]), + Some([0, 1, 0, 0]), + ); + set_and_execute( &mut tester, - &mut chip, - [0, 0, 0, 127], - [0, 1, 0, 0], + &mut harness, &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([0, 0, 0, 127]), + Some([0, 1, 0, 0]), + ); + set_and_execute( &mut tester, - &mut chip, - [0, 0, 0, 0], - [0, 0, 0, 0], + &mut harness, &mut rng, + opcode, + Some([0, 0, 0, 0]), + Some([0, 0, 0, 0]), ); - run_rv32_divrem_rand_write_execute( + set_and_execute( + &mut tester, + &mut harness, + &mut rng, opcode, + Some([0, 0, 0, 0]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, - &mut chip, - [0, 0, 0, 128], - [255, 255, 255, 255], + &mut harness, &mut rng, + opcode, + Some([0, 0, 0, 128]), + Some([255, 255, 255, 255]), ); let tester = tester .build() - .load(chip) - .load(bitwise_chip) - .load(range_tuple_checker) + .load(harness) + .load_periphery(bitwise) + .load_periphery(range_tuple) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_div_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::DIV, 100); -} - -#[test] -fn rv32_divu_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::DIVU, 100); -} - -#[test] -fn rv32_rem_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::REM, 100); -} - -#[test] -fn rv32_remu_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::REMU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32DivRemTestChip = - VmChipWrapper, DivRemCoreChip>; - #[derive(Default, Clone, Copy)] struct DivRemPrankValues { pub q: Option<[u32; NUM_LIMBS]>, @@ -229,84 +252,27 @@ struct DivRemPrankValues { pub r_zero: Option, } -fn run_rv32_divrem_negative_test( - signed: bool, +fn run_negative_divrem_test( + opcode: DivRemOpcode, b: [u32; RV32_REGISTER_NUM_LIMBS], c: [u32; RV32_REGISTER_NUM_LIMBS], - prank_vals: &DivRemPrankValues, + prank_vals: DivRemPrankValues, interaction_error: bool, ) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32DivRemTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat(); 2], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - DivRemCoreChip::new( - bitwise_chip.clone(), - range_tuple_chip.clone(), - DivRemOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise, range_tuple) = create_test_chip(&mut tester); - let (div_opcode, rem_opcode) = if signed { - (DivRemOpcode::DIV, DivRemOpcode::REM) - } else { - (DivRemOpcode::DIVU, DivRemOpcode::REMU) - }; - tester.execute( - &mut chip, - &Instruction::from_usize(div_opcode.global_opcode(), [0, 0, 0, 1, 1]), - ); - tester.execute( - &mut chip, - &Instruction::from_usize(rem_opcode.global_opcode(), [0, 0, 0, 1, 1]), + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(b), + Some(c), ); - let (q, r, b_sign, c_sign, q_sign, case) = - run_divrem::(signed, &b, &c); - let q = prank_vals.q.unwrap_or(q); - let r = prank_vals.r.unwrap_or(r); - let carries = - run_mul_carries::(signed, &c, &q, &r, q_sign); - - range_tuple_chip.clear(); - for i in 0..RV32_REGISTER_NUM_LIMBS { - range_tuple_chip.add_count(&[q[i], carries[i]]); - range_tuple_chip.add_count(&[r[i], carries[i + RV32_REGISTER_NUM_LIMBS]]); - } - - if let Some(diff_val) = prank_vals.diff_val { - bitwise_chip.clear(); - if signed { - let b_sign_mask = if b_sign { 1 << (RV32_CELL_BITS - 1) } else { 0 }; - let c_sign_mask = if c_sign { 1 << (RV32_CELL_BITS - 1) } else { 0 }; - bitwise_chip.request_range( - (b[RV32_REGISTER_NUM_LIMBS - 1] - b_sign_mask) << 1, - (c[RV32_REGISTER_NUM_LIMBS - 1] - c_sign_mask) << 1, - ); - } - if case == DivRemCoreSpecialCase::None { - bitwise_chip.request_range(diff_val - 1, 0); - } - } - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - + let adapter_width = BaseAir::::width(&harness.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut DivRemCoreCols = @@ -338,21 +304,17 @@ fn run_rv32_divrem_negative_test( cols.r_zero = F::from_bool(r_zero); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) - .load(range_tuple_chip) + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) + .load_periphery(range_tuple) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -363,7 +325,8 @@ fn rv32_divrem_unsigned_wrong_q_negative_test() { q: Some([245, 168, 7, 0]), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -376,7 +339,8 @@ fn rv32_divrem_unsigned_wrong_r_negative_test() { diff_val: Some(31), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -387,7 +351,8 @@ fn rv32_divrem_unsigned_high_mult_negative_test() { q: Some([128, 0, 0, 1]), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -400,7 +365,8 @@ fn rv32_divrem_unsigned_zero_divisor_wrong_r_negative_test() { diff_val: Some(255), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -411,7 +377,8 @@ fn rv32_divrem_signed_wrong_q_negative_test() { q: Some([74, 61, 255, 255]), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -424,7 +391,8 @@ fn rv32_divrem_signed_wrong_r_negative_test() { diff_val: Some(20), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -435,7 +403,8 @@ fn rv32_divrem_signed_high_mult_negative_test() { q: Some([1, 0, 0, 1]), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -449,7 +418,8 @@ fn rv32_divrem_signed_r_wrong_sign_negative_test() { diff_val: Some(192), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -463,7 +433,8 @@ fn rv32_divrem_signed_r_wrong_prime_negative_test() { diff_val: Some(36), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -476,7 +447,8 @@ fn rv32_divrem_signed_zero_divisor_wrong_r_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -491,8 +463,10 @@ fn rv32_divrem_false_zero_divisor_flag_negative_test() { zero_divisor: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -507,8 +481,10 @@ fn rv32_divrem_false_r_zero_flag_negative_test() { r_zero: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -519,8 +495,10 @@ fn rv32_divrem_unset_zero_divisor_flag_negative_test() { zero_divisor: Some(false), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -532,8 +510,10 @@ fn rv32_divrem_wrong_r_zero_flag_negative_test() { r_zero: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -544,8 +524,10 @@ fn rv32_divrem_unset_r_zero_flag_negative_test() { r_zero: Some(false), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index f8dd2fbf54..af3264a120 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -1,108 +1,43 @@ +use std::sync::Arc; + use derive_more::derive::From; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, - VmInventoryError, + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, }, - system::phantom::PhantomChip, + system::{memory::SharedMemoryHelper, SystemPort}, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, MeteredExecutor}; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, + range_tuple::{ + RangeTupleCheckerAir, RangeTupleCheckerBus, RangeTupleCheckerChip, + SharedRangeTupleCheckerChip, + }, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_rv32im_transpiler::{ BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, DivRemOpcode, LessThanOpcode, MulHOpcode, MulOpcode, Rv32AuipcOpcode, Rv32HintStoreOpcode, Rv32JalLuiOpcode, Rv32JalrOpcode, Rv32LoadStoreOpcode, Rv32Phantom, ShiftOpcode, }; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; use crate::{adapters::*, *}; -/// Config for a VM with base extension and IO extension -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Rv32IConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub base: Rv32I, - #[extension] - pub io: Rv32Io, -} - -// Default implementation uses no init file -impl InitFileGenerator for Rv32IConfig {} - -/// Config for a VM with base extension, IO extension, and multiplication extension -#[derive(Clone, Debug, Default, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Rv32ImConfig { - #[config] - pub rv32i: Rv32IConfig, - #[extension] - pub mul: Rv32M, -} - -// Default implementation uses no init file -impl InitFileGenerator for Rv32ImConfig {} - -impl Default for Rv32IConfig { - fn default() -> Self { - let system = SystemConfig::default().with_continuations(); - Self { - system, - base: Default::default(), - io: Default::default(), - } - } -} - -impl Rv32IConfig { - pub fn with_public_values(public_values: usize) -> Self { - let system = SystemConfig::default() - .with_continuations() - .with_public_values(public_values); - Self { - system, - base: Default::default(), - io: Default::default(), - } - } - - pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { - let system = SystemConfig::default() - .with_continuations() - .with_public_values(public_values) - .with_max_segment_len(segment_len); - Self { - system, - base: Default::default(), - io: Default::default(), - } - } -} - -impl Rv32ImConfig { - pub fn with_public_values(public_values: usize) -> Self { - Self { - rv32i: Rv32IConfig::with_public_values(public_values), - mul: Default::default(), - } - } - - pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { - Self { - rv32i: Rv32IConfig::with_public_values_and_segment_len(public_values, segment_len), - mul: Default::default(), - } - } -} - -// ============ Extension Implementations ============ +// ============ Extension Struct Definitions ============ /// RISC-V 32-bit Base (RV32I) Extension #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] @@ -134,361 +69,630 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { // ============ Executor and Periphery Enums for Extension ============ /// RISC-V 32-bit Base (RV32I) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum Rv32IExecutor { +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Rv32IExecutor { // Rv32 (for standard 32-bit integers): - BaseAlu(Rv32BaseAluChip), - LessThan(Rv32LessThanChip), - Shift(Rv32ShiftChip), - LoadStore(Rv32LoadStoreChip), - LoadSignExtend(Rv32LoadSignExtendChip), - BranchEqual(Rv32BranchEqualChip), - BranchLessThan(Rv32BranchLessThanChip), - JalLui(Rv32JalLuiChip), - Jalr(Rv32JalrChip), - Auipc(Rv32AuipcChip), + BaseAlu(Rv32BaseAluExecutor), + LessThan(Rv32LessThanExecutor), + Shift(Rv32ShiftExecutor), + LoadStore(Rv32LoadStoreExecutor), + LoadSignExtend(Rv32LoadSignExtendExecutor), + BranchEqual(Rv32BranchEqualExecutor), + BranchLessThan(Rv32BranchLessThanExecutor), + JalLui(Rv32JalLuiExecutor), + Jalr(Rv32JalrExecutor), + Auipc(Rv32AuipcExecutor), } /// RISC-V 32-bit Multiplication Extension (RV32M) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum Rv32MExecutor { - Multiplication(Rv32MultiplicationChip), - MultiplicationHigh(Rv32MulHChip), - DivRem(Rv32DivRemChip), +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Rv32MExecutor { + Multiplication(Rv32MultiplicationExecutor), + MultiplicationHigh(Rv32MulHExecutor), + DivRem(Rv32DivRemExecutor), } /// RISC-V 32-bit Io Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum Rv32IoExecutor { - HintStore(Rv32HintStoreChip), +#[derive(Clone, Copy, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Rv32IoExecutor { + HintStore(Rv32HintStoreExecutor), } -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Rv32IPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - // We put this only to get the generic to work - Phantom(PhantomChip), -} +// ============ VmExtension Implementations ============ -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Rv32MPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - /// Only needed for multiplication extension - RangeTupleChecker(SharedRangeTupleCheckerChip<2>), - // We put this only to get the generic to work - Phantom(PhantomChip), -} +impl VmExecutionExtension for Rv32I { + type Executor = Rv32IExecutor; -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Rv32IoPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - // We put this only to get the generic to work - Phantom(PhantomChip), -} + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); -// ============ VmExtension Implementations ============ + let base_alu = + Rv32BaseAluExecutor::new(Rv32BaseAluAdapterExecutor, BaseAluOpcode::CLASS_OFFSET); + inventory.add_executor(base_alu, BaseAluOpcode::iter().map(|x| x.global_opcode()))?; -impl VmExtension for Rv32I { - type Executor = Rv32IExecutor; - type Periphery = Rv32IPeriphery; + let lt = LessThanExecutor::new(Rv32BaseAluAdapterExecutor, LessThanOpcode::CLASS_OFFSET); + inventory.add_executor(lt, LessThanOpcode::iter().map(|x| x.global_opcode()))?; - fn build( - &self, - builder: &mut VmInventoryBuilder, - ) -> Result, Rv32IPeriphery>, VmInventoryError> { - let mut inventory = VmInventory::new(); + let shift = ShiftExecutor::new(Rv32BaseAluAdapterExecutor, ShiftOpcode::CLASS_OFFSET); + inventory.add_executor(shift, ShiftOpcode::iter().map(|x| x.global_opcode()))?; + + let load_store = LoadStoreExecutor::new( + Rv32LoadStoreAdapterExecutor::new(pointer_max_bits), + Rv32LoadStoreOpcode::CLASS_OFFSET, + ); + inventory.add_executor( + load_store, + Rv32LoadStoreOpcode::iter() + .take(Rv32LoadStoreOpcode::STOREB as usize + 1) + .map(|x| x.global_opcode()), + )?; + + let load_sign_extend = + LoadSignExtendExecutor::new(Rv32LoadStoreAdapterExecutor::new(pointer_max_bits)); + inventory.add_executor( + load_sign_extend, + [Rv32LoadStoreOpcode::LOADB, Rv32LoadStoreOpcode::LOADH].map(|x| x.global_opcode()), + )?; + + let beq = BranchEqualExecutor::new( + Rv32BranchAdapterExecutor, + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ); + inventory.add_executor(beq, BranchEqualOpcode::iter().map(|x| x.global_opcode()))?; + + let blt = BranchLessThanExecutor::new( + Rv32BranchAdapterExecutor, + BranchLessThanOpcode::CLASS_OFFSET, + ); + inventory.add_executor(blt, BranchLessThanOpcode::iter().map(|x| x.global_opcode()))?; + + let jal_lui = Rv32JalLuiExecutor::new(Rv32CondRdWriteAdapterExecutor::new( + Rv32RdWriteAdapterExecutor, + )); + inventory.add_executor(jal_lui, Rv32JalLuiOpcode::iter().map(|x| x.global_opcode()))?; + + let jalr = Rv32JalrExecutor::new(Rv32JalrAdapterExecutor); + inventory.add_executor(jalr, Rv32JalrOpcode::iter().map(|x| x.global_opcode()))?; + + let auipc = Rv32AuipcExecutor::new(Rv32RdWriteAdapterExecutor); + inventory.add_executor(auipc, Rv32AuipcOpcode::iter().map(|x| x.global_opcode()))?; + + // There is no downside to adding phantom sub-executors, so we do it in the base extension. + inventory.add_phantom_sub_executor( + phantom::Rv32HintInputSubEx, + PhantomDiscriminant(Rv32Phantom::HintInput as u16), + )?; + inventory.add_phantom_sub_executor( + phantom::Rv32HintRandomSubEx, + PhantomDiscriminant(Rv32Phantom::HintRandom as u16), + )?; + inventory.add_phantom_sub_executor( + phantom::Rv32PrintStrSubEx, + PhantomDiscriminant(Rv32Phantom::PrintStr as u16), + )?; + inventory.add_phantom_sub_executor( + phantom::Rv32HintLoadByKeySubEx, + PhantomDiscriminant(Rv32Phantom::HintLoadByKey as u16), + )?; + + Ok(()) + } +} + +impl VmCircuitExtension for Rv32I { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { let SystemPort { execution_bus, program_bus, memory_bridge, - } = builder.system_port(); - - let range_checker = builder.system_base().range_checker_chip.clone(); - let offline_memory = builder.system_base().offline_memory(); - let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; - - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip + } = inventory.system().port(); + + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = inventory.range_checker().bus; + let pointer_max_bits = inventory.pointer_max_bits(); + + let bitwise_lu = { + // A trick to get around Rust's borrow rules + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } }; - let base_alu_chip = Rv32BaseAluChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, + let base_alu = Rv32BaseAluAir::new( + Rv32BaseAluAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu), + BaseAluCoreAir::new(bitwise_lu, BaseAluOpcode::CLASS_OFFSET), + ); + inventory.add_air(base_alu); + + let lt = Rv32LessThanAir::new( + Rv32BaseAluAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu), + LessThanCoreAir::new(bitwise_lu, LessThanOpcode::CLASS_OFFSET), + ); + inventory.add_air(lt); + + let shift = Rv32ShiftAir::new( + Rv32BaseAluAdapterAir::new(exec_bridge, memory_bridge, bitwise_lu), + ShiftCoreAir::new(bitwise_lu, range_checker, ShiftOpcode::CLASS_OFFSET), + ); + inventory.add_air(shift); + + let load_store = Rv32LoadStoreAir::new( + Rv32LoadStoreAdapterAir::new( memory_bridge, - bitwise_lu_chip.clone(), + exec_bridge, + range_checker, + pointer_max_bits, ), - BaseAluCoreChip::new(bitwise_lu_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - offline_memory.clone(), + LoadStoreCoreAir::new(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - inventory.add_executor( - base_alu_chip, - BaseAluOpcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_air(load_store); - let lt_chip = Rv32LessThanChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, + let load_sign_extend = Rv32LoadSignExtendAir::new( + Rv32LoadStoreAdapterAir::new( memory_bridge, - bitwise_lu_chip.clone(), + exec_bridge, + range_checker, + pointer_max_bits, ), - LessThanCoreChip::new(bitwise_lu_chip.clone(), LessThanOpcode::CLASS_OFFSET), - offline_memory.clone(), + LoadSignExtendCoreAir::new(range_checker), ); - inventory.add_executor(lt_chip, LessThanOpcode::iter().map(|x| x.global_opcode()))?; + inventory.add_air(load_sign_extend); - let shift_chip = Rv32ShiftChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - bitwise_lu_chip.clone(), + let beq = Rv32BranchEqualAir::new( + Rv32BranchAdapterAir::new(exec_bridge, memory_bridge), + BranchEqualCoreAir::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ); + inventory.add_air(beq); + + let blt = Rv32BranchLessThanAir::new( + Rv32BranchAdapterAir::new(exec_bridge, memory_bridge), + BranchLessThanCoreAir::new(bitwise_lu, BranchLessThanOpcode::CLASS_OFFSET), + ); + inventory.add_air(blt); + + let jal_lui = Rv32JalLuiAir::new( + Rv32CondRdWriteAdapterAir::new(Rv32RdWriteAdapterAir::new(memory_bridge, exec_bridge)), + Rv32JalLuiCoreAir::new(bitwise_lu), + ); + inventory.add_air(jal_lui); + + let jalr = Rv32JalrAir::new( + Rv32JalrAdapterAir::new(memory_bridge, exec_bridge), + Rv32JalrCoreAir::new(bitwise_lu, range_checker), + ); + inventory.add_air(jalr); + + let auipc = Rv32AuipcAir::new( + Rv32RdWriteAdapterAir::new(memory_bridge, exec_bridge), + Rv32AuipcCoreAir::new(bitwise_lu), + ); + inventory.add_air(auipc); + + Ok(()) + } +} + +pub struct Rv32ImCpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Rv32ImCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Rv32I, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + // These calls to next_air are not strictly necessary to construct the chips, but provide a + // safeguard to ensure that chip construction matches the circuit definition + inventory.next_air::()?; + let base_alu = Rv32BaseAluChip::new( + BaseAluFiller::new( + Rv32BaseAluAdapterFiller::new(bitwise_lu.clone()), + bitwise_lu.clone(), + BaseAluOpcode::CLASS_OFFSET, ), - ShiftCoreChip::new( - bitwise_lu_chip.clone(), + mem_helper.clone(), + ); + inventory.add_executor_chip(base_alu); + + inventory.next_air::()?; + let lt = Rv32LessThanChip::new( + LessThanFiller::new( + Rv32BaseAluAdapterFiller::new(bitwise_lu.clone()), + bitwise_lu.clone(), + LessThanOpcode::CLASS_OFFSET, + ), + mem_helper.clone(), + ); + inventory.add_executor_chip(lt); + + inventory.next_air::()?; + let shift = Rv32ShiftChip::new( + ShiftFiller::new( + Rv32BaseAluAdapterFiller::new(bitwise_lu.clone()), + bitwise_lu.clone(), range_checker.clone(), ShiftOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor(shift_chip, ShiftOpcode::iter().map(|x| x.global_opcode()))?; + inventory.add_executor_chip(shift); + inventory.next_air::()?; let load_store_chip = Rv32LoadStoreChip::new( - Rv32LoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - pointer_max_bits, - range_checker.clone(), + LoadStoreFiller::new( + Rv32LoadStoreAdapterFiller::new(pointer_max_bits, range_checker.clone()), + Rv32LoadStoreOpcode::CLASS_OFFSET, ), - LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - load_store_chip, - Rv32LoadStoreOpcode::iter() - .take(Rv32LoadStoreOpcode::STOREB as usize + 1) - .map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(load_store_chip); - let load_sign_extend_chip = Rv32LoadSignExtendChip::new( - Rv32LoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - pointer_max_bits, + inventory.next_air::()?; + let load_sign_extend = Rv32LoadSignExtendChip::new( + LoadSignExtendFiller::new( + Rv32LoadStoreAdapterFiller::new(pointer_max_bits, range_checker.clone()), range_checker.clone(), ), - LoadSignExtendCoreChip::new(range_checker.clone()), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - load_sign_extend_chip, - [Rv32LoadStoreOpcode::LOADB, Rv32LoadStoreOpcode::LOADH].map(|x| x.global_opcode()), - )?; - - let beq_chip = Rv32BranchEqualChip::new( - Rv32BranchAdapterChip::new(execution_bus, program_bus, memory_bridge), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + inventory.add_executor_chip(load_sign_extend); + + inventory.next_air::()?; + let beq = Rv32BranchEqualChip::new( + BranchEqualFiller::new( + Rv32BranchAdapterFiller, + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + mem_helper.clone(), ); - inventory.add_executor( - beq_chip, - BranchEqualOpcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(beq); - let blt_chip = Rv32BranchLessThanChip::new( - Rv32BranchAdapterChip::new(execution_bus, program_bus, memory_bridge), - BranchLessThanCoreChip::new( - bitwise_lu_chip.clone(), + inventory.next_air::()?; + let blt = Rv32BranchLessThanChip::new( + BranchLessThanFiller::new( + Rv32BranchAdapterFiller, + bitwise_lu.clone(), BranchLessThanOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - blt_chip, - BranchLessThanOpcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(blt); - let jal_lui_chip = Rv32JalLuiChip::new( - Rv32CondRdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32JalLuiCoreChip::new(bitwise_lu_chip.clone()), - offline_memory.clone(), + inventory.next_air::()?; + let jal_lui = Rv32JalLuiChip::new( + Rv32JalLuiFiller::new( + Rv32CondRdWriteAdapterFiller::new(Rv32RdWriteAdapterFiller), + bitwise_lu.clone(), + ), + mem_helper.clone(), ); - inventory.add_executor( - jal_lui_chip, - Rv32JalLuiOpcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(jal_lui); - let jalr_chip = Rv32JalrChip::new( - Rv32JalrAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32JalrCoreChip::new(bitwise_lu_chip.clone(), range_checker.clone()), - offline_memory.clone(), + inventory.next_air::()?; + let jalr = Rv32JalrChip::new( + Rv32JalrFiller::new( + Rv32JalrAdapterFiller, + bitwise_lu.clone(), + range_checker.clone(), + ), + mem_helper.clone(), ); - inventory.add_executor(jalr_chip, Rv32JalrOpcode::iter().map(|x| x.global_opcode()))?; + inventory.add_executor_chip(jalr); - let auipc_chip = Rv32AuipcChip::new( - Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32AuipcCoreChip::new(bitwise_lu_chip.clone()), - offline_memory.clone(), + inventory.next_air::()?; + let auipc = Rv32AuipcChip::new( + Rv32AuipcFiller::new(Rv32RdWriteAdapterFiller, bitwise_lu.clone()), + mem_helper.clone(), ); - inventory.add_executor( - auipc_chip, - Rv32AuipcOpcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(auipc); - // There is no downside to adding phantom sub-executors, so we do it in the base extension. - builder.add_phantom_sub_executor( - phantom::Rv32HintInputSubEx, - PhantomDiscriminant(Rv32Phantom::HintInput as u16), - )?; - builder.add_phantom_sub_executor( - phantom::Rv32HintRandomSubEx::new(), - PhantomDiscriminant(Rv32Phantom::HintRandom as u16), - )?; - builder.add_phantom_sub_executor( - phantom::Rv32PrintStrSubEx, - PhantomDiscriminant(Rv32Phantom::PrintStr as u16), - )?; - builder.add_phantom_sub_executor( - phantom::Rv32HintLoadByKeySubEx, - PhantomDiscriminant(Rv32Phantom::HintLoadByKey as u16), - )?; - - Ok(inventory) + Ok(()) } } -impl VmExtension for Rv32M { - type Executor = Rv32MExecutor; - type Periphery = Rv32MPeriphery; +impl VmExecutionExtension for Rv32M { + type Executor = Rv32MExecutor; - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, Rv32MPeriphery>, VmInventoryError> { - let mut inventory = VmInventory::new(); + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let mult = + Rv32MultiplicationExecutor::new(Rv32MultAdapterExecutor, MulOpcode::CLASS_OFFSET); + inventory.add_executor(mult, MulOpcode::iter().map(|x| x.global_opcode()))?; + + let mul_h = Rv32MulHExecutor::new(Rv32MultAdapterExecutor, MulHOpcode::CLASS_OFFSET); + inventory.add_executor(mul_h, MulHOpcode::iter().map(|x| x.global_opcode()))?; + + let div_rem = Rv32DivRemExecutor::new(Rv32MultAdapterExecutor, DivRemOpcode::CLASS_OFFSET); + inventory.add_executor(div_rem, DivRemOpcode::iter().map(|x| x.global_opcode()))?; + + Ok(()) + } +} + +impl VmCircuitExtension for Rv32M { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { let SystemPort { execution_bus, program_bus, memory_bridge, - } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); - - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip + } = inventory.system().port(); + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + + let bitwise_lu = { + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } + }; + + let range_tuple_checker = { + let existing_air = inventory.find_air::>().find(|c| { + c.bus.sizes[0] >= self.range_tuple_checker_sizes[0] + && c.bus.sizes[1] >= self.range_tuple_checker_sizes[1] + }); + if let Some(air) = existing_air { + air.bus + } else { + let bus = RangeTupleCheckerBus::new( + inventory.new_bus_idx(), + self.range_tuple_checker_sizes, + ); + let air = RangeTupleCheckerAir { bus }; + inventory.add_air(air); + air.bus + } }; - let range_tuple_checker = if let Some(chip) = builder - .find_chip::>() - .into_iter() - .find(|c| { - c.bus().sizes[0] >= self.range_tuple_checker_sizes[0] - && c.bus().sizes[1] >= self.range_tuple_checker_sizes[1] - }) { - chip.clone() - } else { - let range_tuple_bus = - RangeTupleCheckerBus::new(builder.new_bus_idx(), self.range_tuple_checker_sizes); - let chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - inventory.add_periphery_chip(chip.clone()); - chip + let mult = Rv32MultiplicationAir::new( + Rv32MultAdapterAir::new(exec_bridge, memory_bridge), + MultiplicationCoreAir::new(range_tuple_checker, MulOpcode::CLASS_OFFSET), + ); + inventory.add_air(mult); + + let mul_h = Rv32MulHAir::new( + Rv32MultAdapterAir::new(exec_bridge, memory_bridge), + MulHCoreAir::new(bitwise_lu, range_tuple_checker), + ); + inventory.add_air(mul_h); + + let div_rem = Rv32DivRemAir::new( + Rv32MultAdapterAir::new(exec_bridge, memory_bridge), + DivRemCoreAir::new(bitwise_lu, range_tuple_checker, DivRemOpcode::CLASS_OFFSET), + ); + inventory.add_air(div_rem); + + Ok(()) + } +} + +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Rv32ImCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + extension: &Rv32M, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } }; - let mul_chip = Rv32MultiplicationChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - MultiplicationCoreChip::new(range_tuple_checker.clone(), MulOpcode::CLASS_OFFSET), - offline_memory.clone(), + let range_tuple_checker = { + let existing_chip = inventory + .find_chip::>() + .find(|c| { + c.bus().sizes[0] >= extension.range_tuple_checker_sizes[0] + && c.bus().sizes[1] >= extension.range_tuple_checker_sizes[1] + }); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &RangeTupleCheckerAir<2> = inventory.next_air()?; + let chip = SharedRangeTupleCheckerChip::new(RangeTupleCheckerChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + // These calls to next_air are not strictly necessary to construct the chips, but provide a + // safeguard to ensure that chip construction matches the circuit definition + inventory.next_air::()?; + let mult = Rv32MultiplicationChip::new( + MultiplicationFiller::new( + Rv32MultAdapterFiller, + range_tuple_checker.clone(), + MulOpcode::CLASS_OFFSET, + ), + mem_helper.clone(), ); - inventory.add_executor(mul_chip, MulOpcode::iter().map(|x| x.global_opcode()))?; + inventory.add_executor_chip(mult); - let mul_h_chip = Rv32MulHChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - MulHCoreChip::new(bitwise_lu_chip.clone(), range_tuple_checker.clone()), - offline_memory.clone(), + inventory.next_air::()?; + let mul_h = Rv32MulHChip::new( + MulHFiller::new( + Rv32MultAdapterFiller, + bitwise_lu.clone(), + range_tuple_checker.clone(), + ), + mem_helper.clone(), ); - inventory.add_executor(mul_h_chip, MulHOpcode::iter().map(|x| x.global_opcode()))?; + inventory.add_executor_chip(mul_h); - let div_rem_chip = Rv32DivRemChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - DivRemCoreChip::new( - bitwise_lu_chip.clone(), + inventory.next_air::()?; + let div_rem = Rv32DivRemChip::new( + DivRemFiller::new( + Rv32MultAdapterFiller, + bitwise_lu.clone(), range_tuple_checker.clone(), DivRemOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + mem_helper.clone(), ); - inventory.add_executor( - div_rem_chip, - DivRemOpcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(div_rem); - Ok(inventory) + Ok(()) } } -impl VmExtension for Rv32Io { - type Executor = Rv32IoExecutor; - type Periphery = Rv32IoPeriphery; +impl VmExecutionExtension for Rv32Io { + type Executor = Rv32IoExecutor; - fn build( + fn extend_execution( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + let hint_store = + Rv32HintStoreExecutor::new(pointer_max_bits, Rv32HintStoreOpcode::CLASS_OFFSET); + inventory.add_executor( + hint_store, + Rv32HintStoreOpcode::iter().map(|x| x.global_opcode()), + )?; + + Ok(()) + } +} + +impl VmCircuitExtension for Rv32Io { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { let SystemPort { execution_bus, program_bus, memory_bridge, - } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); - - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip + } = inventory.system().port(); + + let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); + let pointer_max_bits = inventory.pointer_max_bits(); + + let bitwise_lu = { + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } }; - let mut hintstore_chip = Rv32HintStoreChip::new( - execution_bus, - program_bus, - bitwise_lu_chip.clone(), + let hint_store = Rv32HintStoreAir::new( + exec_bridge, memory_bridge, - offline_memory.clone(), - builder.system_config().memory_config.pointer_max_bits, + bitwise_lu, Rv32HintStoreOpcode::CLASS_OFFSET, + pointer_max_bits, ); - hintstore_chip.set_streams(builder.streams().clone()); + inventory.add_air(hint_store); - inventory.add_executor( - hintstore_chip, - Rv32HintStoreOpcode::iter().map(|x| x.global_opcode()), - )?; + Ok(()) + } +} - Ok(inventory) +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Rv32ImCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Rv32Io, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + inventory.next_air::()?; + let hint_store = Rv32HintStoreChip::new( + Rv32HintStoreFiller::new(pointer_max_bits, bitwise_lu.clone()), + mem_helper.clone(), + ); + inventory.add_executor_chip(hint_store); + + Ok(()) } } @@ -497,34 +701,28 @@ mod phantom { use eyre::bail; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::PhantomDiscriminant; use openvm_stark_backend::p3_field::{Field, PrimeField32}; - use rand::{rngs::OsRng, Rng}; + use rand::{rngs::StdRng, Rng}; - use crate::adapters::unsafe_read_rv32_register; + use crate::adapters::{memory_read, read_rv32_register}; pub struct Rv32HintInputSubEx; - pub struct Rv32HintRandomSubEx { - rng: OsRng, - } - impl Rv32HintRandomSubEx { - pub fn new() -> Self { - Self { rng: OsRng } - } - } + pub struct Rv32HintRandomSubEx; pub struct Rv32PrintStrSubEx; pub struct Rv32HintLoadByKeySubEx; impl PhantomSubExecutor for Rv32HintInputSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let mut hint = match streams.input_stream.pop_front() { @@ -550,18 +748,19 @@ mod phantom { impl PhantomSubExecutor for Rv32HintRandomSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + rng: &mut StdRng, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, _: u16, ) -> eyre::Result<()> { - let len = unsafe_read_rv32_register(memory, a) as usize; + let len = read_rv32_register(memory, a) as usize; streams.hint_stream.clear(); streams.hint_stream.extend( - std::iter::repeat_with(|| F::from_canonical_u8(self.rng.gen::())).take(len * 4), + std::iter::repeat_with(|| F::from_canonical_u8(rng.gen::())).take(len * 4), ); Ok(()) } @@ -569,23 +768,20 @@ mod phantom { impl PhantomSubExecutor for Rv32PrintStrSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, _: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, _: u16, ) -> eyre::Result<()> { - let rd = unsafe_read_rv32_register(memory, a); - let rs1 = unsafe_read_rv32_register(memory, b); + let rd = read_rv32_register(memory, a); + let rs1 = read_rv32_register(memory, b); let bytes = (0..rs1) - .map(|i| -> eyre::Result { - let val = memory.unsafe_read_cell(F::TWO, F::from_canonical_u32(rd + i)); - let byte: u8 = val.as_canonical_u32().try_into()?; - Ok(byte) - }) - .collect::>>()?; + .map(|i| memory_read::<1>(memory, 2, rd + i)[0]) + .collect::>(); let peeked_str = String::from_utf8(bytes)?; print!("{peeked_str}"); Ok(()) @@ -594,22 +790,19 @@ mod phantom { impl PhantomSubExecutor for Rv32HintLoadByKeySubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, _: u16, ) -> eyre::Result<()> { - let ptr = unsafe_read_rv32_register(memory, a); - let len = unsafe_read_rv32_register(memory, b); + let ptr = read_rv32_register(memory, a); + let len = read_rv32_register(memory, b); let key: Vec = (0..len) - .map(|i| { - memory - .unsafe_read_cell(F::TWO, F::from_canonical_u32(ptr + i)) - .as_canonical_u32() as u8 - }) + .map(|i| memory_read::<1>(memory, 2, ptr + i)[0]) .collect(); if let Some(val) = streams.kv_store.get(&key) { let to_push = hint_load_by_key_decode::(val); diff --git a/extensions/rv32im/circuit/src/hintstore/mod.rs b/extensions/rv32im/circuit/src/hintstore/mod.rs index d566292207..35f71bead5 100644 --- a/extensions/rv32im/circuit/src/hintstore/mod.rs +++ b/extensions/rv32im/circuit/src/hintstore/mod.rs @@ -1,25 +1,21 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - sync::{Arc, Mutex, OnceLock}, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ - arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams, - }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, + arch::*, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - utils::{next_power_of_two_or_zero, not}, + utils::not, }; -use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow}; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, @@ -31,18 +27,15 @@ use openvm_rv32im_transpiler::{ Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW}, }; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, - rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + p3_maybe_rayon::prelude::*, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use serde::{Deserialize, Serialize}; -use crate::adapters::{compose, decompose}; +use crate::adapters::{read_rv32_register, tracing_read, tracing_write}; #[cfg(test)] mod tests; @@ -70,7 +63,7 @@ pub struct Rv32HintStoreCols { pub num_words_aux_cols: MemoryReadAuxCols, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct Rv32HintStoreAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, @@ -182,7 +175,6 @@ impl Air for Rv32HintStoreAir { &local_cols.write_aux, ) .eval(builder, is_valid.clone()); - let expected_opcode = (local_cols.is_single * AB::F::from_canonical_usize(HINT_STOREW as usize + self.offset)) + (local_cols.is_buffer @@ -264,265 +256,485 @@ impl Air for Rv32HintStoreAir { } } -#[derive(Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32HintStoreRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - pub mem_ptr_read: RecordId, - pub mem_ptr: u32, +#[derive(Copy, Clone, Debug)] +pub struct Rv32HintStoreMetadata { + num_words: usize, +} + +impl MultiRowMetadata for Rv32HintStoreMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_words + } +} + +pub type Rv32HintStoreLayout = MultiRowLayout; + +// This is the part of the record that we keep only once per instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32HintStoreRecordHeader { pub num_words: u32, - pub num_words_read: Option, - pub hints: Vec<([F; RV32_REGISTER_NUM_LIMBS], RecordId)>, + pub from_pc: u32, + pub timestamp: u32, + + pub mem_ptr_ptr: u32, + pub mem_ptr: u32, + pub mem_ptr_aux_record: MemoryReadAuxRecord, + + // will set `num_words_ptr` to `u32::MAX` in case of single hint + pub num_words_ptr: u32, + pub num_words_read: MemoryReadAuxRecord, } -pub struct Rv32HintStoreChip { - air: Rv32HintStoreAir, - pub records: Vec>, - pub height: usize, - offline_memory: Arc>>, - pub streams: OnceLock>>>, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +// This is the part of the record that we keep `num_words` times per instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32HintStoreVar { + pub data_write_aux: MemoryWriteBytesAuxRecord, + pub data: [u8; RV32_REGISTER_NUM_LIMBS], } -impl Rv32HintStoreChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - memory_bridge: MemoryBridge, - offline_memory: Arc>>, - pointer_max_bits: usize, - offset: usize, - ) -> Self { - let air = Rv32HintStoreAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_operation_lookup_bus: bitwise_lookup_chip.bus(), - offset, - pointer_max_bits, - }; - Self { - records: vec![], - air, - height: 0, - offline_memory, - streams: OnceLock::new(), - bitwise_lookup_chip, +/// **SAFETY**: the order of the fields in `Rv32HintStoreRecord` and `Rv32HintStoreVar` is +/// important. The chip also assumes that the offset of the fields `write_aux` and `data` in +/// `Rv32HintStoreCols` is bigger than `size_of::()` +#[derive(Debug)] +pub struct Rv32HintStoreRecordMut<'a> { + pub inner: &'a mut Rv32HintStoreRecordHeader, + pub var: &'a mut [Rv32HintStoreVar], +} + +/// Custom borrowing that splits the buffer into a fixed `Rv32HintStoreRecord` header +/// followed by a slice of `Rv32HintStoreVar`'s of length `num_words` provided at runtime. +/// Uses `align_to_mut()` to make sure the slice is properly aligned to `Rv32HintStoreVar`. +/// Has debug assertions to make sure the above works as expected. +impl<'a> CustomBorrow<'a, Rv32HintStoreRecordMut<'a>, Rv32HintStoreLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: Rv32HintStoreLayout) -> Rv32HintStoreRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + + let (_, vars, _) = unsafe { rest.align_to_mut::() }; + Rv32HintStoreRecordMut { + inner: header_buf.borrow_mut(), + var: &mut vars[..layout.metadata.num_words], } } - pub fn set_streams(&mut self, streams: Arc>>) { - self.streams - .set(streams) - .map_err(|_| "streams have already been set.") - .unwrap(); + + unsafe fn extract_layout(&self) -> Rv32HintStoreLayout { + let header: &Rv32HintStoreRecordHeader = self.borrow(); + MultiRowLayout::new(Rv32HintStoreMetadata { + num_words: header.num_words as usize, + }) + } +} + +impl SizedRecord for Rv32HintStoreRecordMut<'_> { + fn size(layout: &Rv32HintStoreLayout) -> usize { + let mut total_len = size_of::(); + // Align the pointer to the alignment of `Rv32HintStoreVar` + total_len = total_len.next_multiple_of(align_of::()); + total_len += size_of::() * layout.metadata.num_words; + total_len } + + fn alignment(_layout: &Rv32HintStoreLayout) -> usize { + align_of::() + } +} + +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32HintStoreExecutor { + pub pointer_max_bits: usize, + pub offset: usize, +} + +#[derive(Clone, derive_new::new)] +pub struct Rv32HintStoreFiller { + pointer_max_bits: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl InstructionExecutor for Rv32HintStoreChip { +impl PreflightExecutor for Rv32HintStoreExecutor +where + F: PrimeField32, + for<'buf> RA: + RecordArena<'buf, MultiRowLayout, Rv32HintStoreRecordMut<'buf>>, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + if opcode == HINT_STOREW.global_opcode().as_usize() { + String::from("HINT_STOREW") + } else if opcode == HINT_BUFFER.global_opcode().as_usize() { + String::from("HINT_BUFFER") + } else { + unreachable!("unsupported opcode: {}", opcode) + } + } + fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + ) -> Result<(), ExecutionError> { let &Instruction { - opcode, - a: num_words_ptr, - b: mem_ptr_ptr, - d, - e, - .. + opcode, a, b, d, e, .. } = instruction; + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let local_opcode = - Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let (mem_ptr_read, mem_ptr_limbs) = memory.read::(d, mem_ptr_ptr); - let (num_words, num_words_read) = if local_opcode == HINT_STOREW { - memory.increment_timestamp(); - (1, None) + let local_opcode = Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + // We do untraced read of `num_words` in order to allocate the record first + let num_words = if local_opcode == HINT_STOREW { + 1 } else { - let (num_words_read, num_words_limbs) = - memory.read::(d, num_words_ptr); - (compose(num_words_limbs), Some(num_words_read)) + read_rv32_register(state.memory.data(), a) }; - debug_assert_ne!(num_words, 0); - debug_assert!(num_words <= (1 << self.air.pointer_max_bits)); - let mem_ptr = compose(mem_ptr_limbs); + let record = state.ctx.alloc(MultiRowLayout::new(Rv32HintStoreMetadata { + num_words: num_words as usize, + })); - debug_assert!(mem_ptr <= (1 << self.air.pointer_max_bits)); + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp; + record.inner.mem_ptr_ptr = b; - let mut streams = self.streams.get().unwrap().lock().unwrap(); - if streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { - return Err(ExecutionError::HintOutOfBounds { pc: from_state.pc }); - } + record.inner.mem_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + b, + &mut record.inner.mem_ptr_aux_record.prev_timestamp, + )); - let mut record = Rv32HintStoreRecord { - from_state, - instruction: instruction.clone(), - mem_ptr_read, - mem_ptr, - num_words, - num_words_read, - hints: vec![], + debug_assert!(record.inner.mem_ptr <= (1 << self.pointer_max_bits)); + debug_assert_ne!(num_words, 0); + debug_assert!(num_words <= (1 << self.pointer_max_bits)); + + record.inner.num_words = num_words; + if local_opcode == HINT_STOREW { + state.memory.increment_timestamp(); + record.inner.num_words_ptr = u32::MAX; + } else { + record.inner.num_words_ptr = a; + tracing_read::( + state.memory, + RV32_REGISTER_AS, + record.inner.num_words_ptr, + &mut record.inner.num_words_read.prev_timestamp, + ); }; - for word_index in 0..num_words { - if word_index != 0 { - memory.increment_timestamp(); - memory.increment_timestamp(); + if state.streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); + } + + for idx in 0..(num_words as usize) { + if idx != 0 { + state.memory.increment_timestamp(); + state.memory.increment_timestamp(); } - let data: [F; RV32_REGISTER_NUM_LIMBS] = - std::array::from_fn(|_| streams.hint_stream.pop_front().unwrap()); - let (write, _) = memory.write( - e, - F::from_canonical_u32(mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index)), + let data_f: [F; RV32_REGISTER_NUM_LIMBS] = + std::array::from_fn(|_| state.streams.hint_stream.pop_front().unwrap()); + let data: [u8; RV32_REGISTER_NUM_LIMBS] = + data_f.map(|byte| byte.as_canonical_u32() as u8); + + record.var[idx].data = data; + + tracing_write( + state.memory, + RV32_MEMORY_AS, + record.inner.mem_ptr + (RV32_REGISTER_NUM_LIMBS * idx) as u32, data, + &mut record.var[idx].data_write_aux.prev_timestamp, + &mut record.var[idx].data_write_aux.prev_data, ); - record.hints.push((data, write)); } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - self.height += record.hints.len(); - self.records.push(record); - - let next_state = ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }; - Ok(next_state) + Ok(()) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - if opcode == HINT_STOREW.global_opcode().as_usize() { - String::from("HINT_STOREW") - } else if opcode == HINT_BUFFER.global_opcode().as_usize() { - String::from("HINT_BUFFER") - } else { - unreachable!("unsupported opcode: {}", opcode) +impl TraceFiller for Rv32HintStoreFiller { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; } + + let width = trace.width; + debug_assert_eq!(width, size_of::>()); + let mut trace = &mut trace.values[..width * rows_used]; + let mut sizes = Vec::with_capacity(rows_used); + let mut chunks = Vec::with_capacity(rows_used); + + while !trace.is_empty() { + let record: &Rv32HintStoreRecordHeader = + unsafe { get_record_from_slice(&mut trace, ()) }; + let (chunk, rest) = trace.split_at_mut(width * record.num_words as usize); + sizes.push(record.num_words); + chunks.push(chunk); + trace = rest; + } + + let msl_rshift: u32 = ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; + let msl_lshift: u32 = + (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits) as u32; + + chunks + .par_iter_mut() + .zip(sizes.par_iter()) + .for_each(|(chunk, &num_words)| { + let record: Rv32HintStoreRecordMut = unsafe { + get_record_from_slice( + chunk, + MultiRowLayout::new(Rv32HintStoreMetadata { + num_words: num_words as usize, + }), + ) + }; + self.bitwise_lookup_chip.request_range( + (record.inner.mem_ptr >> msl_rshift) << msl_lshift, + (num_words >> msl_rshift) << msl_lshift, + ); + + let mut timestamp = record.inner.timestamp + num_words * 3; + let mut mem_ptr = record.inner.mem_ptr + num_words * RV32_REGISTER_NUM_LIMBS as u32; + + // Assuming that `num_words` is usually small (e.g. 1 for `HINT_STOREW`) + // it is better to do a serial pass of the rows per instruction (going from the last + // row to the first row) instead of a parallel pass, since need to + // copy the record to a new buffer in parallel case. + chunk + .rchunks_exact_mut(width) + .zip(record.var.iter().enumerate().rev()) + .for_each(|(row, (idx, var))| { + for pair in var.data.chunks_exact(2) { + self.bitwise_lookup_chip + .request_range(pair[0] as u32, pair[1] as u32); + } + + let cols: &mut Rv32HintStoreCols = row.borrow_mut(); + let is_single = record.inner.num_words_ptr == u32::MAX; + timestamp -= 3; + if idx == 0 && !is_single { + mem_helper.fill( + record.inner.num_words_read.prev_timestamp, + timestamp + 1, + cols.num_words_aux_cols.as_mut(), + ); + cols.num_words_ptr = F::from_canonical_u32(record.inner.num_words_ptr); + } else { + mem_helper.fill_zero(cols.num_words_aux_cols.as_mut()); + cols.num_words_ptr = F::ZERO; + } + + cols.is_buffer_start = F::from_bool(idx == 0 && !is_single); + + // Note: writing in reverse + cols.data = var.data.map(|x| F::from_canonical_u8(x)); + + cols.write_aux.set_prev_data( + var.data_write_aux + .prev_data + .map(|x| F::from_canonical_u8(x)), + ); + mem_helper.fill( + var.data_write_aux.prev_timestamp, + timestamp + 2, + cols.write_aux.as_mut(), + ); + + if idx == 0 { + mem_helper.fill( + record.inner.mem_ptr_aux_record.prev_timestamp, + timestamp, + cols.mem_ptr_aux_cols.as_mut(), + ); + } else { + mem_helper.fill_zero(cols.mem_ptr_aux_cols.as_mut()); + } + + mem_ptr -= RV32_REGISTER_NUM_LIMBS as u32; + cols.mem_ptr_limbs = mem_ptr.to_le_bytes().map(|x| F::from_canonical_u8(x)); + cols.mem_ptr_ptr = F::from_canonical_u32(record.inner.mem_ptr_ptr); + + cols.from_state.timestamp = F::from_canonical_u32(timestamp); + cols.from_state.pc = F::from_canonical_u32(record.inner.from_pc); + + cols.rem_words_limbs = (num_words - idx as u32) + .to_le_bytes() + .map(|x| F::from_canonical_u8(x)); + cols.is_buffer = F::from_bool(!is_single); + cols.is_single = F::from_bool(is_single); + }); + }) } } -impl ChipUsageGetter for Rv32HintStoreChip { - fn air_name(&self) -> String { - "Rv32HintStoreAir".to_string() - } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct HintStorePreCompute { + c: u32, + a: u8, + b: u8, +} - fn current_trace_height(&self) -> usize { - self.height +impl Executor for Rv32HintStoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn trace_width(&self) -> usize { - Rv32HintStoreCols::::width() + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut HintStorePreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match local_opcode { + HINT_STOREW => execute_e1_impl::<_, _, true>, + HINT_BUFFER => execute_e1_impl::<_, _, false>, + }; + Ok(fn_ptr) } } -impl Rv32HintStoreChip { - // returns number of used u32s - fn record_to_rows( - record: Rv32HintStoreRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - bitwise_lookup_chip: &SharedBitwiseOperationLookupChip, - pointer_max_bits: usize, - ) -> usize { - let width = Rv32HintStoreCols::::width(); - let cols: &mut Rv32HintStoreCols = slice[..width].borrow_mut(); - - cols.is_single = F::from_bool(record.num_words_read.is_none()); - cols.is_buffer = F::from_bool(record.num_words_read.is_some()); - cols.is_buffer_start = cols.is_buffer; - - cols.from_state = record.from_state.map(F::from_canonical_u32); - cols.mem_ptr_ptr = record.instruction.b; - aux_cols_factory.generate_read_aux( - memory.record_by_id(record.mem_ptr_read), - &mut cols.mem_ptr_aux_cols, - ); +impl MeteredExecutor for Rv32HintStoreExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } - cols.num_words_ptr = record.instruction.a; - if let Some(num_words_read) = record.num_words_read { - aux_cols_factory.generate_read_aux( - memory.record_by_id(num_words_read), - &mut cols.num_words_aux_cols, - ); - } + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match local_opcode { + HINT_STOREW => execute_e2_impl::<_, _, true>, + HINT_BUFFER => execute_e2_impl::<_, _, false>, + }; + Ok(fn_ptr) + } +} - let mut mem_ptr = record.mem_ptr; - let mut rem_words = record.num_words; - let mut used_u32s = 0; +/// Return the number of used rows. +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &HintStorePreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let mem_ptr_limbs = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let mem_ptr = u32::from_le_bytes(mem_ptr_limbs); + + let num_words = if IS_HINT_STOREW { + 1 + } else { + let num_words_limbs = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + u32::from_le_bytes(num_words_limbs) + }; + debug_assert_ne!(num_words, 0); + + if vm_state.streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { + vm_state.exit_code = Err(ExecutionError::HintOutOfBounds { pc: vm_state.pc }); + return 0; + } - let mem_ptr_msl = mem_ptr >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); - let rem_words_msl = rem_words >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); - bitwise_lookup_chip.request_range( - mem_ptr_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits), - rem_words_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits), + for word_index in 0..num_words { + let data: [u8; RV32_REGISTER_NUM_LIMBS] = std::array::from_fn(|_| { + vm_state + .streams + .hint_stream + .pop_front() + .unwrap() + .as_canonical_u32() as u8 + }); + vm_state.vm_write( + RV32_MEMORY_AS, + mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index), + &data, ); - for (i, &(data, write)) in record.hints.iter().enumerate() { - for half in 0..(RV32_REGISTER_NUM_LIMBS / 2) { - bitwise_lookup_chip.request_range( - data[2 * half].as_canonical_u32(), - data[2 * half + 1].as_canonical_u32(), - ); - } - - let cols: &mut Rv32HintStoreCols = slice[used_u32s..used_u32s + width].borrow_mut(); - cols.from_state.timestamp = - F::from_canonical_u32(record.from_state.timestamp + (3 * i as u32)); - cols.data = data; - aux_cols_factory.generate_write_aux(memory.record_by_id(write), &mut cols.write_aux); - cols.rem_words_limbs = decompose(rem_words); - cols.mem_ptr_limbs = decompose(mem_ptr); - if i != 0 { - cols.is_buffer = F::ONE; - } - used_u32s += width; - mem_ptr += RV32_REGISTER_NUM_LIMBS as u32; - rem_words -= 1; - } - - used_u32s } - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + num_words +} - let memory = self.offline_memory.lock().unwrap(); +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &HintStorePreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} - let aux_cols_factory = memory.aux_cols_factory(); +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height_delta = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height_delta); +} - let mut used_u32s = 0; - for record in self.records { - used_u32s += Self::record_to_rows( - record, - &aux_cols_factory, - &mut flat_trace[used_u32s..], - &memory, - &self.bitwise_lookup_chip, - self.air.pointer_max_bits, - ); +impl Rv32HintStoreExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut HintStorePreCompute, + ) -> Result { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + if d.as_canonical_u32() != RV32_REGISTER_AS || e.as_canonical_u32() != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); } - // padding rows can just be all zeros - RowMajorMatrix::new(flat_trace, width) + *data = { + HintStorePreCompute { + c: c.as_canonical_u32(), + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + } + }; + Ok(Rv32HintStoreOpcode::from_usize( + opcode.local_opcode_idx(self.offset), + )) } } -impl Chip for Rv32HintStoreChip> -where - Val: PrimeField32, -{ - fn air(&self) -> Arc> { - Arc::new(self.air) - } - fn generate_air_proof_input(self) -> AirProofInput { - AirProofInput::simple_no_pis(self.generate_trace()) - } -} +pub type Rv32HintStoreChip = VmChipWrapper; diff --git a/extensions/rv32im/circuit/src/hintstore/tests.rs b/extensions/rv32im/circuit/src/hintstore/tests.rs index 204070762c..363fabab96 100644 --- a/extensions/rv32im/circuit/src/hintstore/tests.rs +++ b/extensions/rv32im/circuit/src/hintstore/tests.rs @@ -1,20 +1,17 @@ -use std::{ - array, - borrow::BorrowMut, - sync::{Arc, Mutex}, -}; +use std::{borrow::BorrowMut, sync::Arc}; use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - Streams, + testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + Arena, DenseRecordArena, MatrixRecordArena, PreflightExecutor, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; use openvm_instructions::{ instruction::Instruction, - riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - VmOpcode, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, }; use openvm_rv32im_transpiler::Rv32HintStoreOpcode::{self, *}; use openvm_stark_backend::{ @@ -24,104 +21,100 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::{rngs::StdRng, Rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng, RngCore}; -use super::{Rv32HintStoreChip, Rv32HintStoreCols}; -use crate::adapters::decompose; +use super::{Rv32HintStoreAir, Rv32HintStoreChip, Rv32HintStoreCols, Rv32HintStoreExecutor}; +use crate::{test_utils::get_verification_error, Rv32HintStoreFiller, Rv32HintStoreLayout}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 4096; +type Harness = + TestChipHarness, RA>; -fn set_and_execute( +fn create_test_chip( tester: &mut VmChipTestBuilder, - chip: &mut Rv32HintStoreChip, - rng: &mut StdRng, - opcode: Rv32HintStoreOpcode, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), ) { - let mem_ptr = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - 2)), - ) << 2; - let b = gen_pointer(rng, 4); - - tester.write(1, b, decompose(mem_ptr)); - - let read_data: [F; RV32_REGISTER_NUM_LIMBS] = - array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS)))); - for data in read_data { - chip.streams - .get() - .unwrap() - .lock() - .unwrap() - .hint_stream - .push_back(data); - } + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); - tester.execute( - chip, - &Instruction::from_usize(VmOpcode::from_usize(opcode as usize), [0, b, 0, 1, 2]), + let air = Rv32HintStoreAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_chip.bus(), + Rv32HintStoreOpcode::CLASS_OFFSET, + tester.address_bits(), ); + let executor = + Rv32HintStoreExecutor::new(tester.address_bits(), Rv32HintStoreOpcode::CLASS_OFFSET); + let chip = Rv32HintStoreChip::::new( + Rv32HintStoreFiller::new(tester.address_bits(), bitwise_chip.clone()), + tester.memory_helper(), + ); + + let harness = Harness::::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - let write_data = read_data; - assert_eq!(write_data, tester.read::<4>(2, mem_ptr as usize)); + (harness, (bitwise_chip.air, bitwise_chip)) } -fn set_and_execute_buffer( +fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut Rv32HintStoreChip, + harness: &mut Harness, rng: &mut StdRng, opcode: Rv32HintStoreOpcode, -) { - let mem_ptr = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - 2)), - ) << 2; - let b = gen_pointer(rng, 4); - - tester.write(1, b, decompose(mem_ptr)); - - let num_words = rng.gen_range(1..20); - let a = gen_pointer(rng, 4); - tester.write(1, a, decompose(num_words)); - - let data: Vec<[F; RV32_REGISTER_NUM_LIMBS]> = (0..num_words) - .map(|_| array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS))))) - .collect(); - for i in 0..num_words { - for datum in data[i as usize] { - chip.streams - .get() - .unwrap() - .lock() - .unwrap() - .hint_stream - .push_back(datum); - } +) where + Rv32HintStoreExecutor: PreflightExecutor, +{ + let num_words = match opcode { + HINT_STOREW => 1, + HINT_BUFFER => rng.gen_range(1..28), + } as u32; + + let a = if opcode == HINT_BUFFER { + let a = gen_pointer(rng, RV32_REGISTER_NUM_LIMBS); + tester.write( + RV32_REGISTER_AS as usize, + a, + num_words.to_le_bytes().map(F::from_canonical_u8), + ); + a + } else { + 0 + }; + + let mem_ptr = gen_pointer(rng, 4) as u32; + let b = gen_pointer(rng, RV32_REGISTER_NUM_LIMBS); + tester.write(1, b, mem_ptr.to_le_bytes().map(F::from_canonical_u8)); + + let mut input = Vec::with_capacity(num_words as usize * 4); + for _ in 0..num_words { + let data = rng.next_u32().to_le_bytes().map(F::from_canonical_u8); + input.extend(data); + tester.streams.hint_stream.extend(data); } tester.execute( - chip, - &Instruction::from_usize(VmOpcode::from_usize(opcode as usize), [a, b, 0, 1, 2]), + harness, + &Instruction::from_usize( + opcode.global_opcode(), + [a, b, 0, RV32_REGISTER_AS as usize, RV32_MEMORY_AS as usize], + ), ); - for i in 0..num_words { - assert_eq!( - data[i as usize], - tester.read::<4>(2, mem_ptr as usize + (i as usize * RV32_REGISTER_NUM_LIMBS)) - ); + for idx in 0..num_words as usize { + let data = tester.read::<4>(RV32_MEMORY_AS as usize, mem_ptr as usize + idx * 4); + + let expected: [F; 4] = input[idx * 4..(idx + 1) * 4].try_into().unwrap(); + assert_eq!(data, expected); } } @@ -131,39 +124,28 @@ fn set_and_execute_buffer( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// + #[test] fn rand_hintstore_test() { - setup_tracing(); let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); - - let num_tests: usize = 8; - for _ in 0..num_tests { - if rng.gen_bool(0.5) { - set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); + let (mut harness, bitwise) = create_test_chip(&mut tester); + let num_ops: usize = 100; + for _ in 0..num_ops { + let opcode = if rng.gen_bool(0.5) { + HINT_STOREW } else { - set_and_execute_buffer(&mut tester, &mut chip, &mut rng, HINT_BUFFER); - } + HINT_BUFFER + }; + set_and_execute(&mut tester, &mut harness, &mut rng, opcode); } - drop(range_checker_chip); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } @@ -171,64 +153,44 @@ fn rand_hintstore_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// #[allow(clippy::too_many_arguments)] fn run_negative_hintstore_test( opcode: Rv32HintStoreOpcode, - data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - expected_error: VerificationError, + prank_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise) = create_test_chip(&mut tester); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); - - set_and_execute(&mut tester, &mut chip, &mut rng, opcode); + set_and_execute(&mut tester, &mut harness, &mut rng, opcode); let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); let cols: &mut Rv32HintStoreCols = trace_row.as_mut_slice().borrow_mut(); - if let Some(data) = data { + if let Some(data) = prank_data { cols.data = data.map(F::from_canonical_u32); } *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn negative_hintstore_tests() { - run_negative_hintstore_test( - HINT_STOREW, - Some([92, 187, 45, 280]), - VerificationError::ChallengePhaseError, - ); + run_negative_hintstore_test(HINT_STOREW, Some([92, 187, 45, 280]), true); } + /////////////////////////////////////////////////////////////////////////////////////// /// SANITY TESTS /// @@ -239,22 +201,47 @@ fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let (mut harness, _) = create_test_chip::>(&mut tester); - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut harness, &mut rng, HINT_STOREW); + } +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_harness, bitwise) = create_test_chip::>(&mut tester); + + { + let mut dense_harness = create_test_chip::(&mut tester).0; - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); + let num_ops: usize = 100; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut dense_harness, &mut rng, HINT_STOREW); + } + + let mut record_interpreter = dense_harness + .arena + .get_record_seeker::<_, Rv32HintStoreLayout>(); + record_interpreter.transfer_to_matrix_arena(&mut sparse_harness.arena); } + + let tester = tester + .build() + .load(sparse_harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); } diff --git a/extensions/rv32im/circuit/src/jal_lui/core.rs b/extensions/rv32im/circuit/src/jal_lui/core.rs index 2ba10e615e..5e856deabe 100644 --- a/extensions/rv32im/circuit/src/jal_lui/core.rs +++ b/extensions/rv32im/circuit/src/jal_lui/core.rs @@ -1,19 +1,21 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; - -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_REGISTER_AS, LocalOpcode, }; use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *}; @@ -23,9 +25,13 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS}; +use crate::adapters::{ + Rv32CondRdWriteAdapterExecutor, Rv32CondRdWriteAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS, +}; + +pub(super) const ADDITIONAL_BITS: u32 = 0b11000000; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -36,7 +42,7 @@ pub struct Rv32JalLuiCoreCols { pub is_lui: T, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, derive_new::new)] pub struct Rv32JalLuiCoreAir { pub bus: BitwiseOperationLookupBus, } @@ -141,134 +147,286 @@ where } #[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32JalLuiCoreRecord { - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS], - pub imm: F, +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32JalLuiCoreRecord { + pub imm: u32, + pub rd_data: [u8; RV32_REGISTER_NUM_LIMBS], pub is_jal: bool, - pub is_lui: bool, } -pub struct Rv32JalLuiCoreChip { - pub air: Rv32JalLuiCoreAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32JalLuiExecutor { + adapter: A, } -impl Rv32JalLuiCoreChip { - pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip) -> Self { - Self { - air: Rv32JalLuiCoreAir { - bus: bitwise_lookup_chip.bus(), - }, - bitwise_lookup_chip, - } - } +#[derive(Clone, derive_new::new)] +pub struct Rv32JalLuiFiller { + adapter: A, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl> VmCoreChip for Rv32JalLuiCoreChip +impl PreflightExecutor for Rv32JalLuiExecutor where - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceExecutor, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut Rv32JalLuiCoreRecord), + >, { - type Record = Rv32JalLuiCoreRecord; - type Air = Rv32JalLuiCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - from_pc: u32, - _reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = Rv32JalLuiOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET), - ); - let imm = instruction.c; - - let signed_imm = match local_opcode { - JAL => { - // Note: signed_imm is a signed integer and imm is a field element - (imm + F::from_canonical_u32(1 << (RV_J_TYPE_IMM_BITS - 1))).as_canonical_u32() - as i32 - - (1 << (RV_J_TYPE_IMM_BITS - 1)) - } - LUI => imm.as_canonical_u32() as i32, - }; - let (to_pc, rd_data) = run_jal_lui(local_opcode, from_pc, signed_imm); + ) -> Result<(), ExecutionError> { + let &Instruction { opcode, c: imm, .. } = instruction; - for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) { + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let is_jal = opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET) == JAL as usize; + let signed_imm = get_signed_imm(is_jal, imm); + + let (to_pc, rd_data) = run_jal_lui(is_jal, *state.pc, signed_imm); + + core_record.imm = imm.as_canonical_u32(); + core_record.rd_data = rd_data; + core_record.is_jal = is_jal; + + self.adapter + .write(state.memory, instruction, rd_data, &mut adapter_record); + + *state.pc = to_pc; + + Ok(()) + } +} + +impl TraceFiller for Rv32JalLuiFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &Rv32JalLuiCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut Rv32JalLuiCoreCols = core_row.borrow_mut(); + + for pair in record.rd_data.chunks_exact(2) { self.bitwise_lookup_chip - .request_range(rd_data[i * 2], rd_data[i * 2 + 1]); + .request_range(pair[0] as u32, pair[1] as u32); } - - if local_opcode == JAL { - let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); - let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x)); + if record.is_jal { self.bitwise_lookup_chip - .request_xor(rd_data[3], additional_bits); + .request_xor(record.rd_data[3] as u32, ADDITIONAL_BITS); } - let rd_data = rd_data.map(F::from_canonical_u32); + // Writing in reverse order + core_row.is_lui = F::from_bool(!record.is_jal); + core_row.is_jal = F::from_bool(record.is_jal); + core_row.rd_data = record.rd_data.map(F::from_canonical_u8); + core_row.imm = F::from_canonical_u32(record.imm); + } +} - let output = AdapterRuntimeContext { - to_pc: Some(to_pc), - writes: [rd_data].into(), +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct JalLuiPreCompute { + signed_imm: i32, + a: u8, +} + +impl Executor for Rv32JalLuiExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut JalLuiPreCompute = data.borrow_mut(); + let (is_jal, enabled) = self.pre_compute_impl(inst, data)?; + let fn_ptr = match (is_jal, enabled) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, }; + Ok(fn_ptr) + } +} - Ok(( - output, - Rv32JalLuiCoreRecord { - rd_data, - imm, - is_jal: local_opcode == JAL, - is_lui: local_opcode == LUI, - }, - )) +impl MeteredExecutor for Rv32JalLuiExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET) - ) + fn metered_pre_compute( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (is_jal, enabled) = self.pre_compute_impl(inst, &mut data.data)?; + let fn_ptr = match (is_jal, enabled) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_JAL: bool, + const ENABLED: bool, +>( + pre_compute: &JalLuiPreCompute, + vm_state: &mut VmExecState, +) { + let JalLuiPreCompute { a, signed_imm } = *pre_compute; + + let rd = if IS_JAL { + let rd_data = (vm_state.pc + DEFAULT_PC_STEP).to_le_bytes(); + let next_pc = vm_state.pc as i32 + signed_imm; + debug_assert!(next_pc >= 0); + vm_state.pc = next_pc as u32; + rd_data + } else { + let imm = signed_imm as u32; + let rd = imm << 12; + vm_state.pc += DEFAULT_PC_STEP; + rd.to_le_bytes() + }; + + if ENABLED { + vm_state.vm_write(RV32_REGISTER_AS, a as u32, &rd); } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut Rv32JalLuiCoreCols = row_slice.borrow_mut(); - core_cols.rd_data = record.rd_data; - core_cols.imm = record.imm; - core_cols.is_jal = F::from_bool(record.is_jal); - core_cols.is_lui = F::from_bool(record.is_lui); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_JAL: bool, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &JalLuiPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const IS_JAL: bool, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32JalLuiExecutor { + /// Return (IS_JAL, ENABLED) + #[inline(always)] + fn pre_compute_impl( + &self, + inst: &Instruction, + data: &mut JalLuiPreCompute, + ) -> Result<(bool, bool), StaticProgramError> { + let local_opcode = Rv32JalLuiOpcode::from_usize( + inst.opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET), + ); + let is_jal = local_opcode == JAL; + let imm_f = inst.c.as_canonical_u32(); + let signed_imm = if is_jal { + if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) { + imm_f as i32 + } else { + let neg_imm_f = F::ORDER_U32 - imm_f; + assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1))); + -(neg_imm_f as i32) + } + } else { + imm_f as i32 + }; + + *data = JalLuiPreCompute { + signed_imm, + a: inst.a.as_canonical_u32() as u8, + }; + let enabled = !inst.f.is_zero(); + Ok((is_jal, enabled)) } +} - fn air(&self) -> &Self::Air { - &self.air +// returns the canonical signed representation of the immediate +// `imm` can be "negative" as a field element +pub(super) fn get_signed_imm(is_jal: bool, imm: F) -> i32 { + let imm_f = imm.as_canonical_u32(); + if is_jal { + if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) { + imm_f as i32 + } else { + let neg_imm_f = F::ORDER_U32 - imm_f; + debug_assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1))); + -(neg_imm_f as i32) + } + } else { + imm_f as i32 } } // returns (to_pc, rd_data) -pub(super) fn run_jal_lui( - opcode: Rv32JalLuiOpcode, - pc: u32, - imm: i32, -) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) { - match opcode { - JAL => { - let rd_data = array::from_fn(|i| { - ((pc + DEFAULT_PC_STEP) >> (8 * i)) & ((1 << RV32_CELL_BITS) - 1) - }); - let next_pc = pc as i32 + imm; - assert!(next_pc >= 0); - (next_pc as u32, rd_data) - } - LUI => { - let imm = imm as u32; - let rd = imm << 12; - let rd_data = - array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & ((1 << RV32_CELL_BITS) - 1)); - (pc + DEFAULT_PC_STEP, rd_data) - } +#[inline(always)] +pub(super) fn run_jal_lui(is_jal: bool, pc: u32, imm: i32) -> (u32, [u8; RV32_REGISTER_NUM_LIMBS]) { + if is_jal { + let rd_data = (pc + DEFAULT_PC_STEP).to_le_bytes(); + let next_pc = pc as i32 + imm; + debug_assert!(next_pc >= 0); + (next_pc as u32, rd_data) + } else { + let imm = imm as u32; + let rd = imm << 12; + (pc + DEFAULT_PC_STEP, rd.to_le_bytes()) } } diff --git a/extensions/rv32im/circuit/src/jal_lui/mod.rs b/extensions/rv32im/circuit/src/jal_lui/mod.rs index 779b710bea..4a403aa36a 100644 --- a/extensions/rv32im/circuit/src/jal_lui/mod.rs +++ b/extensions/rv32im/circuit/src/jal_lui/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use crate::adapters::Rv32CondRdWriteAdapterChip; +use crate::adapters::Rv32CondRdWriteAdapterAir; mod core; pub use core::*; @@ -8,4 +8,5 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32JalLuiChip = VmChipWrapper, Rv32JalLuiCoreChip>; +pub type Rv32JalLuiAir = VmAirWrapper; +pub type Rv32JalLuiChip = VmChipWrapper; diff --git a/extensions/rv32im/circuit/src/jal_lui/tests.rs b/extensions/rv32im/circuit/src/jal_lui/tests.rs index 35e258cbfb..2751b9eedd 100644 --- a/extensions/rv32im/circuit/src/jal_lui/tests.rs +++ b/extensions/rv32im/circuit/src/jal_lui/tests.rs @@ -1,41 +1,85 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, sync::Arc}; -use openvm_circuit::arch::{ - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmAdapterChip, -}; +use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{run_jal_lui, Rv32JalLuiChip, Rv32JalLuiCoreChip}; +use super::{run_jal_lui, Rv32JalLuiChip, Rv32JalLuiCoreAir, Rv32JalLuiExecutor}; use crate::{ adapters::{ - Rv32CondRdWriteAdapterChip, Rv32CondRdWriteAdapterCols, RV32_CELL_BITS, - RV32_REGISTER_NUM_LIMBS, RV_IS_TYPE_IMM_BITS, + Rv32CondRdWriteAdapterAir, Rv32CondRdWriteAdapterCols, Rv32CondRdWriteAdapterExecutor, + Rv32CondRdWriteAdapterFiller, Rv32RdWriteAdapterAir, Rv32RdWriteAdapterExecutor, + Rv32RdWriteAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_IS_TYPE_IMM_BITS, }, - jal_lui::Rv32JalLuiCoreCols, + jal_lui::{Rv32JalLuiCoreCols, ADDITIONAL_BITS}, + test_utils::get_verification_error, + Rv32JalLuiAir, Rv32JalLuiFiller, }; const IMM_BITS: usize = 20; const LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; +const MAX_INS_CAPACITY: usize = 128; +type Harness = TestChipHarness>; + type F = BabyBear; +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = Rv32JalLuiAir::new( + Rv32CondRdWriteAdapterAir::new(Rv32RdWriteAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + )), + Rv32JalLuiCoreAir::new(bitwise_bus), + ); + let executor = Rv32JalLuiExecutor::new(Rv32CondRdWriteAdapterExecutor::new( + Rv32RdWriteAdapterExecutor, + )); + let chip = Rv32JalLuiChip::::new( + Rv32JalLuiFiller::new( + Rv32CondRdWriteAdapterFiller::new(Rv32RdWriteAdapterFiller), + bitwise_chip.clone(), + ), + tester.memory_helper(), + ); + + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} + fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut Rv32JalLuiChip, + harness: &mut Harness, rng: &mut StdRng, opcode: Rv32JalLuiOpcode, imm: Option, @@ -51,7 +95,7 @@ fn set_and_execute( let needs_write = a != 0 || opcode == LUI; tester.execute_with_pc( - chip, + harness, &Instruction::large_from_isize( opcode.global_opcode(), a as isize, @@ -67,11 +111,11 @@ fn set_and_execute( let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); let final_pc = tester.execution.last_to_pc().as_canonical_u32(); - let (next_pc, rd_data) = run_jal_lui(opcode, initial_pc, imm); + let (next_pc, rd_data) = run_jal_lui(opcode == JAL, initial_pc, imm); let rd_data = if needs_write { rd_data } else { [0; 4] }; assert_eq!(next_pc, final_pc); - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -81,118 +125,98 @@ fn set_and_execute( /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_jal_lui_test() { +#[test_case(JAL, 100)] +#[test_case(LUI, 100)] +fn rand_jal_lui_test(opcode: Rv32JalLuiOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32JalLuiCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut harness, bitwise) = create_test_chip(&tester); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, JAL, None, None); - set_and_execute(&mut tester, &mut chip, &mut rng, LUI, None, None); + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut harness, &mut rng, opcode, None, None); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct JalLuiPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub imm: Option, + pub is_jal: Option, + pub is_lui: Option, + pub needs_write: Option, +} + #[allow(clippy::too_many_arguments)] fn run_negative_jal_lui_test( opcode: Rv32JalLuiOpcode, initial_imm: Option, initial_pc: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - imm: Option, - is_jal: Option, - is_lui: Option, - needs_write: Option, - expected_error: VerificationError, + prank_vals: JalLuiPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let core = Rv32JalLuiCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut harness, bitwise) = create_test_chip(&tester); set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, opcode, initial_imm, initial_pc, ); - let tester = tester.build(); - - let jal_lui_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jal_lui_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = jal_lui_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (adapter_row, core_row) = trace_row.split_at_mut(adapter_width); - let adapter_cols: &mut Rv32CondRdWriteAdapterCols = adapter_row.borrow_mut(); let core_cols: &mut Rv32JalLuiCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(imm) = imm { + if let Some(imm) = prank_vals.imm { core_cols.imm = if imm < 0 { F::NEG_ONE * F::from_canonical_u32((-imm) as u32) } else { F::from_canonical_u32(imm as u32) }; } - if let Some(is_jal) = is_jal { + if let Some(is_jal) = prank_vals.is_jal { core_cols.is_jal = F::from_bool(is_jal); } - if let Some(is_lui) = is_lui { + if let Some(is_lui) = prank_vals.is_lui { core_cols.is_lui = F::from_bool(is_lui); } - - if let Some(needs_write) = needs_write { + if let Some(needs_write) = prank_vals.needs_write { adapter_cols.needs_write = F::from_bool(needs_write); } - *jal_lui_trace = RowMajorMatrix::new(trace_row, jal_lui_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) - .load(bitwise_chip) + .build() + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -201,34 +225,35 @@ fn opcode_flag_negative_test() { JAL, None, None, - None, - None, - Some(false), - Some(true), - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + is_jal: Some(false), + is_lui: Some(true), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( JAL, None, None, - None, - None, - Some(false), - Some(false), - Some(false), - VerificationError::ChallengePhaseError, + JalLuiPrankValues { + is_jal: Some(false), + is_lui: Some(false), + needs_write: Some(false), + ..Default::default() + }, + true, ); run_negative_jal_lui_test( LUI, None, None, - None, - None, - Some(true), - Some(false), - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + is_jal: Some(true), + is_lui: Some(false), + ..Default::default() + }, + false, ); } @@ -238,67 +263,61 @@ fn overflow_negative_tests() { JAL, None, None, - Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - Some([0, LIMB_MAX, LIMB_MAX, LIMB_MAX + 1]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([0, LIMB_MAX, LIMB_MAX, LIMB_MAX + 1]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - None, - Some(-1), - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + imm: Some(-1), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - None, - Some(-28), - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + imm: Some(-28), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( JAL, None, Some(251), - Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), - None, - None, - None, - None, - VerificationError::ChallengePhaseError, + JalLuiPrankValues { + rd_data: Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), + ..Default::default() + }, + true, ); } @@ -307,29 +326,16 @@ fn overflow_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// + #[test] fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32JalLuiCoreChip::new(bitwise_chip); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, JAL, None, None); - set_and_execute(&mut tester, &mut chip, &mut rng, LUI, None, None); - } + let (mut harness, _) = create_test_chip(&tester); set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, LUI, Some((1 << IMM_BITS) - 1), @@ -337,7 +343,7 @@ fn execute_roundtrip_sanity_test() { ); set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, JAL, Some((1 << RV_IS_TYPE_IMM_BITS) - 1), @@ -347,20 +353,25 @@ fn execute_roundtrip_sanity_test() { #[test] fn run_jal_sanity_test() { - let opcode = JAL; let initial_pc = 28120; let imm = -2048; - let (next_pc, rd_data) = run_jal_lui(opcode, initial_pc, imm); + let (next_pc, rd_data) = run_jal_lui(true, initial_pc, imm); assert_eq!(next_pc, 26072); assert_eq!(rd_data, [220, 109, 0, 0]); } #[test] fn run_lui_sanity_test() { - let opcode = LUI; let initial_pc = 456789120; let imm = 853679; - let (next_pc, rd_data) = run_jal_lui(opcode, initial_pc, imm); + let (next_pc, rd_data) = run_jal_lui(false, initial_pc, imm); assert_eq!(next_pc, 456789124); assert_eq!(rd_data, [0, 240, 106, 208]); } + +#[test] +fn test_additional_bits() { + let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1u32 << x)); + assert_eq!(additional_bits, ADDITIONAL_BITS); +} diff --git a/extensions/rv32im/circuit/src/jalr/core.rs b/extensions/rv32im/circuit/src/jalr/core.rs index fd89c1e317..02ff8f3726 100644 --- a/extensions/rv32im/circuit/src/jalr/core.rs +++ b/extensions/rv32im/circuit/src/jalr/core.rs @@ -3,18 +3,23 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, SignedImmInstruction, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_REGISTER_AS, LocalOpcode, }; use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *}; @@ -24,11 +29,10 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use crate::adapters::{compose, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; - -const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; +use crate::adapters::{ + Rv32JalrAdapterExecutor, Rv32JalrAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -46,18 +50,7 @@ pub struct Rv32JalrCoreCols { pub imm_sign: T, } -#[repr(C)] -#[derive(Serialize, Deserialize)] -pub struct Rv32JalrCoreRecord { - pub imm: F, - pub rs1_data: [F; RV32_REGISTER_NUM_LIMBS], - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS - 1], - pub to_pc_least_sig_bit: F, - pub to_pc_limbs: [u32; 2], - pub imm_sign: F, -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct Rv32JalrCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_bus: VariableRangeCheckerBus, @@ -181,127 +174,281 @@ where } } -pub struct Rv32JalrCoreChip { - pub air: Rv32JalrCoreAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32JalrCoreRecord { + pub imm: u16, + pub from_pc: u32, + pub rs1_val: u32, + pub imm_sign: bool, +} + +#[derive(Clone, Copy, derive_new::new)] +pub struct Rv32JalrExecutor { + adapter: A, +} + +#[derive(Clone)] +pub struct Rv32JalrFiller { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl Rv32JalrCoreChip { +impl Rv32JalrFiller { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: SharedVariableRangeCheckerChip, ) -> Self { assert!(range_checker_chip.range_max_bits() >= 16); Self { - air: Rv32JalrCoreAir { - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - range_bus: range_checker_chip.bus(), - }, + adapter, bitwise_lookup_chip, range_checker_chip, } } } -impl> VmCoreChip for Rv32JalrCoreChip +impl PreflightExecutor for Rv32JalrExecutor where - I::Reads: Into<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData = [u8; RV32_REGISTER_NUM_LIMBS], + WriteData = [u8; RV32_REGISTER_NUM_LIMBS], + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut Rv32JalrCoreRecord), + >, { - type Record = Rv32JalrCoreRecord; - type Air = Rv32JalrCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { let Instruction { opcode, c, g, .. } = *instruction; - let local_opcode = - Rv32JalrOpcode::from_usize(opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET)); - let imm = c.as_canonical_u32(); - let imm_sign = g.as_canonical_u32(); - let imm_extended = imm + imm_sign * 0xffff0000; + debug_assert_eq!( + opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET), + JALR as usize + ); - let rs1 = reads.into()[0]; - let rs1_val = compose(rs1); + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); - let (to_pc, rd_data) = run_jalr(local_opcode, from_pc, imm_extended, rs1_val); + A::start(*state.pc, state.memory, &mut adapter_record); + core_record.rs1_val = u32::from_le_bytes(self.adapter.read( + state.memory, + instruction, + &mut adapter_record, + )); + + core_record.imm = c.as_canonical_u32() as u16; + core_record.imm_sign = g.is_one(); + core_record.from_pc = *state.pc; + + let (to_pc, rd_data) = run_jalr( + core_record.from_pc, + core_record.rs1_val, + core_record.imm, + core_record.imm_sign, + ); + + self.adapter + .write(state.memory, instruction, rd_data, &mut adapter_record); + + // RISC-V spec explicitly sets the least significant bit of `to_pc` to 0 + *state.pc = to_pc & !1; + + Ok(()) + } +} +impl TraceFiller for Rv32JalrFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &Rv32JalrCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut Rv32JalrCoreCols = core_row.borrow_mut(); + + let (to_pc, rd_data) = + run_jalr(record.from_pc, record.rs1_val, record.imm, record.imm_sign); + let to_pc_limbs = [(to_pc & ((1 << 16) - 1)) >> 1, to_pc >> 16]; + self.range_checker_chip.add_count(to_pc_limbs[0], 15); + self.range_checker_chip + .add_count(to_pc_limbs[1], PC_BITS - 16); self.bitwise_lookup_chip - .request_range(rd_data[0], rd_data[1]); + .request_range(rd_data[0] as u32, rd_data[1] as u32); + self.range_checker_chip - .add_count(rd_data[2], RV32_CELL_BITS); + .add_count(rd_data[2] as u32, RV32_CELL_BITS); self.range_checker_chip - .add_count(rd_data[3], PC_BITS - RV32_CELL_BITS * 3); - - let mask = (1 << 15) - 1; - let to_pc_least_sig_bit = rs1_val.wrapping_add(imm_extended) & 1; - - let to_pc_limbs = array::from_fn(|i| ((to_pc >> (1 + i * 15)) & mask)); + .add_count(rd_data[3] as u32, PC_BITS - RV32_CELL_BITS * 3); + + // Write in reverse order + core_row.imm_sign = F::from_bool(record.imm_sign); + core_row.to_pc_limbs = to_pc_limbs.map(F::from_canonical_u32); + core_row.to_pc_least_sig_bit = F::from_bool(to_pc & 1 == 1); + // fill_trace_row is called only on valid rows + core_row.is_valid = F::ONE; + core_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8); + core_row + .rd_data + .iter_mut() + .rev() + .zip(rd_data.iter().skip(1).rev()) + .for_each(|(dst, src)| { + *dst = F::from_canonical_u8(*src); + }); + core_row.imm = F::from_canonical_u16(record.imm); + } +} - let rd_data = rd_data.map(F::from_canonical_u32); +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct JalrPreCompute { + imm_extended: u32, + a: u8, + b: u8, +} - let output = AdapterRuntimeContext { - to_pc: Some(to_pc), - writes: [rd_data].into(), +impl Executor for Rv32JalrExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut JalrPreCompute = data.borrow_mut(); + let enabled = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = if enabled { + execute_e1_impl::<_, _, true> + } else { + execute_e1_impl::<_, _, false> }; + Ok(fn_ptr) + } +} - Ok(( - output, - Rv32JalrCoreRecord { - imm: c, - rd_data: array::from_fn(|i| rd_data[i + 1]), - rs1_data: rs1, - to_pc_least_sig_bit: F::from_canonical_u32(to_pc_least_sig_bit), - to_pc_limbs, - imm_sign: g, - }, - )) +impl MeteredExecutor for Rv32JalrExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET) - ) + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let enabled = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = if enabled { + execute_e2_impl::<_, _, true> + } else { + execute_e2_impl::<_, _, false> + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.range_checker_chip.add_count(record.to_pc_limbs[0], 15); - self.range_checker_chip.add_count(record.to_pc_limbs[1], 14); - - let core_cols: &mut Rv32JalrCoreCols = row_slice.borrow_mut(); - core_cols.imm = record.imm; - core_cols.rd_data = record.rd_data; - core_cols.rs1_data = record.rs1_data; - core_cols.to_pc_least_sig_bit = record.to_pc_least_sig_bit; - core_cols.to_pc_limbs = record.to_pc_limbs.map(F::from_canonical_u32); - core_cols.imm_sign = record.imm_sign; - core_cols.is_valid = F::ONE; +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &JalrPreCompute, + vm_state: &mut VmExecState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1 = u32::from_le_bytes(rs1); + let to_pc = rs1.wrapping_add(pre_compute.imm_extended); + let to_pc = to_pc - (to_pc & 1); + debug_assert!(to_pc < (1 << PC_BITS)); + let rd = (vm_state.pc + DEFAULT_PC_STEP).to_le_bytes(); + + if ENABLED { + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); } - fn air(&self) -> &Self::Air { - &self.air + vm_state.pc = to_pc; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &JalrPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32JalrExecutor { + /// Return true if enabled. + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut JalrPreCompute, + ) -> Result { + let imm_extended = inst.c.as_canonical_u32() + inst.g.as_canonical_u32() * 0xffff0000; + if inst.d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = JalrPreCompute { + imm_extended, + a: inst.a.as_canonical_u32() as u8, + b: inst.b.as_canonical_u32() as u8, + }; + let enabled = !inst.f.is_zero(); + Ok(enabled) } } // returns (to_pc, rd_data) -pub(super) fn run_jalr( - _opcode: Rv32JalrOpcode, - pc: u32, - imm: u32, - rs1: u32, -) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) { - let to_pc = rs1.wrapping_add(imm); - let to_pc = to_pc - (to_pc & 1); +#[inline(always)] +pub(super) fn run_jalr(pc: u32, rs1: u32, imm: u16, imm_sign: bool) -> (u32, [u8; 4]) { + let to_pc = rs1.wrapping_add(imm as u32 + (imm_sign as u32 * 0xffff0000)); assert!(to_pc < (1 << PC_BITS)); - ( - to_pc, - array::from_fn(|i: usize| ((pc + DEFAULT_PC_STEP) >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX), - ) + (to_pc, pc.wrapping_add(DEFAULT_PC_STEP).to_le_bytes()) } diff --git a/extensions/rv32im/circuit/src/jalr/mod.rs b/extensions/rv32im/circuit/src/jalr/mod.rs index 1d85dcbe4a..c5102f260f 100644 --- a/extensions/rv32im/circuit/src/jalr/mod.rs +++ b/extensions/rv32im/circuit/src/jalr/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use crate::adapters::Rv32JalrAdapterChip; +use crate::adapters::Rv32JalrAdapterAir; mod core; pub use core::*; @@ -8,4 +8,5 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32JalrChip = VmChipWrapper, Rv32JalrCoreChip>; +pub type Rv32JalrAir = VmAirWrapper; +pub type Rv32JalrChip = VmChipWrapper; diff --git a/extensions/rv32im/circuit/src/jalr/tests.rs b/extensions/rv32im/circuit/src/jalr/tests.rs index e22d97967f..f0ae372e7f 100644 --- a/extensions/rv32im/circuit/src/jalr/tests.rs +++ b/extensions/rv32im/circuit/src/jalr/tests.rs @@ -1,41 +1,83 @@ -use std::{array, borrow::BorrowMut}; +use std::{array, borrow::BorrowMut, sync::Arc}; -use openvm_circuit::arch::{ - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmAdapterChip, -}; +use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use super::Rv32JalrCoreAir; use crate::{ - adapters::{compose, Rv32JalrAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - jalr::{run_jalr, Rv32JalrChip, Rv32JalrCoreChip, Rv32JalrCoreCols}, + adapters::{ + compose, Rv32JalrAdapterAir, Rv32JalrAdapterExecutor, Rv32JalrAdapterFiller, + RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + jalr::{run_jalr, Rv32JalrChip, Rv32JalrCoreCols, Rv32JalrExecutor}, + test_utils::get_verification_error, + Rv32JalrAir, Rv32JalrFiller, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +type Harness = TestChipHarness>; fn into_limbs(num: u32) -> [u32; 4] { array::from_fn(|i| (num >> (8 * i)) & 255) } +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let range_checker_chip = tester.range_checker().clone(); + + let air = Rv32JalrAir::new( + Rv32JalrAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + Rv32JalrCoreAir::new(bitwise_bus, range_checker_chip.bus()), + ); + let executor = Rv32JalrExecutor::new(Rv32JalrAdapterExecutor); + let chip = Rv32JalrChip::::new( + Rv32JalrFiller::new( + Rv32JalrAdapterFiller::new(), + bitwise_chip.clone(), + range_checker_chip.clone(), + ), + tester.memory_helper(), + ); + + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} + #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut Rv32JalrChip, + harness: &mut Harness, rng: &mut StdRng, opcode: Rv32JalrOpcode, initial_imm: Option, @@ -45,7 +87,7 @@ fn set_and_execute( ) { let imm = initial_imm.unwrap_or(rng.gen_range(0..(1 << IMM_BITS))); let imm_sign = initial_imm_sign.unwrap_or(rng.gen_range(0..2)); - let imm_ext = imm + imm_sign * (0xffffffff ^ ((1 << IMM_BITS) - 1)); + let imm_ext = imm + (imm_sign * 0xffff0000); let a = rng.gen_range(0..32) << 2; let b = rng.gen_range(1..32) << 2; let to_pc = rng.gen_range(0..(1 << PC_BITS)); @@ -55,8 +97,9 @@ fn set_and_execute( tester.write(1, b, rs1); + let initial_pc = initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))); tester.execute_with_pc( - chip, + harness, &Instruction::from_usize( opcode.global_opcode(), [ @@ -69,18 +112,17 @@ fn set_and_execute( imm_sign as usize, ], ), - initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), + initial_pc, ); - let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); let final_pc = tester.execution.last_to_pc().as_canonical_u32(); let rs1 = compose(rs1); - let (next_pc, rd_data) = run_jalr(opcode, initial_pc, imm_ext, rs1); + let (next_pc, rd_data) = run_jalr(initial_pc, rs1, imm as u16, imm_sign == 1); let rd_data = if a == 0 { [0; 4] } else { rd_data }; - assert_eq!(next_pc, final_pc); - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + assert_eq!(next_pc & !1, final_pc); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -92,24 +134,14 @@ fn set_and_execute( #[test] fn rand_jalr_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); + let (mut harness, bitwise) = create_test_chip(&mut tester); - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32JalrCoreChip::new(bitwise_chip.clone(), range_checker_chip.clone()); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { + let num_ops = 100; + for _ in 0..num_ops { set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, JALR, None, @@ -119,8 +151,11 @@ fn rand_jalr_test() { ); } - drop(range_checker_chip); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } @@ -128,10 +163,18 @@ fn rand_jalr_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct JalrPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, + pub rs1_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub to_pc_least_sig_bit: Option, + pub to_pc_limbs: Option<[u32; 2]>, + pub imm_sign: Option, +} + #[allow(clippy::too_many_arguments)] fn run_negative_jalr_test( opcode: Rv32JalrOpcode, @@ -139,31 +182,17 @@ fn run_negative_jalr_test( initial_rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, initial_imm: Option, initial_imm_sign: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, - rs1_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - to_pc_least_sig_bit: Option, - to_pc_limbs: Option<[u32; 2]>, - imm_sign: Option, - expected_error: VerificationError, + prank_vals: JalrPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let inner = Rv32JalrCoreChip::new(bitwise_chip.clone(), range_checker_chip.clone()); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); + let (mut harness, bitwise) = create_test_chip(&mut tester); set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, opcode, initial_imm, @@ -172,49 +201,38 @@ fn run_negative_jalr_test( initial_rs1, ); - let tester = tester.build(); - - let jalr_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jalr_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = jalr_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (_, core_row) = trace_row.split_at_mut(adapter_width); - let core_cols: &mut Rv32JalrCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(data) = rs1_data { + if let Some(data) = prank_vals.rs1_data { core_cols.rs1_data = data.map(F::from_canonical_u32); } - - if let Some(data) = to_pc_least_sig_bit { + if let Some(data) = prank_vals.to_pc_least_sig_bit { core_cols.to_pc_least_sig_bit = F::from_canonical_u32(data); } - - if let Some(data) = to_pc_limbs { + if let Some(data) = prank_vals.to_pc_limbs { core_cols.to_pc_limbs = data.map(F::from_canonical_u32); } - - if let Some(data) = imm_sign { + if let Some(data) = prank_vals.imm_sign { core_cols.imm_sign = F::from_canonical_u32(data); } - *jalr_trace = RowMajorMatrix::new(trace_row, jalr_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) - .load(bitwise_chip) + .build() + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -225,12 +243,11 @@ fn invalid_cols_negative_tests() { None, Some(15362), Some(0), - None, - None, - None, - None, - Some(1), - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + imm_sign: Some(1), + ..Default::default() + }, + false, ); run_negative_jalr_test( @@ -239,12 +256,11 @@ fn invalid_cols_negative_tests() { None, Some(15362), Some(1), - None, - None, - None, - None, - Some(0), - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + imm_sign: Some(0), + ..Default::default() + }, + false, ); run_negative_jalr_test( @@ -253,12 +269,11 @@ fn invalid_cols_negative_tests() { Some([23, 154, 67, 28]), Some(42512), Some(1), - None, - None, - Some(0), - None, - None, - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + to_pc_least_sig_bit: Some(0), + ..Default::default() + }, + false, ); } @@ -270,12 +285,11 @@ fn overflow_negative_tests() { None, None, None, - Some([1, 0, 0]), - None, - None, - None, - None, - VerificationError::ChallengePhaseError, + JalrPrankValues { + rd_data: Some([1, 0, 0]), + ..Default::default() + }, + true, ); run_negative_jalr_test( @@ -284,15 +298,14 @@ fn overflow_negative_tests() { Some([0, 0, 0, 0]), Some((1 << 15) - 2), Some(0), - None, - None, - None, - Some([ - (F::NEG_ONE * F::from_canonical_u32((1 << 14) + 1)).as_canonical_u32(), - 1, - ]), - None, - VerificationError::ChallengePhaseError, + JalrPrankValues { + to_pc_limbs: Some([ + (F::NEG_ONE * F::from_canonical_u32((1 << 14) + 1)).as_canonical_u32(), + 1, + ]), + ..Default::default() + }, + true, ); } @@ -301,44 +314,13 @@ fn overflow_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32JalrCoreChip::new(bitwise_chip, range_checker_chip); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - JALR, - None, - None, - None, - None, - ); - } -} #[test] fn run_jalr_sanity_test() { - let opcode = JALR; let initial_pc = 789456120; let imm = -1235_i32 as u32; let rs1 = 736482910; - let (next_pc, rd_data) = run_jalr(opcode, initial_pc, imm, rs1); - assert_eq!(next_pc, 736481674); + let (next_pc, rd_data) = run_jalr(initial_pc, rs1, imm as u16, true); + assert_eq!(next_pc & !1, 736481674); assert_eq!(rd_data, [252, 36, 14, 47]); } diff --git a/extensions/rv32im/circuit/src/less_than/core.rs b/extensions/rv32im/circuit/src/less_than/core.rs index a605dc43de..51d486d705 100644 --- a/extensions/rv32im/circuit/src/less_than/core.rs +++ b/extensions/rv32im/circuit/src/less_than/core.rs @@ -3,16 +3,25 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::LessThanOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,12 +29,12 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; +use crate::adapters::imm_to_bytes; + #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct LessThanCoreCols { pub b: [T; NUM_LIMBS], pub c: [T; NUM_LIMBS], @@ -45,7 +54,7 @@ pub struct LessThanCoreCols { pub diff_val: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct LessThanCoreAir { pub bus: BitwiseOperationLookupBus, offset: usize, @@ -164,162 +173,342 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct LessThanCoreRecord { - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - pub cmp_result: T, - pub b_msb_f: T, - pub c_msb_f: T, - pub diff_val: T, - pub diff_idx: usize, - pub opcode: LessThanOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct LessThanCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, } -pub struct LessThanCoreChip { - pub air: LessThanCoreAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +#[derive(Clone, Copy, derive_new::new)] +pub struct LessThanExecutor { + adapter: A, + pub offset: usize, } -impl LessThanCoreChip { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: LessThanCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, - bitwise_lookup_chip, - } - } +#[derive(Clone, derive_new::new)] +pub struct LessThanFiller { + adapter: A, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for LessThanCoreChip +impl PreflightExecutor + for LessThanExecutor where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut LessThanCoreRecord, + ), + >, { - type Record = LessThanCoreRecord; - type Air = LessThanCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", LessThanOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { + debug_assert!(LIMB_BITS <= 8); let Instruction { opcode, .. } = instruction; - let less_than_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.b = rs1; + core_record.c = rs2; + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + + let (cmp_result, _, _, _) = run_less_than::( + core_record.local_opcode == LessThanOpcode::SLT as u8, + &rs1, + &rs2, + ); + + let mut output = [0u8; NUM_LIMBS]; + output[0] = cmp_result as u8; + + self.adapter.write( + state.memory, + instruction, + [output].into(), + &mut adapter_record, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for LessThanFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &LessThanCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut LessThanCoreCols = core_row.borrow_mut(); + + let is_slt = record.local_opcode == LessThanOpcode::SLT as u8; let (cmp_result, diff_idx, b_sign, c_sign) = - run_less_than::(less_than_opcode, &b, &c); + run_less_than::(is_slt, &record.b, &record.c); // We range check (b_msb_f + 128) and (c_msb_f + 128) if signed, // b_msb_f and c_msb_f if not let (b_msb_f, b_msb_range) = if b_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u16((1u16 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u16), + record.b[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] - + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)), + F::from_canonical_u8(record.b[NUM_LIMBS - 1]), + record.b[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)), ) }; let (c_msb_f, c_msb_range) = if c_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - c[NUM_LIMBS - 1]), - c[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u16((1u16 << LIMB_BITS) - record.c[NUM_LIMBS - 1] as u16), + record.c[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(c[NUM_LIMBS - 1]), - c[NUM_LIMBS - 1] - + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)), + F::from_canonical_u8(record.c[NUM_LIMBS - 1]), + record.c[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)), ) }; - self.bitwise_lookup_chip - .request_range(b_msb_range, c_msb_range); - let diff_val = if diff_idx == NUM_LIMBS { - 0 + core_row.diff_val = if diff_idx == NUM_LIMBS { + F::ZERO } else if diff_idx == (NUM_LIMBS - 1) { if cmp_result { c_msb_f - b_msb_f } else { b_msb_f - c_msb_f } - .as_canonical_u32() } else if cmp_result { - c[diff_idx] - b[diff_idx] + F::from_canonical_u8(record.c[diff_idx] - record.b[diff_idx]) } else { - b[diff_idx] - c[diff_idx] + F::from_canonical_u8(record.b[diff_idx] - record.c[diff_idx]) }; + self.bitwise_lookup_chip + .request_range(b_msb_range as u32, c_msb_range as u32); + + core_row.diff_marker = [F::ZERO; NUM_LIMBS]; if diff_idx != NUM_LIMBS { - self.bitwise_lookup_chip.request_range(diff_val - 1, 0); + self.bitwise_lookup_chip + .request_range(core_row.diff_val.as_canonical_u32() - 1, 0); + core_row.diff_marker[diff_idx] = F::ONE; } - let mut writes = [0u32; NUM_LIMBS]; - writes[0] = cmp_result as u32; - - let output = AdapterRuntimeContext::without_pc([writes.map(F::from_canonical_u32)]); - let record = LessThanCoreRecord { - opcode: less_than_opcode, - b: data[0], - c: data[1], - cmp_result: F::from_bool(cmp_result), - b_msb_f, - c_msb_f, - diff_val: F::from_canonical_u32(diff_val), - diff_idx, - }; + core_row.c_msb_f = c_msb_f; + core_row.b_msb_f = b_msb_f; + core_row.opcode_sltu_flag = F::from_bool(!is_slt); + core_row.opcode_slt_flag = F::from_bool(is_slt); + core_row.cmp_result = F::from_bool(cmp_result); + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + } +} - Ok((output, record)) +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LessThanPreCompute { + c: u32, + a: u8, + b: u8, +} + +impl Executor + for LessThanExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", LessThanOpcode::from_usize(opcode - self.air.offset)) + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut LessThanPreCompute = data.borrow_mut(); + let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match (is_imm, is_sltu) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.cmp_result = record.cmp_result; - row_slice.b_msb_f = record.b_msb_f; - row_slice.c_msb_f = record.c_msb_f; - row_slice.diff_val = record.diff_val; - row_slice.opcode_slt_flag = F::from_bool(record.opcode == LessThanOpcode::SLT); - row_slice.opcode_sltu_flag = F::from_bool(record.opcode == LessThanOpcode::SLTU); - row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx)); +impl MeteredExecutor + for LessThanExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn air(&self) -> &Self::Air { - &self.air + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match (is_imm, is_sltu) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const E_IS_IMM: bool, + const IS_U32: bool, +>( + pre_compute: &LessThanPreCompute, + vm_state: &mut VmExecState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = if E_IS_IMM { + pre_compute.c.to_le_bytes() + } else { + vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c) + }; + let cmp_result = if IS_U32 { + u32::from_le_bytes(rs1) < u32::from_le_bytes(rs2) + } else { + i32::from_le_bytes(rs1) < i32::from_le_bytes(rs2) + }; + let mut rd = [0u8; RV32_REGISTER_NUM_LIMBS]; + rd[0] = cmp_result as u8; + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const E_IS_IMM: bool, + const IS_U32: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &LessThanPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const E_IS_IMM: bool, + const IS_U32: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl LessThanExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LessThanPreCompute, + ) -> Result<(bool, bool), StaticProgramError> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS + || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS) + { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + let local_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let is_imm = e_u32 == RV32_IMM_AS; + let c_u32 = c.as_canonical_u32(); + + *data = LessThanPreCompute { + c: if is_imm { + u32::from_le_bytes(imm_to_bytes(c_u32)) + } else { + c_u32 + }, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok((is_imm, local_opcode == LessThanOpcode::SLTU)) } } // Returns (cmp_result, diff_idx, x_sign, y_sign) +#[inline(always)] pub(super) fn run_less_than( - opcode: LessThanOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + is_slt: bool, + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize, bool, bool) { - let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT; - let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT; + let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt; + let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt; for i in (0..NUM_LIMBS).rev() { if x[i] != y[i] { return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign); diff --git a/extensions/rv32im/circuit/src/less_than/mod.rs b/extensions/rv32im/circuit/src/less_than/mod.rs index f8247d2d33..9819af55f7 100644 --- a/extensions/rv32im/circuit/src/less_than/mod.rs +++ b/extensions/rv32im/circuit/src/less_than/mod.rs @@ -1,6 +1,9 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -8,8 +11,18 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32LessThanAir = + VmAirWrapper>; +pub type Rv32LessThanExecutor = LessThanExecutor< + Rv32BaseAluAdapterExecutor, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, +>; pub type Rv32LessThanChip = VmChipWrapper< F, - Rv32BaseAluAdapterChip, - LessThanCoreChip, + LessThanFiller< + Rv32BaseAluAdapterFiller, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >, >; diff --git a/extensions/rv32im/circuit/src/less_than/tests.rs b/extensions/rv32im/circuit/src/less_than/tests.rs index 18d64bf5f6..c23ac9aba1 100644 --- a/extensions/rv32im/circuit/src/less_than/tests.rs +++ b/extensions/rv32im/circuit/src/less_than/tests.rs @@ -1,17 +1,15 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut, sync::Arc}; use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, - }, - utils::{generate_long_number, i32_to_f}, + arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + utils::i32_to_f, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::LessThanOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::LessThanOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, @@ -20,20 +18,105 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_less_than, LessThanCoreChip, Rv32LessThanChip}; +use super::{core::run_less_than, LessThanCoreAir, Rv32LessThanChip}; use crate::{ - adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, + RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, less_than::LessThanCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, + LessThanFiller, Rv32LessThanAir, Rv32LessThanExecutor, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +type Harness = TestChipHarness>; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + let air = Rv32LessThanAir::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + LessThanCoreAir::new(bitwise_bus, LessThanOpcode::CLASS_OFFSET), + ); + let executor = + Rv32LessThanExecutor::new(Rv32BaseAluAdapterExecutor, LessThanOpcode::CLASS_OFFSET); + let chip = Rv32LessThanChip::::new( + LessThanFiller::new( + Rv32BaseAluAdapterFiller::new(bitwise_chip.clone()), + bitwise_chip.clone(), + LessThanOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + (harness, (bitwise_chip.air, bitwise_chip)) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: LessThanOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) + } else { + generate_rv32_is_type_immediate(rng) + }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(harness, &instruction); + + let (cmp, _, _, _) = + run_less_than::(opcode == SLT, &b, &c); + let mut a = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; + a[0] = F::from_bool(cmp); + assert_eq!(a, tester.read::(1, rd)); +} ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS @@ -42,100 +125,63 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// +#[test_case(SLT, 100)] +#[test_case(SLTU, 100)] fn run_rv32_lt_rand_test(opcode: LessThanOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32LessThanChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - LessThanCoreChip::new(bitwise_chip.clone(), LessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise) = create_test_chip(&tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) - } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) - }; - - let (instruction, rd) = rv32_rand_write_register_or_imm( + set_and_execute( &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), + &mut harness, &mut rng, + opcode, + None, + None, + None, ); - tester.execute(&mut chip, &instruction); - - let (cmp, _, _, _) = - run_less_than::(opcode, &b, &c); - let mut a = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; - a[0] = F::from_bool(cmp); - assert_eq!(a, tester.read::(1, rd)); } // Test special case where b = c let b = [101, 128, 202, 255]; - let (instruction, _) = rv32_rand_write_register_or_imm( + set_and_execute( &mut tester, - b, - b, - None, - opcode.global_opcode().as_usize(), + &mut harness, &mut rng, + opcode, + Some(b), + Some(false), + Some(b), ); - tester.execute(&mut chip, &instruction); let b = [36, 0, 0, 0]; - let (instruction, _) = rv32_rand_write_register_or_imm( + set_and_execute( &mut tester, - b, - b, - Some(36), - opcode.global_opcode().as_usize(), + &mut harness, &mut rng, + opcode, + Some(b), + Some(true), + Some(b), ); - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_slt_rand_test() { - run_rv32_lt_rand_test(LessThanOpcode::SLT, 100); -} - -#[test] -fn rv32_sltu_rand_test() { - run_rv32_lt_rand_test(LessThanOpcode::SLTU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32LessThanTestChip = - VmChipWrapper, LessThanCoreChip>; - #[derive(Clone, Copy, Default, PartialEq)] struct LessThanPrankValues { pub b_msb: Option, @@ -145,67 +191,29 @@ struct LessThanPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_lt_negative_test( +fn run_negative_less_than_test( opcode: LessThanOpcode, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: bool, prank_vals: LessThanPrankValues, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32LessThanTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - LessThanCoreChip::new(bitwise_chip.clone(), LessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise) = create_test_chip(&tester); - tester.execute( - &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(b), + Some(false), + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, _, b_sign, c_sign) = - run_less_than::(opcode, &b, &c); - - if prank_vals != LessThanPrankValues::default() { - debug_assert!(prank_vals.diff_val.is_some()); - let b_msb = prank_vals.b_msb.unwrap_or( - b[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if b_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let c_msb = prank_vals.c_msb.unwrap_or( - c[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if c_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let sign_offset = if opcode == LessThanOpcode::SLT { - 1 << (RV32_CELL_BITS - 1) - } else { - 0 - }; - - bitwise_chip.clear(); - bitwise_chip.request_range( - (b_msb + sign_offset) as u8 as u32, - (c_msb + sign_offset) as u8 as u32, - ); - - let diff_val = prank_vals - .diff_val - .unwrap() - .clamp(0, (1 << RV32_CELL_BITS) - 1); - if diff_val > 0 { - bitwise_chip.request_range(diff_val - 1, 0); - } - }; - + let adapter_width = BaseAir::::width(&harness.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut LessThanCoreCols = @@ -223,22 +231,18 @@ fn run_rv32_lt_negative_test( if let Some(diff_val) = prank_vals.diff_val { cols.diff_val = F::from_canonical_u32(diff_val); } - cols.cmp_result = F::from_bool(cmp_result); + cols.cmp_result = F::from_bool(prank_cmp_result); - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -246,8 +250,8 @@ fn rv32_lt_wrong_false_cmp_negative_test() { let b = [145, 34, 25, 205]; let c = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -255,8 +259,8 @@ fn rv32_lt_wrong_true_cmp_negative_test() { let b = [73, 35, 25, 205]; let c = [145, 34, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -264,8 +268,8 @@ fn rv32_lt_wrong_eq_negative_test() { let b = [73, 35, 25, 205]; let c = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -276,8 +280,8 @@ fn rv32_lt_fake_diff_val_negative_test() { diff_val: Some(F::NEG_ONE.as_canonical_u32()), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } #[test] @@ -289,8 +293,8 @@ fn rv32_lt_zero_diff_val_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } #[test] @@ -302,8 +306,8 @@ fn rv32_lt_fake_diff_marker_negative_test() { diff_val: Some(72), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -315,8 +319,8 @@ fn rv32_lt_zero_diff_marker_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -329,7 +333,7 @@ fn rv32_slt_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); } #[test] @@ -342,7 +346,7 @@ fn rv32_slt_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); } #[test] @@ -355,7 +359,7 @@ fn rv32_slt_wrong_c_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); } #[test] @@ -368,7 +372,7 @@ fn rv32_slt_wrong_c_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, true); + run_negative_less_than_test(SLT, b, c, true, prank_vals, true); } #[test] @@ -381,7 +385,7 @@ fn rv32_sltu_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -394,7 +398,7 @@ fn rv32_sltu_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, true); } #[test] @@ -407,7 +411,7 @@ fn rv32_sltu_wrong_c_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -420,7 +424,7 @@ fn rv32_sltu_wrong_c_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -431,10 +435,10 @@ fn rv32_sltu_wrong_c_msb_sign_negative_test() { #[test] fn run_sltu_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLTU, &x, &y); + run_less_than::(false, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(!x_sign); // unsigned @@ -443,10 +447,10 @@ fn run_sltu_sanity_test() { #[test] fn run_slt_same_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &y); + run_less_than::(true, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(x_sign); // negative @@ -455,10 +459,10 @@ fn run_slt_same_sign_sanity_test() { #[test] fn run_slt_diff_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &y); + run_less_than::(true, &x, &y); assert!(!cmp_result); assert_eq!(diff_idx, 3); assert!(!x_sign); // positive @@ -467,9 +471,9 @@ fn run_slt_diff_sign_sanity_test() { #[test] fn run_less_than_equal_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &x); + run_less_than::(true, &x, &x); assert!(!cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert!(!x_sign); // positive diff --git a/extensions/rv32im/circuit/src/lib.rs b/extensions/rv32im/circuit/src/lib.rs index 2006b27038..d8c34deec5 100644 --- a/extensions/rv32im/circuit/src/lib.rs +++ b/extensions/rv32im/circuit/src/lib.rs @@ -1,5 +1,20 @@ -pub mod adapters; +use openvm_circuit::{ + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, + VmBuilder, VmChipComplex, VmProverExtension, + }, + system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, +}; +use openvm_circuit_derive::{Executor, PreflightExecutor, VmConfig}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use serde::{Deserialize, Serialize}; +pub mod adapters; mod auipc; mod base_alu; mod branch_eq; @@ -35,3 +50,140 @@ pub use extension::*; #[cfg(any(test, feature = "test-utils"))] mod test_utils; + +// Config for a VM with base extension and IO extension +#[derive(Clone, Debug, derive_new::new, VmConfig, Serialize, Deserialize)] +pub struct Rv32IConfig { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension] + pub base: Rv32I, + #[extension] + pub io: Rv32Io, +} + +// Default implementation uses no init file +impl InitFileGenerator for Rv32IConfig {} + +/// Config for a VM with base extension, IO extension, and multiplication extension +#[derive(Clone, Debug, Default, VmConfig, derive_new::new, Serialize, Deserialize)] +pub struct Rv32ImConfig { + #[config] + pub rv32i: Rv32IConfig, + #[extension] + pub mul: Rv32M, +} + +// Default implementation uses no init file +impl InitFileGenerator for Rv32ImConfig {} + +impl Default for Rv32IConfig { + fn default() -> Self { + let system = SystemConfig::default().with_continuations(); + Self { + system, + base: Default::default(), + io: Default::default(), + } + } +} + +impl Rv32IConfig { + pub fn with_public_values(public_values: usize) -> Self { + let system = SystemConfig::default() + .with_continuations() + .with_public_values(public_values); + Self { + system, + base: Default::default(), + io: Default::default(), + } + } + + pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { + let system = SystemConfig::default() + .with_continuations() + .with_public_values(public_values) + .with_max_segment_len(segment_len); + Self { + system, + base: Default::default(), + io: Default::default(), + } + } +} + +impl Rv32ImConfig { + pub fn with_public_values(public_values: usize) -> Self { + Self { + rv32i: Rv32IConfig::with_public_values(public_values), + mul: Default::default(), + } + } + + pub fn with_public_values_and_segment_len(public_values: usize, segment_len: usize) -> Self { + Self { + rv32i: Rv32IConfig::with_public_values_and_segment_len(public_values, segment_len), + mul: Default::default(), + } + } +} + +#[derive(Clone)] +pub struct Rv32ICpuBuilder; + +impl VmBuilder for Rv32ICpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Rv32IConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Rv32IConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.base, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + Ok(chip_complex) + } +} + +#[derive(Clone)] +pub struct Rv32ImCpuBuilder; + +impl VmBuilder for Rv32ImCpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Rv32ImConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Self::VmConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&Rv32ICpuBuilder, &config.rv32i, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.mul, inventory)?; + Ok(chip_complex) + } +} diff --git a/extensions/rv32im/circuit/src/load_sign_extend/core.rs b/extensions/rv32im/circuit/src/load_sign_extend/core.rs index 2284d6815c..4c0f0a87f3 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/core.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/core.rs @@ -3,15 +3,25 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, POINTER_MAX_BITS, + }, }; use openvm_circuit_primitives::{ utils::select, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -19,10 +29,8 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; -use crate::adapters::LoadStoreInstruction; +use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller}; /// LoadSignExtend Core Chip handles byte/halfword into word conversions through sign extend /// This chip uses read_data to construct write_data @@ -46,20 +54,7 @@ pub struct LoadSignExtendCoreCols { pub prev_data: [T; NUM_CELLS], } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Serialize + DeserializeOwned")] -pub struct LoadSignExtendCoreRecord { - #[serde(with = "BigArray")] - pub shifted_read_data: [F; NUM_CELLS], - #[serde(with = "BigArray")] - pub prev_data: [F; NUM_CELLS], - pub opcode: Rv32LoadStoreOpcode, - pub shift_amount: u32, - pub most_sig_bit: bool, -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct LoadSignExtendCoreAir { pub range_bus: VariableRangeCheckerBus, } @@ -178,135 +173,354 @@ where } } -pub struct LoadSignExtendCoreChip { - pub air: LoadSignExtendCoreAir, - pub range_checker_chip: SharedVariableRangeCheckerChip, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct LoadSignExtendCoreRecord { + pub is_byte: bool, + pub shift_amount: u8, + pub read_data: [u8; NUM_CELLS], + pub prev_data: [u8; NUM_CELLS], } -impl LoadSignExtendCoreChip { - pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self { - Self { - air: LoadSignExtendCoreAir:: { - range_bus: range_checker_chip.bus(), - }, - range_checker_chip, - } - } +#[derive(Clone, Copy, derive_new::new)] +pub struct LoadSignExtendExecutor { + adapter: A, +} + +#[derive(Clone, derive_new::new)] +pub struct LoadSignExtendFiller< + A = Rv32LoadStoreAdapterFiller, + const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS, + const LIMB_BITS: usize = RV32_CELL_BITS, +> { + adapter: A, + pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl, const NUM_CELLS: usize, const LIMB_BITS: usize> - VmCoreChip for LoadSignExtendCoreChip +impl PreflightExecutor + for LoadSignExtendExecutor where - I::Reads: Into<([[F; NUM_CELLS]; 2], F)>, - I::Writes: From<[[F; NUM_CELLS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8), + WriteData = [u32; NUM_CELLS], + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut LoadSignExtendCoreRecord, + ), + >, { - type Record = LoadSignExtendCoreRecord; - type Air = LoadSignExtendCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { + let Instruction { opcode, .. } = instruction; + let local_opcode = Rv32LoadStoreOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let (data, shift_amount) = reads.into(); - let shift_amount = shift_amount.as_canonical_u32(); - let write_data: [F; NUM_CELLS] = run_write_data_sign_extend::<_, NUM_CELLS, LIMB_BITS>( + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let tmp = self + .adapter + .read(state.memory, instruction, &mut adapter_record); + + core_record.is_byte = local_opcode == LOADB; + core_record.prev_data = tmp.0 .0.map(|x| x as u8); + core_record.read_data = tmp.0 .1; + core_record.shift_amount = tmp.1; + + let write_data = run_write_data_sign_extend( local_opcode, - data[1], - data[0], - shift_amount, + core_record.read_data, + core_record.shift_amount as usize, ); - let output = AdapterRuntimeContext::without_pc([write_data]); - let most_sig_limb = match local_opcode { - LOADB => write_data[0], - LOADH => write_data[NUM_CELLS / 2 - 1], - _ => unreachable!(), - } - .as_canonical_u32(); + self.adapter.write( + state.memory, + instruction, + write_data.map(u32::from), + &mut adapter_record, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for LoadSignExtendFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &LoadSignExtendCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut LoadSignExtendCoreCols = core_row.borrow_mut(); + + let shift = record.shift_amount; + let most_sig_limb = if record.is_byte { + record.read_data[shift as usize] + } else { + record.read_data[NUM_CELLS / 2 - 1 + shift as usize] + }; - let most_sig_bit = most_sig_limb & (1 << (LIMB_BITS - 1)); + let most_sig_bit = most_sig_limb & (1 << 7); self.range_checker_chip - .add_count(most_sig_limb - most_sig_bit, LIMB_BITS - 1); - - let read_shift = shift_amount & 2; - - Ok(( - output, - LoadSignExtendCoreRecord { - opcode: local_opcode, - most_sig_bit: most_sig_bit != 0, - prev_data: data[0], - shifted_read_data: array::from_fn(|i| { - data[1][(i + read_shift as usize) % NUM_CELLS] - }), - shift_amount, - }, - )) + .add_count((most_sig_limb - most_sig_bit) as u32, 7); + + core_row.prev_data = record.prev_data.map(F::from_canonical_u8); + core_row.shifted_read_data = record.read_data.map(F::from_canonical_u8); + core_row.shifted_read_data.rotate_left((shift & 2) as usize); + + core_row.data_most_sig_bit = F::from_bool(most_sig_bit != 0); + core_row.shift_most_sig_bit = F::from_bool(shift & 2 == 2); + core_row.opcode_loadh_flag = F::from_bool(!record.is_byte); + core_row.opcode_loadb_flag1 = F::from_bool(record.is_byte && ((shift & 1) == 1)); + core_row.opcode_loadb_flag0 = F::from_bool(record.is_byte && ((shift & 1) == 0)); } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET) - ) +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LoadSignExtendPreCompute { + imm_extended: u32, + a: u8, + b: u8, + e: u8, +} + +impl Executor + for LoadSignExtendExecutor +where + F: PrimeField32, +{ + fn pre_compute_size(&self) -> usize { + size_of::() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut LoadSignExtendCoreCols = row_slice.borrow_mut(); - let opcode = record.opcode; - let shift = record.shift_amount; - core_cols.opcode_loadb_flag0 = F::from_bool(opcode == LOADB && (shift & 1) == 0); - core_cols.opcode_loadb_flag1 = F::from_bool(opcode == LOADB && (shift & 1) == 1); - core_cols.opcode_loadh_flag = F::from_bool(opcode == LOADH); - core_cols.shift_most_sig_bit = F::from_canonical_u32((shift & 2) >> 1); - core_cols.data_most_sig_bit = F::from_bool(record.most_sig_bit); - core_cols.prev_data = record.prev_data; - core_cols.shifted_read_data = record.shifted_read_data; + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut LoadSignExtendPreCompute = data.borrow_mut(); + let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match (is_loadb, enabled) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, + }; + Ok(fn_ptr) } +} - fn air(&self) -> &Self::Air { - &self.air +impl MeteredExecutor + for LoadSignExtendExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match (is_loadb, enabled) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) } } -pub(super) fn run_write_data_sign_extend< +#[inline(always)] +unsafe fn execute_e12_impl< F: PrimeField32, - const NUM_CELLS: usize, - const LIMB_BITS: usize, + CTX: E1ExecutionCtx, + const IS_LOADB: bool, + const ENABLED: bool, >( + pre_compute: &LoadSignExtendPreCompute, + vm_state: &mut VmExecState, +) { + let rs1_bytes: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1_val = u32::from_le_bytes(rs1_bytes); + let ptr_val = rs1_val.wrapping_add(pre_compute.imm_extended); + // sign_extend([r32{c,g}(b):2]_e)` + debug_assert!(ptr_val < (1 << POINTER_MAX_BITS)); + let shift_amount = ptr_val % 4; + let ptr_val = ptr_val - shift_amount; // aligned ptr + + let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = vm_state.vm_read(pre_compute.e as u32, ptr_val); + + let write_data = if IS_LOADB { + let byte = read_data[shift_amount as usize]; + let sign_extended = (byte as i8) as i32; + sign_extended.to_le_bytes() + } else { + if shift_amount != 0 && shift_amount != 2 { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "LoadSignExtend invalid shift amount", + }); + return; + } + let half: [u8; 2] = array::from_fn(|i| read_data[shift_amount as usize + i]); + (i16::from_le_bytes(half) as i32).to_le_bytes() + }; + + if ENABLED { + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data); + } + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_LOADB: bool, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &LoadSignExtendPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const IS_LOADB: bool, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl LoadSignExtendExecutor { + /// Return (is_loadb, enabled) + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LoadSignExtendPreCompute, + ) -> Result<(bool, bool), StaticProgramError> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 == RV32_IMM_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let local_opcode = Rv32LoadStoreOpcode::from_usize( + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + match local_opcode { + LOADB | LOADH => {} + _ => unreachable!("LoadSignExtendExecutor should only handle LOADB/LOADH opcodes"), + } + + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + + *data = LoadSignExtendPreCompute { + imm_extended, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + e: e_u32 as u8, + }; + let enabled = !f.is_zero(); + Ok((local_opcode == LOADB, enabled)) + } +} + +// Returns write_data +#[inline(always)] +pub(super) fn run_write_data_sign_extend( opcode: Rv32LoadStoreOpcode, - read_data: [F; NUM_CELLS], - _prev_data: [F; NUM_CELLS], - shift: u32, -) -> [F; NUM_CELLS] { - let shift = shift as usize; - let mut write_data = read_data; + read_data: [u8; NUM_CELLS], + shift: usize, +) -> [u8; NUM_CELLS] { match (opcode, shift) { (LOADH, 0) | (LOADH, 2) => { - let ext = read_data[NUM_CELLS / 2 - 1 + shift].as_canonical_u32(); - let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1); - for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) { - *cell = F::from_canonical_u32(ext); - } - write_data[0..NUM_CELLS / 2] - .copy_from_slice(&read_data[shift..(NUM_CELLS / 2 + shift)]); + let ext = (read_data[NUM_CELLS / 2 - 1 + shift] >> 7) * u8::MAX; + array::from_fn(|i| { + if i < NUM_CELLS / 2 { + read_data[i + shift] + } else { + ext + } + }) } (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => { - let ext = read_data[shift].as_canonical_u32(); - let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1); - for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) { - *cell = F::from_canonical_u32(ext); - } - write_data[0] = read_data[shift]; + let ext = (read_data[shift] >> 7) * u8::MAX; + array::from_fn(|i| { + if i == 0 { + read_data[i + shift] + } else { + ext + } + }) } // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes. // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4. @@ -314,6 +528,5 @@ pub(super) fn run_write_data_sign_extend< _ => unreachable!( "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}" ), - }; - write_data + } } diff --git a/extensions/rv32im/circuit/src/load_sign_extend/mod.rs b/extensions/rv32im/circuit/src/load_sign_extend/mod.rs index 79efbe912e..978e875ba0 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/mod.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32LoadStoreAdapterChip; +use crate::adapters::{Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterExecutor}; mod core; pub use core::*; @@ -9,8 +9,10 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32LoadSignExtendChip = VmChipWrapper< - F, - Rv32LoadStoreAdapterChip, - LoadSignExtendCoreChip, +pub type Rv32LoadSignExtendAir = VmAirWrapper< + Rv32LoadStoreAdapterAir, + LoadSignExtendCoreAir, >; +pub type Rv32LoadSignExtendExecutor = + LoadSignExtendExecutor; +pub type Rv32LoadSignExtendChip = VmChipWrapper; diff --git a/extensions/rv32im/circuit/src/load_sign_extend/tests.rs b/extensions/rv32im/circuit/src/load_sign_extend/tests.rs index 0fe6d859d1..39c1f378ae 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/tests.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/tests.rs @@ -1,9 +1,6 @@ use std::{array, borrow::BorrowMut}; -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder}, - VmAdapterChip, -}; +use openvm_circuit::arch::testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ @@ -14,82 +11,104 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::run_write_data_sign_extend; +use super::{run_write_data_sign_extend, LoadSignExtendCoreAir}; use crate::{ - adapters::{compose, Rv32LoadStoreAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterExecutor, Rv32LoadStoreAdapterFiller, + RV32_REGISTER_NUM_LIMBS, + }, load_sign_extend::LoadSignExtendCoreCols, - LoadSignExtendCoreChip, Rv32LoadSignExtendChip, + test_utils::get_verification_error, + LoadSignExtendFiller, Rv32LoadSignExtendAir, Rv32LoadSignExtendChip, + Rv32LoadSignExtendExecutor, }; const IMM_BITS: usize = 16; - +const MAX_INS_CAPACITY: usize = 128; +type Harness = TestChipHarness< + F, + Rv32LoadSignExtendExecutor, + Rv32LoadSignExtendAir, + Rv32LoadSignExtendChip, +>; type F = BabyBear; -fn into_limbs(num: u32) -> [u32; NUM_LIMBS] { - array::from_fn(|i| (num >> (LIMB_BITS * i)) & ((1 << LIMB_BITS) - 1)) +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Harness { + let range_checker_chip = tester.range_checker().clone(); + let air = Rv32LoadSignExtendAir::new( + Rv32LoadStoreAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + range_checker_chip.bus(), + tester.address_bits(), + ), + LoadSignExtendCoreAir::new(range_checker_chip.bus()), + ); + let executor = + Rv32LoadSignExtendExecutor::new(Rv32LoadStoreAdapterExecutor::new(tester.address_bits())); + let chip = Rv32LoadSignExtendChip::::new( + LoadSignExtendFiller::new( + Rv32LoadStoreAdapterFiller::new(tester.address_bits(), range_checker_chip.clone()), + range_checker_chip.clone(), + ), + tester.memory_helper(), + ); + + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) } #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut Rv32LoadSignExtendChip, + harness: &mut Harness, rng: &mut StdRng, opcode: Rv32LoadStoreOpcode, - read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + read_data: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, ) { let imm = imm.unwrap_or(rng.gen_range(0..(1 << IMM_BITS))); let imm_sign = imm_sign.unwrap_or(rng.gen_range(0..2)); - let imm_ext = imm + imm_sign * (0xffffffff ^ ((1 << IMM_BITS) - 1)); + let imm_ext = imm + imm_sign * (0xffff0000); let alignment = match opcode { LOADB => 0, LOADH => 1, _ => unreachable!(), }; - let ptr_val = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - alignment)), - ) << alignment; - - let rs1 = rs1 - .unwrap_or(into_limbs::( - (ptr_val as u32).wrapping_sub(imm_ext), - )) - .map(F::from_canonical_u32); + + let ptr_val: u32 = rng.gen_range(0..(1 << (tester.address_bits() - alignment))) << alignment; + let rs1 = rs1.unwrap_or(ptr_val.wrapping_sub(imm_ext).to_le_bytes()); + let ptr_val = imm_ext.wrapping_add(u32::from_le_bytes(rs1)); let a = gen_pointer(rng, 4); let b = gen_pointer(rng, 4); - let ptr_val = imm_ext.wrapping_add(compose(rs1)); let shift_amount = ptr_val % 4; - tester.write(1, b, rs1); + tester.write(1, b, rs1.map(F::from_canonical_u8)); let some_prev_data: [F; RV32_REGISTER_NUM_LIMBS] = if a != 0 { - array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS)))) + array::from_fn(|_| F::from_canonical_u8(rng.gen())) } else { [F::ZERO; RV32_REGISTER_NUM_LIMBS] }; - let read_data: [F; RV32_REGISTER_NUM_LIMBS] = read_data - .unwrap_or(array::from_fn(|_| rng.gen_range(0..(1 << RV32_CELL_BITS)))) - .map(F::from_canonical_u32); + let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = + read_data.unwrap_or(array::from_fn(|_| rng.gen())); tester.write(1, a, some_prev_data); - tester.write(2, (ptr_val - shift_amount) as usize, read_data); + tester.write( + 2, + (ptr_val - shift_amount) as usize, + read_data.map(F::from_canonical_u8), + ); tester.execute( - chip, + harness, &Instruction::from_usize( opcode.global_opcode(), [ @@ -104,16 +123,11 @@ fn set_and_execute( ), ); - let write_data = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - opcode, - read_data, - some_prev_data, - shift_amount, - ); + let write_data = run_write_data_sign_extend(opcode, read_data, shift_amount as usize); if a != 0 { - assert_eq!(write_data, tester.read::<4>(1, a)); + assert_eq!(write_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } else { - assert_eq!([F::ZERO; RV32_REGISTER_NUM_LIMBS], tester.read::<4>(1, a)); + assert_eq!([F::ZERO; 4], tester.read::<4>(1, a)); } } @@ -123,40 +137,19 @@ fn set_and_execute( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_load_sign_extend_test() { - setup_tracing(); +#[test_case(LOADB, 100)] +#[test_case(LOADH, 100)] +fn rand_load_sign_extend_test(opcode: Rv32LoadStoreOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - let num_tests: usize = 100; - for _ in 0..num_tests { + let mut harness = create_test_chip(&mut tester); + for _ in 0..num_ops { set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, - LOADB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADH, + opcode, None, None, None, @@ -164,7 +157,7 @@ fn rand_load_sign_extend_test() { ); } - let tester = tester.build().load(chip).finalize(); + let tester = tester.build().load(harness).finalize(); tester.simple_test().expect("Verification failed"); } @@ -172,40 +165,33 @@ fn rand_load_sign_extend_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_negative_loadstore_test( - opcode: Rv32LoadStoreOpcode, - read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, +#[derive(Clone, Copy, Default, PartialEq)] +struct LoadSignExtPrankValues { data_most_sig_bit: Option, shift_most_sig_bit: Option, opcode_flags: Option<[bool; 3]>, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, +} + +#[allow(clippy::too_many_arguments)] +fn run_negative_load_sign_extend_test( + opcode: Rv32LoadStoreOpcode, + read_data: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, - expected_error: VerificationError, + prank_vals: LoadSignExtPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip.clone()); - let adapter_width = BaseAir::::width(adapter.air()); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut harness = create_test_chip(&mut tester); set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, opcode, read_data, @@ -214,78 +200,78 @@ fn run_negative_loadstore_test( imm_sign, ); + let adapter_width = BaseAir::::width(&harness.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); - let (_, core_row) = trace_row.split_at_mut(adapter_width); let core_cols: &mut LoadSignExtendCoreCols = core_row.borrow_mut(); - if let Some(shifted_read_data) = read_data { - core_cols.shifted_read_data = shifted_read_data.map(F::from_canonical_u32); + core_cols.shifted_read_data = shifted_read_data.map(F::from_canonical_u8); } - - if let Some(data_most_sig_bit) = data_most_sig_bit { + if let Some(data_most_sig_bit) = prank_vals.data_most_sig_bit { core_cols.data_most_sig_bit = F::from_canonical_u32(data_most_sig_bit); } - if let Some(shift_most_sig_bit) = shift_most_sig_bit { + if let Some(shift_most_sig_bit) = prank_vals.shift_most_sig_bit { core_cols.shift_most_sig_bit = F::from_canonical_u32(shift_most_sig_bit); } - - if let Some(opcode_flags) = opcode_flags { + if let Some(opcode_flags) = prank_vals.opcode_flags { core_cols.opcode_loadb_flag0 = F::from_bool(opcode_flags[0]); core_cols.opcode_loadb_flag1 = F::from_bool(opcode_flags[1]); core_cols.opcode_loadh_flag = F::from_bool(opcode_flags[2]); } + *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) + .load_and_prank_trace(harness, modify_trace) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn loadstore_negative_tests() { - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADB, Some([233, 187, 145, 238]), - Some(0), None, None, None, - None, - None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + data_most_sig_bit: Some(0), + ..Default::default() + }, + true, ); - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADH, None, - None, - Some(0), - None, Some([202, 109, 183, 26]), Some(31212), None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + shift_most_sig_bit: Some(0), + ..Default::default() + }, + true, ); - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADB, None, - None, - None, - Some([true, false, false]), Some([250, 132, 77, 5]), Some(47741), None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + opcode_flags: Some([true, false, false]), + ..Default::default() + }, + true, ); } @@ -294,119 +280,51 @@ fn loadstore_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADH, - None, - None, - None, - None, - ); - } -} #[test] fn solve_loadh_extend_sign_sanity_test() { - let read_data = [34, 159, 237, 151].map(F::from_canonical_u32); - let prev_data = [94, 183, 56, 241].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 0, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 2, - ); + let read_data = [34, 159, 237, 151]; + let write_data0 = run_write_data_sign_extend::(LOADH, read_data, 0); + let write_data2 = run_write_data_sign_extend::(LOADH, read_data, 2); - assert_eq!(write_data0, [34, 159, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data2, [237, 151, 255, 255].map(F::from_canonical_u32)); + assert_eq!(write_data0, [34, 159, 255, 255]); + assert_eq!(write_data2, [237, 151, 255, 255]); } #[test] fn solve_loadh_extend_zero_sanity_test() { - let read_data = [34, 121, 237, 97].map(F::from_canonical_u32); - let prev_data = [94, 183, 56, 241].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 0, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 2, - ); + let read_data = [34, 121, 237, 97]; + let write_data0 = run_write_data_sign_extend::(LOADH, read_data, 0); + let write_data2 = run_write_data_sign_extend::(LOADH, read_data, 2); - assert_eq!(write_data0, [34, 121, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [237, 97, 0, 0].map(F::from_canonical_u32)); + assert_eq!(write_data0, [34, 121, 0, 0]); + assert_eq!(write_data2, [237, 97, 0, 0]); } #[test] fn solve_loadb_extend_sign_sanity_test() { - let read_data = [45, 82, 99, 127].map(F::from_canonical_u32); - let prev_data = [53, 180, 29, 244].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 0, - ); - let write_data1 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 1, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 2, - ); - let write_data3 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 3, - ); - - assert_eq!(write_data0, [45, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data1, [82, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [99, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data3, [127, 0, 0, 0].map(F::from_canonical_u32)); + let read_data = [45, 82, 99, 127]; + let write_data0 = run_write_data_sign_extend::(LOADB, read_data, 0); + let write_data1 = run_write_data_sign_extend::(LOADB, read_data, 1); + let write_data2 = run_write_data_sign_extend::(LOADB, read_data, 2); + let write_data3 = run_write_data_sign_extend::(LOADB, read_data, 3); + + assert_eq!(write_data0, [45, 0, 0, 0]); + assert_eq!(write_data1, [82, 0, 0, 0]); + assert_eq!(write_data2, [99, 0, 0, 0]); + assert_eq!(write_data3, [127, 0, 0, 0]); } #[test] fn solve_loadb_extend_zero_sanity_test() { - let read_data = [173, 210, 227, 255].map(F::from_canonical_u32); - let prev_data = [53, 180, 29, 244].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 0, - ); - let write_data1 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 1, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 2, - ); - let write_data3 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 3, - ); - - assert_eq!(write_data0, [173, 255, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data1, [210, 255, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data2, [227, 255, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data3, [255, 255, 255, 255].map(F::from_canonical_u32)); + let read_data = [173, 210, 227, 255]; + let write_data0 = run_write_data_sign_extend::(LOADB, read_data, 0); + let write_data1 = run_write_data_sign_extend::(LOADB, read_data, 1); + let write_data2 = run_write_data_sign_extend::(LOADB, read_data, 2); + let write_data3 = run_write_data_sign_extend::(LOADB, read_data, 3); + + assert_eq!(write_data0, [173, 255, 255, 255]); + assert_eq!(write_data1, [210, 255, 255, 255]); + assert_eq!(write_data2, [227, 255, 255, 255]); + assert_eq!(write_data3, [255, 255, 255, 255]); } diff --git a/extensions/rv32im/circuit/src/loadstore/core.rs b/extensions/rv32im/circuit/src/loadstore/core.rs index 36beb10629..e2b9c1972f 100644 --- a/extensions/rv32im/circuit/src/loadstore/core.rs +++ b/extensions/rv32im/circuit/src/loadstore/core.rs @@ -1,10 +1,23 @@ -use std::borrow::{Borrow, BorrowMut}; +use std::{ + array, + borrow::{Borrow, BorrowMut}, + fmt::Debug, +}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, POINTER_MAX_BITS, + }, +}; +use openvm_circuit_primitives::{AlignedBorrow, AlignedBytesBorrow}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, NATIVE_AS, }; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -12,10 +25,8 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; -use crate::adapters::LoadStoreInstruction; +use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller}; #[derive(Debug, Clone, Copy)] enum InstructionOpcode { @@ -56,21 +67,7 @@ pub struct LoadStoreCoreCols { pub write_data: [T; NUM_CELLS], } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Serialize + DeserializeOwned")] -pub struct LoadStoreCoreRecord { - pub opcode: Rv32LoadStoreOpcode, - pub shift: u32, - #[serde(with = "BigArray")] - pub read_data: [F; NUM_CELLS], - #[serde(with = "BigArray")] - pub prev_data: [F; NUM_CELLS], - #[serde(with = "BigArray")] - pub write_data: [F; NUM_CELLS], -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct LoadStoreCoreAir { pub offset: usize, } @@ -246,70 +243,115 @@ where } } -#[derive(Debug)] -pub struct LoadStoreCoreChip { - pub air: LoadStoreCoreAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct LoadStoreCoreRecord { + pub local_opcode: u8, + pub shift_amount: u8, + pub read_data: [u8; NUM_CELLS], + // Note: `prev_data` can be from native address space, so we need to use u32 + pub prev_data: [u32; NUM_CELLS], } -impl LoadStoreCoreChip { - pub fn new(offset: usize) -> Self { - Self { - air: LoadStoreCoreAir { offset }, - } - } +#[derive(Clone, Copy, derive_new::new)] +pub struct LoadStoreExecutor { + adapter: A, + pub offset: usize, } -impl, const NUM_CELLS: usize> VmCoreChip - for LoadStoreCoreChip +#[derive(Clone, derive_new::new)] +pub struct LoadStoreFiller< + A = Rv32LoadStoreAdapterFiller, + const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS, +> { + adapter: A, + pub offset: usize, +} + +impl PreflightExecutor for LoadStoreExecutor where - I::Reads: Into<([[F; NUM_CELLS]; 2], F)>, - I::Writes: From<[[F; NUM_CELLS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8), + WriteData = [u32; NUM_CELLS], + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + (A::RecordMut<'buf>, &'buf mut LoadStoreCoreRecord), + >, { - type Record = LoadStoreCoreRecord; - type Air = LoadStoreCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = - Rv32LoadStoreOpcode::from_usize(instruction.opcode.local_opcode_idx(self.air.offset)); - - let (reads, shift_amount) = reads.into(); - let shift = shift_amount.as_canonical_u32(); - let prev_data = reads[0]; - let read_data = reads[1]; - let write_data = run_write_data(local_opcode, read_data, prev_data, shift); - let output = AdapterRuntimeContext::without_pc([write_data]); - - Ok(( - output, - LoadStoreCoreRecord { - opcode: local_opcode, - shift, - prev_data, - read_data, - write_data, - }, - )) - } - fn get_opcode_name(&self, opcode: usize) -> String { format!( "{:?}", - Rv32LoadStoreOpcode::from_usize(opcode - self.air.offset) + Rv32LoadStoreOpcode::from_usize(opcode - self.offset) ) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut LoadStoreCoreCols = row_slice.borrow_mut(); - let opcode = record.opcode; - let flags = &mut core_cols.flags; + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let Instruction { opcode, .. } = instruction; + + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + ( + (core_record.prev_data, core_record.read_data), + core_record.shift_amount, + ) = self + .adapter + .read(state.memory, instruction, &mut adapter_record); + + let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + core_record.local_opcode = local_opcode as u8; + + let write_data = run_write_data( + local_opcode, + core_record.read_data, + core_record.prev_data, + core_record.shift_amount as usize, + ); + self.adapter + .write(state.memory, instruction, write_data, &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for LoadStoreFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &LoadStoreCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut LoadStoreCoreCols = core_row.borrow_mut(); + + let opcode = Rv32LoadStoreOpcode::from_usize(record.local_opcode as usize); + let shift = record.shift_amount; + + let write_data = run_write_data(opcode, record.read_data, record.prev_data, shift as usize); + // Writing in reverse order + core_row.write_data = write_data.map(F::from_canonical_u32); + core_row.prev_data = record.prev_data.map(F::from_canonical_u32); + core_row.read_data = record.read_data.map(F::from_canonical_u8); + core_row.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode)); + core_row.is_valid = F::ONE; + let flags = &mut core_row.flags; *flags = [F::ZERO; 4]; - match (opcode, record.shift) { + match (opcode, shift) { (LOADW, 0) => flags[0] = F::TWO, (LOADHU, 0) => flags[1] = F::TWO, (LOADHU, 2) => flags[2] = F::TWO, @@ -328,51 +370,445 @@ where (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE), _ => unreachable!(), }; - core_cols.prev_data = record.prev_data; - core_cols.read_data = record.read_data; - core_cols.is_valid = F::ONE; - core_cols.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode)); - core_cols.write_data = record.write_data; } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LoadStorePreCompute { + imm_extended: u32, + a: u8, + b: u8, + e: u8, +} + +impl Executor for LoadStoreExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut LoadStorePreCompute = data.borrow_mut(); + let (local_opcode, enabled, is_native_store) = + self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match (local_opcode, enabled, is_native_store) { + (LOADW, true, _) => execute_e1_impl::<_, _, U8, LoadWOp, true>, + (LOADW, false, _) => execute_e1_impl::<_, _, U8, LoadWOp, false>, + (LOADHU, true, _) => execute_e1_impl::<_, _, U8, LoadHUOp, true>, + (LOADHU, false, _) => execute_e1_impl::<_, _, U8, LoadHUOp, false>, + (LOADBU, true, _) => execute_e1_impl::<_, _, U8, LoadBUOp, true>, + (LOADBU, false, _) => execute_e1_impl::<_, _, U8, LoadBUOp, false>, + (STOREW, true, false) => execute_e1_impl::<_, _, U8, StoreWOp, true>, + (STOREW, false, false) => execute_e1_impl::<_, _, U8, StoreWOp, false>, + (STOREW, true, true) => execute_e1_impl::<_, _, F, StoreWOp, true>, + (STOREW, false, true) => execute_e1_impl::<_, _, F, StoreWOp, false>, + (STOREH, true, false) => execute_e1_impl::<_, _, U8, StoreHOp, true>, + (STOREH, false, false) => execute_e1_impl::<_, _, U8, StoreHOp, false>, + (STOREH, true, true) => execute_e1_impl::<_, _, F, StoreHOp, true>, + (STOREH, false, true) => execute_e1_impl::<_, _, F, StoreHOp, false>, + (STOREB, true, false) => execute_e1_impl::<_, _, U8, StoreBOp, true>, + (STOREB, false, false) => execute_e1_impl::<_, _, U8, StoreBOp, false>, + (STOREB, true, true) => execute_e1_impl::<_, _, F, StoreBOp, true>, + (STOREB, false, true) => execute_e1_impl::<_, _, F, StoreBOp, false>, + (_, _, _) => unreachable!(), + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for LoadStoreExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (local_opcode, enabled, is_native_store) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match (local_opcode, enabled, is_native_store) { + (LOADW, true, _) => execute_e2_impl::<_, _, U8, LoadWOp, true>, + (LOADW, false, _) => execute_e2_impl::<_, _, U8, LoadWOp, false>, + (LOADHU, true, _) => execute_e2_impl::<_, _, U8, LoadHUOp, true>, + (LOADHU, false, _) => execute_e2_impl::<_, _, U8, LoadHUOp, false>, + (LOADBU, true, _) => execute_e2_impl::<_, _, U8, LoadBUOp, true>, + (LOADBU, false, _) => execute_e2_impl::<_, _, U8, LoadBUOp, false>, + (STOREW, true, false) => execute_e2_impl::<_, _, U8, StoreWOp, true>, + (STOREW, false, false) => execute_e2_impl::<_, _, U8, StoreWOp, false>, + (STOREW, true, true) => execute_e2_impl::<_, _, F, StoreWOp, true>, + (STOREW, false, true) => execute_e2_impl::<_, _, F, StoreWOp, false>, + (STOREH, true, false) => execute_e2_impl::<_, _, U8, StoreHOp, true>, + (STOREH, false, false) => execute_e2_impl::<_, _, U8, StoreHOp, false>, + (STOREH, true, true) => execute_e2_impl::<_, _, F, StoreHOp, true>, + (STOREH, false, true) => execute_e2_impl::<_, _, F, StoreHOp, false>, + (STOREB, true, false) => execute_e2_impl::<_, _, U8, StoreBOp, true>, + (STOREB, false, false) => execute_e2_impl::<_, _, U8, StoreBOp, false>, + (STOREB, true, true) => execute_e2_impl::<_, _, F, StoreBOp, true>, + (STOREB, false, true) => execute_e2_impl::<_, _, F, StoreBOp, false>, + (_, _, _) => unreachable!(), + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + T: Copy + Debug + Default, + OP: LoadStoreOp, + const ENABLED: bool, +>( + pre_compute: &LoadStorePreCompute, + vm_state: &mut VmExecState, +) { + let rs1_bytes: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1_val = u32::from_le_bytes(rs1_bytes); + let ptr_val = rs1_val.wrapping_add(pre_compute.imm_extended); + // sign_extend([r32{c,g}(b):2]_e)` + debug_assert!(ptr_val < (1 << POINTER_MAX_BITS)); + let shift_amount = ptr_val % 4; + let ptr_val = ptr_val - shift_amount; // aligned ptr + + let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = if OP::IS_LOAD { + vm_state.vm_read(pre_compute.e as u32, ptr_val) + } else { + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32) + }; + + // We need to write 4 u32s for STORE. + let mut write_data: [T; RV32_REGISTER_NUM_LIMBS] = if OP::HOST_READ { + vm_state.host_read(pre_compute.e as u32, ptr_val) + } else { + [T::default(); RV32_REGISTER_NUM_LIMBS] + }; + + if !OP::compute_write_data(&mut write_data, read_data, shift_amount as usize) { + vm_state.exit_code = Err(ExecutionError::Fail { + pc: vm_state.pc, + msg: "Invalid LoadStoreOp", + }); + return; + } + + if ENABLED { + if OP::IS_LOAD { + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data); + } else { + vm_state.vm_write(pre_compute.e as u32, ptr_val, &write_data); + } + } + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + T: Copy + Debug + Default, + OP: LoadStoreOp, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &LoadStorePreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + T: Copy + Debug + Default, + OP: LoadStoreOp, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl LoadStoreExecutor { + /// Return (local_opcode, enabled, is_native_store) + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LoadStorePreCompute, + ) -> Result<(Rv32LoadStoreOpcode, bool, bool), StaticProgramError> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + let enabled = !f.is_zero(); + + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 == RV32_IMM_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let local_opcode = Rv32LoadStoreOpcode::from_usize( + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + match local_opcode { + LOADW | LOADBU | LOADHU => {} + STOREW | STOREH | STOREB => { + if !enabled { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + } + _ => unreachable!("LoadStoreExecutor should not handle LOADB/LOADH opcodes"), + } + + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + let is_native_store = e_u32 == NATIVE_AS; + + *data = LoadStorePreCompute { + imm_extended, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + e: e_u32 as u8, + }; + Ok((local_opcode, enabled, is_native_store)) + } +} + +trait LoadStoreOp { + const IS_LOAD: bool; + const HOST_READ: bool; - fn air(&self) -> &Self::Air { - &self.air + /// Return if the operation is valid. + fn compute_write_data( + write_data: &mut [T; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool; +} +/// Wrapper type for u8 so we can implement `LoadStoreOp` for `F: PrimeField32`. +/// For memory read/write, this type behaves as same as `u8`. +#[allow(dead_code)] +#[derive(Copy, Clone, Debug, Default)] +struct U8(u8); +struct LoadWOp; +struct LoadHUOp; +struct LoadBUOp; +struct StoreWOp; +struct StoreHOp; +struct StoreBOp; +impl LoadStoreOp for LoadWOp { + const IS_LOAD: bool = true; + const HOST_READ: bool = false; + + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + _shift_amount: usize, + ) -> bool { + *write_data = read_data.map(U8); + true } } -pub(super) fn run_write_data( +impl LoadStoreOp for LoadHUOp { + const IS_LOAD: bool = true; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + if shift_amount != 0 && shift_amount != 2 { + return false; + } + write_data[0] = U8(read_data[shift_amount]); + write_data[1] = U8(read_data[shift_amount + 1]); + true + } +} +impl LoadStoreOp for LoadBUOp { + const IS_LOAD: bool = true; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + write_data[0] = U8(read_data[shift_amount]); + true + } +} + +impl LoadStoreOp for StoreWOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + _shift_amount: usize, + ) -> bool { + *write_data = read_data.map(U8); + true + } +} +impl LoadStoreOp for StoreHOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + if shift_amount != 0 && shift_amount != 2 { + return false; + } + write_data[shift_amount] = U8(read_data[0]); + write_data[shift_amount + 1] = U8(read_data[1]); + true + } +} +impl LoadStoreOp for StoreBOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + write_data[shift_amount] = U8(read_data[0]); + true + } +} + +impl LoadStoreOp for StoreWOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [F; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + _shift_amount: usize, + ) -> bool { + *write_data = read_data.map(F::from_canonical_u8); + true + } +} +impl LoadStoreOp for StoreHOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + + #[inline(always)] + fn compute_write_data( + write_data: &mut [F; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + if shift_amount != 0 && shift_amount != 2 { + return false; + } + write_data[shift_amount] = F::from_canonical_u8(read_data[0]); + write_data[shift_amount + 1] = F::from_canonical_u8(read_data[1]); + true + } +} +impl LoadStoreOp for StoreBOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + #[inline(always)] + fn compute_write_data( + write_data: &mut [F; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + write_data[shift_amount] = F::from_canonical_u8(read_data[0]); + true + } +} + +// Returns the write data +#[inline(always)] +pub(super) fn run_write_data( opcode: Rv32LoadStoreOpcode, - read_data: [F; NUM_CELLS], - prev_data: [F; NUM_CELLS], - shift: u32, -) -> [F; NUM_CELLS] { - let shift = shift as usize; - let mut write_data = read_data; + read_data: [u8; NUM_CELLS], + prev_data: [u32; NUM_CELLS], + shift: usize, +) -> [u32; NUM_CELLS] { match (opcode, shift) { - (LOADW, 0) => (), + (LOADW, 0) => { + read_data.map(|x| x as u32) + }, (LOADBU, 0) | (LOADBU, 1) | (LOADBU, 2) | (LOADBU, 3) => { - for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) { - *cell = F::ZERO; - } - write_data[0] = read_data[shift]; + let mut wrie_data = [0; NUM_CELLS]; + wrie_data[0] = read_data[shift] as u32; + wrie_data } (LOADHU, 0) | (LOADHU, 2) => { - for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) { - *cell = F::ZERO; - } + let mut write_data = [0; NUM_CELLS]; for (i, cell) in write_data.iter_mut().take(NUM_CELLS / 2).enumerate() { - *cell = read_data[i + shift]; + *cell = read_data[i + shift] as u32; } + write_data } - (STOREW, 0) => (), + (STOREW, 0) => { + read_data.map(|x| x as u32) + }, (STOREB, 0) | (STOREB, 1) | (STOREB, 2) | (STOREB, 3) => { - write_data = prev_data; - write_data[shift] = read_data[0]; + let mut write_data = prev_data; + write_data[shift] = read_data[0] as u32; + write_data } (STOREH, 0) | (STOREH, 2) => { - write_data = prev_data; - write_data[shift..(NUM_CELLS / 2 + shift)] - .copy_from_slice(&read_data[..(NUM_CELLS / 2)]); + array::from_fn(|i| { + if i >= shift && i < (NUM_CELLS / 2 + shift){ + read_data[i - shift] as u32 + } else { + prev_data[i] + } + }) } // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes. // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4. @@ -380,6 +816,5 @@ pub(super) fn run_write_data( _ => unreachable!( "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}" ), - }; - write_data + } } diff --git a/extensions/rv32im/circuit/src/loadstore/mod.rs b/extensions/rv32im/circuit/src/loadstore/mod.rs index 825f82166c..9d3056e83d 100644 --- a/extensions/rv32im/circuit/src/loadstore/mod.rs +++ b/extensions/rv32im/circuit/src/loadstore/mod.rs @@ -2,12 +2,16 @@ mod core; pub use core::*; -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::{Rv32LoadStoreAdapterChip, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::RV32_REGISTER_NUM_LIMBS; +use crate::adapters::{Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterExecutor}; #[cfg(test)] mod tests; -pub type Rv32LoadStoreChip = - VmChipWrapper, LoadStoreCoreChip>; +pub type Rv32LoadStoreAir = + VmAirWrapper>; +pub type Rv32LoadStoreExecutor = + LoadStoreExecutor; +pub type Rv32LoadStoreChip = VmChipWrapper; diff --git a/extensions/rv32im/circuit/src/loadstore/tests.rs b/extensions/rv32im/circuit/src/loadstore/tests.rs index 0fbfa137b9..00b2dcc46f 100644 --- a/extensions/rv32im/circuit/src/loadstore/tests.rs +++ b/extensions/rv32im/circuit/src/loadstore/tests.rs @@ -2,50 +2,83 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder}, - VmAdapterChip, + testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder}, + MemoryConfig, }, - utils::u32_into_limbs, + system::memory::merkle::public_values::PUBLIC_VALUES_AS, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, riscv::RV32_REGISTER_AS, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, - p3_field::FieldAlgebra, + p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, seq::SliceRandom, Rng}; +use test_case::test_case; -use super::{run_write_data, LoadStoreCoreChip, Rv32LoadStoreChip}; +use super::{run_write_data, LoadStoreCoreAir, Rv32LoadStoreChip}; use crate::{ - adapters::{compose, Rv32LoadStoreAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterCols, Rv32LoadStoreAdapterExecutor, + Rv32LoadStoreAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, loadstore::LoadStoreCoreCols, + test_utils::get_verification_error, + LoadStoreFiller, Rv32LoadStoreAir, Rv32LoadStoreExecutor, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +type Harness = TestChipHarness>; + +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Harness { + let range_checker_chip = tester.range_checker(); + + let air = Rv32LoadStoreAir::new( + Rv32LoadStoreAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + range_checker_chip.bus(), + tester.address_bits(), + ), + LoadStoreCoreAir::new(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + let executor = Rv32LoadStoreExecutor::new( + Rv32LoadStoreAdapterExecutor::new(tester.address_bits()), + Rv32LoadStoreOpcode::CLASS_OFFSET, + ); + let chip = Rv32LoadStoreChip::::new( + LoadStoreFiller::new( + Rv32LoadStoreAdapterFiller::new(tester.address_bits(), range_checker_chip.clone()), + Rv32LoadStoreOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY) +} #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut Rv32LoadStoreChip, + harness: &mut Harness, rng: &mut StdRng, opcode: Rv32LoadStoreOpcode, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, mem_as: Option, ) { let imm = imm.unwrap_or(rng.gen_range(0..(1 << IMM_BITS))); let imm_sign = imm_sign.unwrap_or(rng.gen_range(0..2)); - let imm_ext = imm + imm_sign * (0xffffffff ^ ((1 << IMM_BITS) - 1)); + let imm_ext = imm + imm_sign * 0xffff0000; let alignment = match opcode { LOADW | STOREW => 2, @@ -54,33 +87,21 @@ fn set_and_execute( _ => unreachable!(), }; - let ptr_val = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - alignment)), - ) << alignment; - - let rs1 = rs1 - .unwrap_or(u32_into_limbs::( - (ptr_val as u32).wrapping_sub(imm_ext), - )) - .map(F::from_canonical_u32); + let ptr_val: u32 = rng.gen_range(0..(1 << (tester.address_bits() - alignment))) << alignment; + let rs1 = rs1.unwrap_or(ptr_val.wrapping_sub(imm_ext).to_le_bytes()); + let ptr_val = imm_ext.wrapping_add(u32::from_le_bytes(rs1)); let a = gen_pointer(rng, 4); let b = gen_pointer(rng, 4); + let is_load = [LOADW, LOADHU, LOADBU].contains(&opcode); let mem_as = mem_as.unwrap_or(if is_load { - *[1, 2].choose(rng).unwrap() + 2 } else { *[2, 3, 4].choose(rng).unwrap() }); - let ptr_val = imm_ext.wrapping_add(compose(rs1)); let shift_amount = ptr_val % 4; - tester.write(1, b, rs1); + tester.write(1, b, rs1.map(F::from_canonical_u8)); let mut some_prev_data: [F; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS)))); @@ -92,11 +113,11 @@ fn set_and_execute( some_prev_data = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; } tester.write(1, a, some_prev_data); - if mem_as == 1 && ptr_val - shift_amount == 0 { - read_data = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; - } tester.write(mem_as, (ptr_val - shift_amount) as usize, read_data); } else { + if mem_as == 4 { + some_prev_data = array::from_fn(|_| rng.gen()); + } if a == 0 { read_data = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; } @@ -107,7 +128,7 @@ fn set_and_execute( let enabled_write = !(is_load & (a == 0)); tester.execute( - chip, + harness, &Instruction::from_usize( opcode.global_opcode(), [ @@ -122,7 +143,13 @@ fn set_and_execute( ), ); - let write_data = run_write_data(opcode, read_data, some_prev_data, shift_amount); + let write_data = run_write_data( + opcode, + read_data.map(|x| x.as_canonical_u32() as u8), + some_prev_data.map(|x| x.as_canonical_u32()), + shift_amount as usize, + ) + .map(F::from_canonical_u32); if is_load { if enabled_write { assert_eq!(write_data, tester.read::<4>(1, a)); @@ -143,80 +170,28 @@ fn set_and_execute( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_loadstore_test() { - setup_tracing(); +#[test_case(LOADW, 100)] +#[test_case(LOADBU, 100)] +#[test_case(LOADHU, 100)] +#[test_case(STOREW, 100)] +#[test_case(STOREB, 100)] +#[test_case(STOREH, 100)] +fn rand_loadstore_test(opcode: Rv32LoadStoreOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut mem_config = MemoryConfig::default(); + mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29; + if [STOREW, STOREB, STOREH].contains(&opcode) { + mem_config.addr_spaces[PUBLIC_VALUES_AS as usize].num_cells = 1 << 29; + } + let mut tester = VmChipTestBuilder::volatile(mem_config); + let mut harness = create_test_chip(&mut tester); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADBU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADHU, - None, - None, - None, - None, - ); + for _ in 0..num_ops { set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, - STOREW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREH, + opcode, None, None, None, @@ -224,8 +199,7 @@ fn rand_loadstore_test() { ); } - drop(range_checker_chip); - let tester = tester.build().load(chip).finalize(); + let tester = tester.build().load(harness).finalize(); tester.simple_test().expect("Verification failed"); } @@ -233,79 +207,84 @@ fn rand_loadstore_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_negative_loadstore_test( - opcode: Rv32LoadStoreOpcode, +#[derive(Clone, Copy, Default, PartialEq)] +struct LoadStorePrankValues { read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, prev_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, write_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, flags: Option<[u32; 4]>, is_load: Option, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + mem_as: Option, +} + +#[allow(clippy::too_many_arguments)] +fn run_negative_loadstore_test( + opcode: Rv32LoadStoreOpcode, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, - mem_as: Option, - expected_error: VerificationError, + prank_vals: LoadStorePrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let adapter_width = BaseAir::::width(adapter.air()); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut mem_config = MemoryConfig::default(); + mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29; + if [STOREW, STOREB, STOREH].contains(&opcode) { + mem_config.addr_spaces[PUBLIC_VALUES_AS as usize].num_cells = 1 << 29; + } + let mut tester = VmChipTestBuilder::volatile(mem_config); + let mut harness = create_test_chip(&mut tester); set_and_execute( &mut tester, - &mut chip, + &mut harness, &mut rng, opcode, rs1, imm, imm_sign, - mem_as, + None, ); + let adapter_width = BaseAir::::width(&harness.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); - let (_, core_row) = trace_row.split_at_mut(adapter_width); + let (adapter_row, core_row) = trace_row.split_at_mut(adapter_width); + let adapter_cols: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); let core_cols: &mut LoadStoreCoreCols = core_row.borrow_mut(); - if let Some(read_data) = read_data { + + if let Some(read_data) = prank_vals.read_data { core_cols.read_data = read_data.map(F::from_canonical_u32); } - if let Some(prev_data) = prev_data { + if let Some(prev_data) = prank_vals.prev_data { core_cols.prev_data = prev_data.map(F::from_canonical_u32); } - if let Some(write_data) = write_data { + if let Some(write_data) = prank_vals.write_data { core_cols.write_data = write_data.map(F::from_canonical_u32); } - if let Some(flags) = flags { + if let Some(flags) = prank_vals.flags { core_cols.flags = flags.map(F::from_canonical_u32); } - if let Some(is_load) = is_load { + if let Some(is_load) = prank_vals.is_load { core_cols.is_load = F::from_bool(is_load); } + if let Some(mem_as) = prank_vals.mem_as { + adapter_cols.mem_as = F::from_canonical_u32(mem_as); + } + *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) + .load_and_prank_trace(harness, modify_trace) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -315,41 +294,36 @@ fn negative_wrong_opcode_tests() { None, None, None, - None, - Some(false), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + is_load: Some(false), + ..Default::default() + }, + false, ); run_negative_loadstore_test( LOADBU, - None, - None, - None, - Some([0, 0, 0, 2]), - None, Some([4, 0, 0, 0]), Some(1), None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + flags: Some([0, 0, 0, 2]), + ..Default::default() + }, + false, ); run_negative_loadstore_test( STOREH, - None, - None, - None, - Some([1, 0, 1, 0]), - Some(true), Some([11, 169, 76, 28]), Some(37121), None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + flags: Some([1, 0, 1, 0]), + is_load: Some(true), + ..Default::default() + }, + false, ); } @@ -357,30 +331,34 @@ fn negative_wrong_opcode_tests() { fn negative_write_data_tests() { run_negative_loadstore_test( LOADHU, - Some([175, 33, 198, 250]), - Some([90, 121, 64, 205]), - Some([175, 33, 0, 0]), - Some([0, 2, 0, 0]), - Some(true), Some([13, 11, 156, 23]), Some(43641), None, - None, - VerificationError::ChallengePhaseError, + LoadStorePrankValues { + read_data: Some([175, 33, 198, 250]), + prev_data: Some([90, 121, 64, 205]), + write_data: Some([175, 33, 0, 0]), + flags: Some([0, 2, 0, 0]), + is_load: Some(true), + mem_as: None, + }, + true, ); run_negative_loadstore_test( STOREB, - Some([175, 33, 198, 250]), - Some([90, 121, 64, 205]), - Some([175, 121, 64, 205]), - Some([0, 0, 1, 1]), - None, Some([45, 123, 87, 24]), Some(28122), Some(0), - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + read_data: Some([175, 33, 198, 250]), + prev_data: Some([90, 121, 64, 205]), + write_data: Some([175, 121, 64, 205]), + flags: Some([0, 0, 1, 1]), + is_load: None, + mem_as: None, + }, + false, ); } @@ -391,39 +369,35 @@ fn negative_wrong_address_space_tests() { None, None, None, - None, - None, - None, - None, - None, - Some(3), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(3), + ..Default::default() + }, + false, ); + run_negative_loadstore_test( LOADW, None, None, None, - None, - None, - None, - None, - None, - Some(4), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(4), + ..Default::default() + }, + false, ); + run_negative_loadstore_test( STOREW, None, None, None, - None, - None, - None, - None, - None, - Some(1), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(1), + ..Default::default() + }, + false, ); } @@ -432,140 +406,60 @@ fn negative_wrong_address_space_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADBU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADHU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREH, - None, - None, - None, - None, - ); - } -} - #[test] fn run_loadw_storew_sanity_test() { - let read_data = [138, 45, 202, 76].map(F::from_canonical_u32); - let prev_data = [159, 213, 89, 34].map(F::from_canonical_u32); + let read_data = [138, 45, 202, 76]; + let prev_data = [159, 213, 89, 34]; let store_write_data = run_write_data(STOREW, read_data, prev_data, 0); let load_write_data = run_write_data(LOADW, read_data, prev_data, 0); - assert_eq!(store_write_data, read_data); - assert_eq!(load_write_data, read_data); + assert_eq!(store_write_data, read_data.map(u32::from)); + assert_eq!(load_write_data, read_data.map(u32::from)); } #[test] fn run_storeh_sanity_test() { - let read_data = [250, 123, 67, 198].map(F::from_canonical_u32); - let prev_data = [144, 56, 175, 92].map(F::from_canonical_u32); + let read_data = [250, 123, 67, 198]; + let prev_data = [144, 56, 175, 92]; let write_data = run_write_data(STOREH, read_data, prev_data, 0); let write_data2 = run_write_data(STOREH, read_data, prev_data, 2); - assert_eq!(write_data, [250, 123, 175, 92].map(F::from_canonical_u32)); - assert_eq!(write_data2, [144, 56, 250, 123].map(F::from_canonical_u32)); + assert_eq!(write_data, [250, 123, 175, 92]); + assert_eq!(write_data2, [144, 56, 250, 123]); } #[test] fn run_storeb_sanity_test() { - let read_data = [221, 104, 58, 147].map(F::from_canonical_u32); - let prev_data = [199, 83, 243, 12].map(F::from_canonical_u32); + let read_data = [221, 104, 58, 147]; + let prev_data = [199, 83, 243, 12]; let write_data = run_write_data(STOREB, read_data, prev_data, 0); let write_data1 = run_write_data(STOREB, read_data, prev_data, 1); let write_data2 = run_write_data(STOREB, read_data, prev_data, 2); let write_data3 = run_write_data(STOREB, read_data, prev_data, 3); - assert_eq!(write_data, [221, 83, 243, 12].map(F::from_canonical_u32)); - assert_eq!(write_data1, [199, 221, 243, 12].map(F::from_canonical_u32)); - assert_eq!(write_data2, [199, 83, 221, 12].map(F::from_canonical_u32)); - assert_eq!(write_data3, [199, 83, 243, 221].map(F::from_canonical_u32)); + assert_eq!(write_data, [221, 83, 243, 12]); + assert_eq!(write_data1, [199, 221, 243, 12]); + assert_eq!(write_data2, [199, 83, 221, 12]); + assert_eq!(write_data3, [199, 83, 243, 221]); } #[test] fn run_loadhu_sanity_test() { - let read_data = [175, 33, 198, 250].map(F::from_canonical_u32); - let prev_data = [90, 121, 64, 205].map(F::from_canonical_u32); + let read_data = [175, 33, 198, 250]; + let prev_data = [90, 121, 64, 205]; let write_data = run_write_data(LOADHU, read_data, prev_data, 0); let write_data2 = run_write_data(LOADHU, read_data, prev_data, 2); - assert_eq!(write_data, [175, 33, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [198, 250, 0, 0].map(F::from_canonical_u32)); + assert_eq!(write_data, [175, 33, 0, 0]); + assert_eq!(write_data2, [198, 250, 0, 0]); } #[test] fn run_loadbu_sanity_test() { - let read_data = [131, 74, 186, 29].map(F::from_canonical_u32); - let prev_data = [142, 67, 210, 88].map(F::from_canonical_u32); + let read_data = [131, 74, 186, 29]; + let prev_data = [142, 67, 210, 88]; let write_data = run_write_data(LOADBU, read_data, prev_data, 0); let write_data1 = run_write_data(LOADBU, read_data, prev_data, 1); let write_data2 = run_write_data(LOADBU, read_data, prev_data, 2); let write_data3 = run_write_data(LOADBU, read_data, prev_data, 3); - assert_eq!(write_data, [131, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data1, [74, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [186, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data3, [29, 0, 0, 0].map(F::from_canonical_u32)); + assert_eq!(write_data, [131, 0, 0, 0]); + assert_eq!(write_data1, [74, 0, 0, 0]); + assert_eq!(write_data2, [186, 0, 0, 0]); + assert_eq!(write_data3, [29, 0, 0, 0]); } diff --git a/extensions/rv32im/circuit/src/mul/core.rs b/extensions/rv32im/circuit/src/mul/core.rs index fa65a6cf09..c23551d73a 100644 --- a/extensions/rv32im/circuit/src/mul/core.rs +++ b/extensions/rv32im/circuit/src/mul/core.rs @@ -3,13 +3,24 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::{ + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + AlignedBytesBorrow, }; -use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::MulOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -17,8 +28,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; #[repr(C)] #[derive(AlignedBorrow)] @@ -29,7 +38,7 @@ pub struct MultiplicationCoreCols { pub bus: RangeTupleCheckerBus<2>, pub offset: usize, @@ -109,14 +118,34 @@ where } } -#[derive(Debug)] -pub struct MultiplicationCoreChip { - pub air: MultiplicationCoreAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MultiplicationCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], +} + +#[derive(Clone, Copy, derive_new::new)] +pub struct MultiplicationExecutor { + adapter: A, + pub offset: usize, +} + +#[derive(Clone, Debug)] +pub struct MultiplicationFiller { + adapter: A, + pub offset: usize, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl MultiplicationCoreChip { - pub fn new(range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize) -> Self { +impl + MultiplicationFiller +{ + pub fn new( + adapter: A, + range_tuple_chip: SharedRangeTupleCheckerChip<2>, + offset: usize, + ) -> Self { // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i // < NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1 bytes // (with LIMB_BITS bits). @@ -132,102 +161,234 @@ impl MultiplicationCoreChip { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for MultiplicationCoreChip +impl PreflightExecutor + for MultiplicationExecutor where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut MultiplicationCoreRecord, + ), + >, { - type Record = MultiplicationCoreRecord; - type Air = MultiplicationCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", MulOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { let Instruction { opcode, .. } = instruction; - assert_eq!( - MulOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)), + + debug_assert_eq!( + MulOpcode::from_usize(opcode.local_opcode_idx(self.offset)), MulOpcode::MUL ); + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let (a, _) = run_mul::(&rs1, &rs2); + + core_record.b = rs1; + core_record.c = rs2; + + self.adapter + .write(state.memory, instruction, [a].into(), &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } +} +impl TraceFiller + for MultiplicationFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &MultiplicationCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, carry) = run_mul::(&b, &c); + let core_row: &mut MultiplicationCoreCols = core_row.borrow_mut(); + + let (a, carry) = run_mul::(&record.b, &record.c); for (a, carry) in a.iter().zip(carry.iter()) { - self.range_tuple_chip.add_count(&[*a, *carry]); + self.range_tuple_chip.add_count(&[*a as u32, *carry]); } - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = MultiplicationCoreRecord { - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - }; + // write in reverse order + core_row.is_valid = F::ONE; + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u8); + } +} - Ok((output, record)) +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MultiPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor + for MultiplicationExecutor +where + F: PrimeField32, +{ + fn pre_compute_size(&self) -> usize { + size_of::() + } + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let pre_compute: &mut MultiPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, pre_compute)?; + Ok(execute_e1_impl) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", MulOpcode::from_usize(opcode - self.air.offset)) +impl MeteredExecutor + for MultiplicationExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = - row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.is_valid = F::ONE; + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_e2_impl) } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MultiPreCompute, + vm_state: &mut VmExecState, +) { + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + let rd = rs1.wrapping_mul(rs2); + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd.to_le_bytes()); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &MultiPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} - fn air(&self) -> &Self::Air { - &self.air +impl MultiplicationExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut MultiPreCompute, + ) -> Result<(), StaticProgramError> { + assert_eq!( + MulOpcode::from_usize(inst.opcode.local_opcode_idx(self.offset)), + MulOpcode::MUL + ); + if inst.d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *data = MultiPreCompute { + a: inst.a.as_canonical_u32() as u8, + b: inst.b.as_canonical_u32() as u8, + c: inst.c.as_canonical_u32() as u8, + }; + Ok(()) } } // returns mul, carry +#[inline(always)] pub(super) fn run_mul( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS]) { - let mut result = [0; NUM_LIMBS]; - let mut carry = [0; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], [u32; NUM_LIMBS]) { + let mut result = [0u8; NUM_LIMBS]; + let mut carry = [0u32; NUM_LIMBS]; for i in 0..NUM_LIMBS { + let mut res = 0u32; if i > 0 { - result[i] = carry[i - 1]; + res = carry[i - 1]; } for j in 0..=i { - result[i] += x[j] * y[i - j]; + res += (x[j] as u32) * (y[i - j] as u32); } - carry[i] = result[i] >> LIMB_BITS; - result[i] %= 1 << LIMB_BITS; + carry[i] = res >> LIMB_BITS; + res %= 1u32 << LIMB_BITS; + result[i] = res as u8; } (result, carry) } diff --git a/extensions/rv32im/circuit/src/mul/mod.rs b/extensions/rv32im/circuit/src/mul/mod.rs index 5f28439977..ab654eefdb 100644 --- a/extensions/rv32im/circuit/src/mul/mod.rs +++ b/extensions/rv32im/circuit/src/mul/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller}; mod core; pub use core::*; @@ -8,8 +9,13 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32MultiplicationAir = VmAirWrapper< + Rv32MultAdapterAir, + MultiplicationCoreAir, +>; +pub type Rv32MultiplicationExecutor = + MultiplicationExecutor; pub type Rv32MultiplicationChip = VmChipWrapper< F, - Rv32MultAdapterChip, - MultiplicationCoreChip, + MultiplicationFiller, >; diff --git a/extensions/rv32im/circuit/src/mul/tests.rs b/extensions/rv32im/circuit/src/mul/tests.rs index b942c24cc3..e2d9bd42a3 100644 --- a/extensions/rv32im/circuit/src/mul/tests.rs +++ b/extensions/rv32im/circuit/src/mul/tests.rs @@ -1,15 +1,11 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, RANGE_TUPLE_CHECKER_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, - }, - utils::generate_long_number, +use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, RANGE_TUPLE_CHECKER_BUS}; +use openvm_circuit_primitives::range_tuple::{ + RangeTupleCheckerAir, RangeTupleCheckerBus, RangeTupleCheckerChip, SharedRangeTupleCheckerChip, }; -use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::MulOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::MulOpcode::{self, MUL}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -18,19 +14,88 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; use super::core::run_mul; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - mul::{MultiplicationCoreChip, MultiplicationCoreCols, Rv32MultiplicationChip}, - test_utils::rv32_rand_write_register_or_imm, + adapters::{ + Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, + }, + mul::{MultiplicationCoreCols, Rv32MultiplicationChip}, + test_utils::{get_verification_error, rv32_rand_write_register_or_imm}, + MultiplicationCoreAir, MultiplicationFiller, Rv32MultiplicationAir, Rv32MultiplicationExecutor, }; +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; type F = BabyBear; +type Harness = TestChipHarness< + F, + Rv32MultiplicationExecutor, + Rv32MultiplicationAir, + Rv32MultiplicationChip, +>; + +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Harness, + (RangeTupleCheckerAir<2>, SharedRangeTupleCheckerChip<2>), +) { + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + let range_tuple_chip = + SharedRangeTupleCheckerChip::new(RangeTupleCheckerChip::<2>::new(range_tuple_bus)); + + let air = Rv32MultiplicationAir::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + MultiplicationCoreAir::new(range_tuple_bus, MulOpcode::CLASS_OFFSET), + ); + let executor = + Rv32MultiplicationExecutor::new(Rv32MultAdapterExecutor, MulOpcode::CLASS_OFFSET); + let chip = Rv32MultiplicationChip::::new( + MultiplicationFiller::new( + Rv32MultAdapterFiller, + range_tuple_chip.clone(), + MulOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (range_tuple_chip.air, range_tuple_chip)) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: MulOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let c = c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + + let (mut instruction, rd) = + rv32_rand_write_register_or_imm(tester, b, c, None, opcode.global_opcode().as_usize(), rng); + + instruction.e = F::ZERO; + tester.execute(harness, &instruction); + + let (a, _) = run_mul::(&b, &c); + assert_eq!( + a.map(F::from_canonical_u8), + tester.read::(1, rd) + ) +} ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS @@ -39,144 +104,77 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// -fn run_rv32_mul_rand_test(num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; +#[test] +fn run_rv32_mul_rand_test() { let mut rng = create_seeded_rng(); - - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MultiplicationChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - MultiplicationCoreChip::new(range_tuple_checker.clone(), MulOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, range_tuple) = create_test_chip(&mut tester); + let num_ops = 100; for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - - let (mut instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - None, - MulOpcode::MUL.global_opcode().as_usize(), - &mut rng, - ); - instruction.e = F::ZERO; - tester.execute(&mut chip, &instruction); - - let (a, _) = run_mul::(&b, &c); - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ) + set_and_execute(&mut tester, &mut harness, &mut rng, MUL, None, None); } let tester = tester .build() - .load(chip) - .load(range_tuple_checker) + .load(harness) + .load_periphery(range_tuple) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_mul_rand_test() { - run_rv32_mul_rand_test(1); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32MultiplicationTestChip = VmChipWrapper< - F, - TestAdapterChip, - MultiplicationCoreChip, ->; - #[allow(clippy::too_many_arguments)] -fn run_rv32_mul_negative_test( - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - is_valid: bool, +fn run_negative_mul_test( + opcode: MulOpcode, + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_is_valid: bool, interaction_error: bool, ) { - const MAX_NUM_LIMBS: u32 = 32; - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MultiplicationTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - MultiplicationCoreChip::new(range_tuple_chip.clone(), MulOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - tester.execute( - &mut chip, - &Instruction::from_usize(MulOpcode::MUL.global_opcode(), [0, 0, 0, 1, 0]), + let (mut harness, range_tuple) = create_test_chip(&mut tester); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(b), + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, carry) = run_mul::(&b, &c); - - range_tuple_chip.clear(); - if is_valid { - for (a, carry) in a.iter().zip(carry.iter()) { - range_tuple_chip.add_count(&[*a, *carry]); - } - } - + let adapter_width = BaseAir::::width(&harness.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut MultiplicationCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - cols.is_valid = F::from_bool(is_valid); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + cols.is_valid = F::from_bool(prank_is_valid); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) - .load(range_tuple_chip) + .load_and_prank_trace(harness, modify_trace) + .load_periphery(range_tuple) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_mul_wrong_negative_test() { - run_rv32_mul_negative_test( + run_negative_mul_test( + MUL, [63, 247, 125, 234], [51, 109, 78, 142], [197, 85, 150, 32], @@ -187,7 +185,8 @@ fn rv32_mul_wrong_negative_test() { #[test] fn rv32_mul_is_valid_false_negative_test() { - run_rv32_mul_negative_test( + run_negative_mul_test( + MUL, [63, 247, 125, 234], [51, 109, 78, 142], [197, 85, 150, 32], @@ -204,9 +203,9 @@ fn rv32_mul_is_valid_false_negative_test() { #[test] fn run_mul_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [197, 85, 150, 32]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [51, 109, 78, 142]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [63, 247, 125, 232]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [197, 85, 150, 32]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [51, 109, 78, 142]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [63, 247, 125, 232]; let c: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (result, carry) = run_mul::(&x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { diff --git a/extensions/rv32im/circuit/src/mulh/core.rs b/extensions/rv32im/circuit/src/mulh/core.rs index 16aa8fd550..a14021c477 100644 --- a/extensions/rv32im/circuit/src/mulh/core.rs +++ b/extensions/rv32im/circuit/src/mulh/core.rs @@ -3,16 +3,25 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::MulHOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,8 +29,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -40,7 +47,7 @@ pub struct MulHCoreCols { pub opcode_mulhu_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct MulHCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_tuple_bus: RangeTupleCheckerBus<2>, @@ -183,14 +190,29 @@ where } } -pub struct MulHCoreChip { - pub air: MulHCoreAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MulHCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, +} + +#[derive(Clone, Copy, derive_new::new)] +pub struct MulHExecutor { + adapter: A, + pub offset: usize, +} +#[derive(Clone)] +pub struct MulHFiller { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl MulHCoreChip { +impl MulHFiller { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_tuple_chip: SharedRangeTupleCheckerChip<2>, ) -> Self { @@ -209,55 +231,93 @@ impl MulHCoreChip { - pub opcode: MulHOpcode, - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub a_mul: [T; NUM_LIMBS], - pub b_ext: T, - pub c_ext: T, -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for MulHCoreChip +impl PreflightExecutor + for MulHExecutor where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut MulHCoreRecord, + ), + >, { - type Record = MulHCoreRecord; - type Air = MulHCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { let Instruction { opcode, .. } = instruction; - let mulh_opcode = MulHOpcode::from_usize(opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, a_mul, carry, b_ext, c_ext) = run_mulh::(mulh_opcode, &b, &c); + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.local_opcode = opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET) as u8; + let mulh_opcode = MulHOpcode::from_usize(core_record.local_opcode as usize); + + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let (a, _, _, _, _) = run_mulh::( + mulh_opcode, + &core_record.b.map(u32::from), + &core_record.c.map(u32::from), + ); + + let a = a.map(|x| x as u8); + self.adapter + .write(state.memory, instruction, [a].into(), &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for MulHFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &MulHCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut MulHCoreCols = core_row.borrow_mut(); + + let opcode = MulHOpcode::from_usize(record.local_opcode as usize); + let (a, a_mul, carry, b_ext, c_ext) = run_mulh::( + opcode, + &record.b.map(u32::from), + &record.c.map(u32::from), + ); for i in 0..NUM_LIMBS { self.range_tuple_chip.add_count(&[a_mul[i], carry[i]]); @@ -265,55 +325,182 @@ where .add_count(&[a[i], carry[NUM_LIMBS + i]]); } - if mulh_opcode != MulHOpcode::MULHU { + if opcode != MulHOpcode::MULHU { let b_sign_mask = if b_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) }; let c_sign_mask = if c_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) }; self.bitwise_lookup_chip.request_range( - (b[NUM_LIMBS - 1] - b_sign_mask) << 1, - (c[NUM_LIMBS - 1] - c_sign_mask) << ((mulh_opcode == MulHOpcode::MULH) as u32), + (record.b[NUM_LIMBS - 1] as u32 - b_sign_mask) << 1, + (record.c[NUM_LIMBS - 1] as u32 - c_sign_mask) + << ((opcode == MulHOpcode::MULH) as u32), ); } - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = MulHCoreRecord { - opcode: mulh_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - a_mul: a_mul.map(F::from_canonical_u32), - b_ext: F::from_canonical_u32(b_ext), - c_ext: F::from_canonical_u32(c_ext), + // Write in reverse order + core_row.opcode_mulhu_flag = F::from_bool(opcode == MulHOpcode::MULHU); + core_row.opcode_mulhsu_flag = F::from_bool(opcode == MulHOpcode::MULHSU); + core_row.opcode_mulh_flag = F::from_bool(opcode == MulHOpcode::MULH); + core_row.c_ext = F::from_canonical_u32(c_ext); + core_row.b_ext = F::from_canonical_u32(b_ext); + core_row.a_mul = a_mul.map(F::from_canonical_u32); + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u32); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MulHPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor + for MulHExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut MulHPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_e1(inst, pre_compute)?; + let fn_ptr = match local_opcode { + MulHOpcode::MULH => execute_e1_impl::<_, _, MulHOp>, + MulHOpcode::MULHSU => execute_e1_impl::<_, _, MulHSuOp>, + MulHOpcode::MULHU => execute_e1_impl::<_, _, MulHUOp>, }; + Ok(fn_ptr) + } +} - Ok((output, record)) +impl MeteredExecutor + for MulHExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET) - ) + fn metered_pre_compute( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_e1(inst, &mut pre_compute.data)?; + let fn_ptr = match local_opcode { + MulHOpcode::MULH => execute_e2_impl::<_, _, MulHOp>, + MulHOpcode::MULHSU => execute_e2_impl::<_, _, MulHSuOp>, + MulHOpcode::MULHU => execute_e2_impl::<_, _, MulHUOp>, + }; + Ok(fn_ptr) } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MulHPreCompute, + vm_state: &mut VmExecState, +) { + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let rd = ::compute(rs1, rs2); + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &MulHPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut MulHCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.a_mul = record.a_mul; - row_slice.b_ext = record.b_ext; - row_slice.c_ext = record.c_ext; - row_slice.opcode_mulh_flag = F::from_bool(record.opcode == MulHOpcode::MULH); - row_slice.opcode_mulhsu_flag = F::from_bool(record.opcode == MulHOpcode::MULHSU); - row_slice.opcode_mulhu_flag = F::from_bool(record.opcode == MulHOpcode::MULHU); +impl MulHExecutor { + #[inline(always)] + fn pre_compute_e1( + &self, + inst: &Instruction, + data: &mut MulHPreCompute, + ) -> Result { + *data = MulHPreCompute { + a: inst.a.as_canonical_u32() as u8, + b: inst.b.as_canonical_u32() as u8, + c: inst.c.as_canonical_u32() as u8, + }; + Ok(MulHOpcode::from_usize( + inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET), + )) } +} - fn air(&self) -> &Self::Air { - &self.air +trait MulHOperation { + fn compute(rs1: [u8; 4], rs1: [u8; 4]) -> [u8; 4]; +} +struct MulHOp; +struct MulHSuOp; +struct MulHUOp; +impl MulHOperation for MulHOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1 = i32::from_le_bytes(rs1) as i64; + let rs2 = i32::from_le_bytes(rs2) as i64; + ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes() + } +} +impl MulHOperation for MulHSuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1 = i32::from_le_bytes(rs1) as i64; + let rs2 = u32::from_le_bytes(rs2) as i64; + ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes() + } +} +impl MulHOperation for MulHUOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1 = u32::from_le_bytes(rs1) as i64; + let rs2 = u32::from_le_bytes(rs2) as i64; + ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes() } } // returns mulh[[s]u], mul, carry, x_ext, y_ext +#[inline(always)] pub(super) fn run_mulh( opcode: MulHOpcode, x: &[u32; NUM_LIMBS], diff --git a/extensions/rv32im/circuit/src/mulh/mod.rs b/extensions/rv32im/circuit/src/mulh/mod.rs index 284b77191a..26dd83fb31 100644 --- a/extensions/rv32im/circuit/src/mulh/mod.rs +++ b/extensions/rv32im/circuit/src/mulh/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller}; mod core; pub use core::*; @@ -8,5 +9,9 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32MulHAir = + VmAirWrapper>; +pub type Rv32MulHExecutor = + MulHExecutor; pub type Rv32MulHChip = - VmChipWrapper, MulHCoreChip>; + VmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/mulh/tests.rs b/extensions/rv32im/circuit/src/mulh/tests.rs index 1c7cf5b5cb..161e999987 100644 --- a/extensions/rv32im/circuit/src/mulh/tests.rs +++ b/extensions/rv32im/circuit/src/mulh/tests.rs @@ -1,21 +1,24 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, sync::Arc}; use openvm_circuit::{ - arch::{ - testing::{ - memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, - RANGE_TUPLE_CHECKER_BUS, - }, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, + arch::testing::{ + memory::gen_pointer, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, + RANGE_TUPLE_CHECKER_BUS, }, utils::generate_long_number, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, + }, + range_tuple::{ + RangeTupleCheckerAir, RangeTupleCheckerBus, RangeTupleCheckerChip, + SharedRangeTupleCheckerChip, + }, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::MulHOpcode; +use openvm_rv32im_transpiler::MulHOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -24,36 +27,90 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::rngs::StdRng; +use test_case::test_case; use super::core::run_mulh; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - mulh::{MulHCoreChip, MulHCoreCols, Rv32MulHChip}, + adapters::{ + Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, + }, + mulh::{MulHCoreCols, Rv32MulHChip}, + test_utils::get_verification_error, + MulHCoreAir, MulHFiller, Rv32MulHAir, Rv32MulHExecutor, }; +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; type F = BabyBear; +type Harness = TestChipHarness>; -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), + (RangeTupleCheckerAir<2>, SharedRangeTupleCheckerChip<2>), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + let range_tuple_chip = + SharedRangeTupleCheckerChip::new(RangeTupleCheckerChip::<2>::new(range_tuple_bus)); + + let air = Rv32MulHAir::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + MulHCoreAir::new(bitwise_bus, range_tuple_bus), + ); + let executor = Rv32MulHExecutor::new(Rv32MultAdapterExecutor, MulHOpcode::CLASS_OFFSET); + let chip = Rv32MulHChip::::new( + MulHFiller::new( + Rv32MultAdapterFiller, + bitwise_chip.clone(), + range_tuple_chip.clone(), + ), + tester.memory_helper(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + ( + harness, + (bitwise_chip.air, bitwise_chip), + (range_tuple_chip.air, range_tuple_chip), + ) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_mulh_rand_write_execute>( - opcode: MulHOpcode, +fn set_and_execute( tester: &mut VmChipTestBuilder, - chip: &mut E, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + harness: &mut Harness, rng: &mut StdRng, + opcode: MulHOpcode, + b: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, ) { + let b = b.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let c = c.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); let rd = gen_pointer(rng, 4); @@ -61,160 +118,103 @@ fn run_rv32_mulh_rand_write_execute>( tester.write::(1, rs1, b.map(F::from_canonical_u32)); tester.write::(1, rs2, c.map(F::from_canonical_u32)); - let (a, _, _, _, _) = run_mulh::(opcode, &b, &c); tester.execute( - chip, + harness, &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 0]), ); + let (a, _, _, _, _) = run_mulh::(opcode, &b, &c); assert_eq!( a.map(F::from_canonical_u32), tester.read::(1, rd) ); } +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(MULH, 100)] +#[test_case(MULHSU, 100)] +#[test_case(MULHU, 100)] fn run_rv32_mulh_rand_test(opcode: MulHOpcode, num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MulHChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - MulHCoreChip::new(bitwise_chip.clone(), range_tuple_checker.clone()), - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise, range_tuple) = create_test_chip(&mut tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - run_rv32_mulh_rand_write_execute(opcode, &mut tester, &mut chip, b, c, &mut rng); + set_and_execute(&mut tester, &mut harness, &mut rng, opcode, None, None); } let tester = tester .build() - .load(chip) - .load(bitwise_chip) - .load(range_tuple_checker) + .load(harness) + .load_periphery(bitwise) + .load_periphery(range_tuple) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_mulh_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULH, 100); -} - -#[test] -fn rv32_mulhsu_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULHSU, 100); -} - -#[test] -fn rv32_mulhu_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULHU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32MulHTestChip = - VmChipWrapper, MulHCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_mulh_negative_test( +fn run_negative_mulh_test( opcode: MulHOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], b: [u32; RV32_REGISTER_NUM_LIMBS], c: [u32; RV32_REGISTER_NUM_LIMBS], - a_mul: [u32; RV32_REGISTER_NUM_LIMBS], - b_ext: u32, - c_ext: u32, + prank_a_mul: [u32; RV32_REGISTER_NUM_LIMBS], + prank_b_ext: u32, + prank_c_ext: u32, interaction_error: bool, ) { - const MAX_NUM_LIMBS: u32 = 32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MulHTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - MulHCoreChip::new(bitwise_chip.clone(), range_tuple_chip.clone()), - tester.offline_memory_mutex_arc(), - ); - - tester.execute( - &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 0]), + let (mut harness, bitwise, range_tuple) = create_test_chip(&mut tester); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(b), + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, _, carry, _, _) = run_mulh::(opcode, &b, &c); - - range_tuple_chip.clear(); - for i in 0..RV32_REGISTER_NUM_LIMBS { - range_tuple_chip.add_count(&[a_mul[i], carry[i]]); - range_tuple_chip.add_count(&[a[i], carry[RV32_REGISTER_NUM_LIMBS + i]]); - } - + let adapter_width = BaseAir::::width(&harness.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut MulHCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - cols.a_mul = a_mul.map(F::from_canonical_u32); - cols.b_ext = F::from_canonical_u32(b_ext); - cols.c_ext = F::from_canonical_u32(c_ext); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + cols.a_mul = prank_a_mul.map(F::from_canonical_u32); + cols.b_ext = F::from_canonical_u32(prank_b_ext); + cols.c_ext = F::from_canonical_u32(prank_c_ext); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) - .load(range_tuple_chip) + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) + .load_periphery(range_tuple) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_mulh_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [130, 9, 135, 241], [197, 85, 150, 32], [51, 109, 78, 142], @@ -227,8 +227,8 @@ fn rv32_mulh_wrong_a_mul_negative_test() { #[test] fn rv32_mulh_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [130, 9, 135, 242], [197, 85, 150, 32], [51, 109, 78, 142], @@ -241,8 +241,8 @@ fn rv32_mulh_wrong_a_negative_test() { #[test] fn rv32_mulh_wrong_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [1, 0, 0, 0], [0, 0, 0, 128], [2, 0, 0, 0], @@ -255,8 +255,8 @@ fn rv32_mulh_wrong_ext_negative_test() { #[test] fn rv32_mulh_invalid_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [3, 2, 2, 2], [0, 0, 0, 128], [2, 0, 0, 0], @@ -269,8 +269,8 @@ fn rv32_mulh_invalid_ext_negative_test() { #[test] fn rv32_mulhsu_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [174, 40, 246, 202], [197, 85, 150, 160], [51, 109, 78, 142], @@ -283,8 +283,8 @@ fn rv32_mulhsu_wrong_a_mul_negative_test() { #[test] fn rv32_mulhsu_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [174, 40, 246, 201], [197, 85, 150, 160], [51, 109, 78, 142], @@ -297,8 +297,8 @@ fn rv32_mulhsu_wrong_a_negative_test() { #[test] fn rv32_mulhsu_wrong_b_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [1, 0, 0, 0], [0, 0, 0, 128], [2, 0, 0, 0], @@ -311,8 +311,8 @@ fn rv32_mulhsu_wrong_b_ext_negative_test() { #[test] fn rv32_mulhsu_wrong_c_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [0, 0, 0, 64], [0, 0, 0, 128], [0, 0, 0, 128], @@ -325,8 +325,8 @@ fn rv32_mulhsu_wrong_c_ext_negative_test() { #[test] fn rv32_mulhu_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [130, 9, 135, 241], [197, 85, 150, 32], [51, 109, 78, 142], @@ -339,8 +339,8 @@ fn rv32_mulhu_wrong_a_mul_negative_test() { #[test] fn rv32_mulhu_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [130, 9, 135, 240], [197, 85, 150, 32], [51, 109, 78, 142], @@ -353,8 +353,8 @@ fn rv32_mulhu_wrong_a_negative_test() { #[test] fn rv32_mulhu_wrong_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [255, 255, 255, 255], [0, 0, 0, 128], [2, 0, 0, 0], @@ -380,7 +380,7 @@ fn run_mulh_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [303, 375, 449, 463]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULH, &x, &y); + run_mulh::(MULH, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -400,7 +400,7 @@ fn run_mulhu_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [107, 93, 18, 0]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHU, &x, &y); + run_mulh::(MULHU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -420,7 +420,7 @@ fn run_mulhsu_pos_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [107, 93, 18, 0]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHSU, &x, &y); + run_mulh::(MULHSU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -440,7 +440,7 @@ fn run_mulhsu_neg_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [212, 292, 326, 379]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 231]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHSU, &x, &y); + run_mulh::(MULHSU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); diff --git a/extensions/rv32im/circuit/src/shift/core.rs b/extensions/rv32im/circuit/src/shift/core.rs index cada97685e..5b74e8a547 100644 --- a/extensions/rv32im/circuit/src/shift/core.rs +++ b/extensions/rv32im/circuit/src/shift/core.rs @@ -3,17 +3,26 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::*, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::ShiftOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -21,10 +30,10 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; +use crate::adapters::imm_to_bytes; + #[repr(C)] #[derive(AlignedBorrow, Clone, Copy, Debug)] pub struct ShiftCoreCols { @@ -51,7 +60,10 @@ pub struct ShiftCoreCols { pub bit_shift_carry: [T; NUM_LIMBS], } -#[derive(Copy, Clone, Debug)] +/// RV32 shift AIR. +/// Note: when the shift amount from operand is greater than the number of bits, only shift +/// `shift_amount % num_bits` bits. This matches the RV32 specs for SLL/SRL/SRA. +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct ShiftCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_bus: VariableRangeCheckerBus, @@ -238,154 +250,370 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct ShiftCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - pub b_sign: T, - #[serde(with = "BigArray")] - pub bit_shift_carry: [u32; NUM_LIMBS], - pub bit_shift: usize, - pub limb_shift: usize, - pub opcode: ShiftOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct ShiftCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, } -pub struct ShiftCoreChip { - pub air: ShiftCoreAir, +#[derive(Clone, Copy)] +pub struct ShiftExecutor { + adapter: A, + pub offset: usize, +} +#[derive(Clone)] +pub struct ShiftFiller { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl ShiftCoreChip { +impl ShiftExecutor { + pub fn new(adapter: A, offset: usize) -> Self { + assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2"); + Self { adapter, offset } + } +} + +impl ShiftFiller { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: SharedVariableRangeCheckerChip, offset: usize, ) -> Self { assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2"); Self { - air: ShiftCoreAir { - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - range_bus: range_checker_chip.bus(), - offset, - }, + adapter, + offset, bitwise_lookup_chip, range_checker_chip, } } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for ShiftCoreChip +impl PreflightExecutor + for ShiftExecutor where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceExecutor< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, + for<'buf> RA: RecordArena< + 'buf, + EmptyAdapterCoreLayout, + ( + A::RecordMut<'buf>, + &'buf mut ShiftCoreRecord, + ), + >, { - type Record = ShiftCoreRecord; - type Air = ShiftCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", ShiftOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<(), ExecutionError> { let Instruction { opcode, .. } = instruction; - let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, limb_shift, bit_shift) = run_shift::(shift_opcode, &b, &c); + let local_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let bit_shift_carry = array::from_fn(|i| match shift_opcode { - ShiftOpcode::SLL => b[i] >> (LIMB_BITS - bit_shift), - _ => b[i] % (1 << bit_shift), - }); + let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new()); - let mut b_sign = 0; - if shift_opcode == ShiftOpcode::SRA { - b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1); - self.bitwise_lookup_chip - .request_xor(b[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1)); - } + A::start(*state.pc, state.memory, &mut adapter_record); - for i in 0..(NUM_LIMBS / 2) { - self.bitwise_lookup_chip - .request_range(a[i * 2], a[i * 2 + 1]); - } + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = ShiftCoreRecord { - opcode: shift_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - bit_shift_carry, - bit_shift, - limb_shift, - b_sign: F::from_canonical_u32(b_sign), - }; + let (output, _, _) = run_shift::(local_opcode, &rs1, &rs2); - Ok((output, record)) - } + core_record.b = rs1; + core_record.c = rs2; + core_record.local_opcode = local_opcode as u8; - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", ShiftOpcode::from_usize(opcode - self.air.offset)) + self.adapter.write( + state.memory, + instruction, + [output].into(), + &mut adapter_record, + ); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } +} + +impl TraceFiller + for ShiftFiller +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &ShiftCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - for carry_val in record.bit_shift_carry { - self.range_checker_chip - .add_count(carry_val, record.bit_shift); + let core_row: &mut ShiftCoreCols = core_row.borrow_mut(); + + let opcode = ShiftOpcode::from_usize(record.local_opcode as usize); + let (a, limb_shift, bit_shift) = + run_shift::(opcode, &record.b, &record.c); + + for pair in a.chunks_exact(2) { + self.bitwise_lookup_chip + .request_range(pair[0] as u32, pair[1] as u32); } let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); self.range_checker_chip.add_count( - (((record.c[0].as_canonical_u32() as usize) - - record.bit_shift - - record.limb_shift * LIMB_BITS) - >> num_bits_log) as u32, + ((record.c[0] as usize - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u32, LIMB_BITS - num_bits_log as usize, ); - let row_slice: &mut ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.bit_multiplier_left = match record.opcode { - ShiftOpcode::SLL => F::from_canonical_usize(1 << record.bit_shift), - _ => F::ZERO, + core_row.bit_shift_carry = if bit_shift == 0 { + for _ in 0..NUM_LIMBS { + self.range_checker_chip.add_count(0, 0); + } + [F::ZERO; NUM_LIMBS] + } else { + array::from_fn(|i| { + let carry = match opcode { + ShiftOpcode::SLL => record.b[i] >> (LIMB_BITS - bit_shift), + _ => record.b[i] % (1 << bit_shift), + }; + self.range_checker_chip.add_count(carry as u32, bit_shift); + F::from_canonical_u8(carry) + }) }; - row_slice.bit_multiplier_right = match record.opcode { + + core_row.limb_shift_marker = [F::ZERO; NUM_LIMBS]; + core_row.limb_shift_marker[limb_shift] = F::ONE; + core_row.bit_shift_marker = [F::ZERO; LIMB_BITS]; + core_row.bit_shift_marker[bit_shift] = F::ONE; + + core_row.b_sign = F::ZERO; + if opcode == ShiftOpcode::SRA { + core_row.b_sign = F::from_canonical_u8(record.b[NUM_LIMBS - 1] >> (LIMB_BITS - 1)); + self.bitwise_lookup_chip + .request_xor(record.b[NUM_LIMBS - 1] as u32, 1 << (LIMB_BITS - 1)); + } + + core_row.bit_multiplier_right = match opcode { ShiftOpcode::SLL => F::ZERO, - _ => F::from_canonical_usize(1 << record.bit_shift), + _ => F::from_canonical_usize(1 << bit_shift), + }; + core_row.bit_multiplier_left = match opcode { + ShiftOpcode::SLL => F::from_canonical_usize(1 << bit_shift), + _ => F::ZERO, }; - row_slice.b_sign = record.b_sign; - row_slice.bit_shift_marker = array::from_fn(|i| F::from_bool(i == record.bit_shift)); - row_slice.limb_shift_marker = array::from_fn(|i| F::from_bool(i == record.limb_shift)); - row_slice.bit_shift_carry = record.bit_shift_carry.map(F::from_canonical_u32); - row_slice.opcode_sll_flag = F::from_bool(record.opcode == ShiftOpcode::SLL); - row_slice.opcode_srl_flag = F::from_bool(record.opcode == ShiftOpcode::SRL); - row_slice.opcode_sra_flag = F::from_bool(record.opcode == ShiftOpcode::SRA); + + core_row.opcode_sra_flag = F::from_bool(opcode == ShiftOpcode::SRA); + core_row.opcode_srl_flag = F::from_bool(opcode == ShiftOpcode::SRL); + core_row.opcode_sll_flag = F::from_bool(opcode == ShiftOpcode::SLL); + + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct ShiftPreCompute { + c: u32, + a: u8, + b: u8, +} + +impl Executor + for ShiftExecutor +where + F: PrimeField32, +{ + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut ShiftPreCompute = data.borrow_mut(); + let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?; + // `d` is always expected to be RV32_REGISTER_AS. + let fn_ptr = match (is_imm, shift_opcode) { + (true, ShiftOpcode::SLL) => execute_e1_impl::<_, _, true, SllOp>, + (false, ShiftOpcode::SLL) => execute_e1_impl::<_, _, false, SllOp>, + (true, ShiftOpcode::SRL) => execute_e1_impl::<_, _, true, SrlOp>, + (false, ShiftOpcode::SRL) => execute_e1_impl::<_, _, false, SrlOp>, + (true, ShiftOpcode::SRA) => execute_e1_impl::<_, _, true, SraOp>, + (false, ShiftOpcode::SRA) => execute_e1_impl::<_, _, false, SraOp>, + }; + Ok(fn_ptr) + } +} + +impl MeteredExecutor + for ShiftExecutor +where + F: PrimeField32, +{ + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn air(&self) -> &Self::Air { - &self.air + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?; + // `d` is always expected to be RV32_REGISTER_AS. + let fn_ptr = match (is_imm, shift_opcode) { + (true, ShiftOpcode::SLL) => execute_e2_impl::<_, _, true, SllOp>, + (false, ShiftOpcode::SLL) => execute_e2_impl::<_, _, false, SllOp>, + (true, ShiftOpcode::SRL) => execute_e2_impl::<_, _, true, SrlOp>, + (false, ShiftOpcode::SRL) => execute_e2_impl::<_, _, false, SrlOp>, + (true, ShiftOpcode::SRA) => execute_e2_impl::<_, _, true, SraOp>, + (false, ShiftOpcode::SRA) => execute_e2_impl::<_, _, false, SraOp>, + }; + Ok(fn_ptr) } } +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_IMM: bool, + OP: ShiftOp, +>( + pre_compute: &ShiftPreCompute, + state: &mut VmExecState, +) { + let rs1 = state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = if IS_IMM { + pre_compute.c.to_le_bytes() + } else { + state.vm_read::(RV32_REGISTER_AS, pre_compute.c) + }; + let rs2 = u32::from_le_bytes(rs2); + + // Execute the shift operation + let rd = ::compute(rs1, rs2); + // Write the result back to memory + state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + state.instret += 1; + state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + state: &mut VmExecState, +) { + let pre_compute: &ShiftPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + state.ctx.on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, state); +} + +impl ShiftExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut ShiftPreCompute, + ) -> Result<(bool, ShiftOpcode), StaticProgramError> { + let Instruction { + opcode, a, b, c, e, .. + } = inst; + let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let e_u32 = e.as_canonical_u32(); + if inst.d.as_canonical_u32() != RV32_REGISTER_AS + || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS) + { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + let is_imm = e_u32 == RV32_IMM_AS; + let c_u32 = c.as_canonical_u32(); + *data = ShiftPreCompute { + c: if is_imm { + u32::from_le_bytes(imm_to_bytes(c_u32)) + } else { + c_u32 + }, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + // `d` is always expected to be RV32_REGISTER_AS. + Ok((is_imm, shift_opcode)) + } +} + +trait ShiftOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4]; +} +struct SllOp; +struct SrlOp; +struct SraOp; +impl ShiftOp for SllOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] { + let rs1 = u32::from_le_bytes(rs1); + // `rs2`'s other bits are ignored. + (rs1 << (rs2 & 0x1F)).to_le_bytes() + } +} +impl ShiftOp for SrlOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] { + let rs1 = u32::from_le_bytes(rs1); + // `rs2`'s other bits are ignored. + (rs1 >> (rs2 & 0x1F)).to_le_bytes() + } +} +impl ShiftOp for SraOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] { + let rs1 = i32::from_le_bytes(rs1); + // `rs2`'s other bits are ignored. + (rs1 >> (rs2 & 0x1F)).to_le_bytes() + } +} + +// Returns (result, limb_shift, bit_shift) +#[inline(always)] pub(super) fn run_shift( opcode: ShiftOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], usize, usize) { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], usize, usize) { match opcode { ShiftOpcode::SLL => run_shift_left::(x, y), ShiftOpcode::SRL => run_shift_right::(x, y, true), @@ -393,53 +621,60 @@ pub(super) fn run_shift( } } +#[inline(always)] fn run_shift_left( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], usize, usize) { - let mut result = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], usize, usize) { + let mut result = [0u8; NUM_LIMBS]; let (limb_shift, bit_shift) = get_shift::(y); for i in limb_shift..NUM_LIMBS { result[i] = if i > limb_shift { - ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) + (((x[i - limb_shift] as u16) << bit_shift) + | ((x[i - limb_shift - 1] as u16) >> (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) } else { - (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS) - }; + ((x[i - limb_shift] as u16) << bit_shift) % (1u16 << LIMB_BITS) + } as u8; } (result, limb_shift, bit_shift) } +#[inline(always)] fn run_shift_right( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], logical: bool, -) -> ([u32; NUM_LIMBS], usize, usize) { +) -> ([u8; NUM_LIMBS], usize, usize) { let fill = if logical { 0 } else { - ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) + (((1u16 << LIMB_BITS) - 1) as u8) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) }; let mut result = [fill; NUM_LIMBS]; let (limb_shift, bit_shift) = get_shift::(y); for i in 0..(NUM_LIMBS - limb_shift) { - result[i] = if i + limb_shift + 1 < NUM_LIMBS { - ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) + let res = if i + limb_shift + 1 < NUM_LIMBS { + (((x[i + limb_shift] >> bit_shift) as u16) + | ((x[i + limb_shift + 1] as u16) << (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) } else { - ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) - } + (((x[i + limb_shift] >> bit_shift) as u16) | ((fill as u16) << (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) + }; + result[i] = res as u8; } (result, limb_shift, bit_shift) } -fn get_shift(y: &[u32]) -> (usize, usize) { - // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so so the shift is defined +#[inline(always)] +fn get_shift(y: &[u8]) -> (usize, usize) { + debug_assert!(NUM_LIMBS * LIMB_BITS <= (1 << LIMB_BITS)); + // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so the shift is defined // entirely in y[0]. let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS); (shift / LIMB_BITS, shift % LIMB_BITS) diff --git a/extensions/rv32im/circuit/src/shift/mod.rs b/extensions/rv32im/circuit/src/shift/mod.rs index 58d5ad022b..3f585a5773 100644 --- a/extensions/rv32im/circuit/src/shift/mod.rs +++ b/extensions/rv32im/circuit/src/shift/mod.rs @@ -1,6 +1,9 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use super::adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -8,8 +11,14 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32ShiftAir = + VmAirWrapper>; +pub type Rv32ShiftExecutor = ShiftExecutor< + Rv32BaseAluAdapterExecutor, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, +>; pub type Rv32ShiftChip = VmChipWrapper< F, - Rv32BaseAluAdapterChip, - ShiftCoreChip, + ShiftFiller, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>, >; diff --git a/extensions/rv32im/circuit/src/shift/tests.rs b/extensions/rv32im/circuit/src/shift/tests.rs index 7a3ef6e72c..e1051a164b 100644 --- a/extensions/rv32im/circuit/src/shift/tests.rs +++ b/extensions/rv32im/circuit/src/shift/tests.rs @@ -1,17 +1,12 @@ -use std::{array, borrow::BorrowMut}; +use std::{array, borrow::BorrowMut, sync::Arc}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, - }, - utils::generate_long_number, -}; +use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::ShiftOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::ShiftOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -20,108 +15,147 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_shift, Rv32ShiftChip, ShiftCoreChip}; +use super::{core::run_shift, Rv32ShiftChip, ShiftCoreAir, ShiftCoreCols}; use crate::{ - adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - shift::ShiftCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, + RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, + Rv32ShiftAir, Rv32ShiftExecutor, ShiftFiller, }; type F = BabyBear; - -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// - -fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); +const MAX_INS_CAPACITY: usize = 128; +type Harness = TestChipHarness>; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let range_checker = tester.range_checker().clone(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32ShiftChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), + let air = Rv32ShiftAir::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), tester.memory_bridge(), - bitwise_chip.clone(), + bitwise_bus, ), - ShiftCoreChip::new( + ShiftCoreAir::new(bitwise_bus, range_checker.bus(), ShiftOpcode::CLASS_OFFSET), + ); + let executor = Rv32ShiftExecutor::new(Rv32BaseAluAdapterExecutor, ShiftOpcode::CLASS_OFFSET); + let chip = Rv32ShiftChip::::new( + ShiftFiller::new( + Rv32BaseAluAdapterFiller::new(bitwise_chip.clone()), bitwise_chip.clone(), - tester.memory_controller().borrow().range_checker.clone(), + range_checker.clone(), ShiftOpcode::CLASS_OFFSET, ), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) + (harness, (bitwise_chip.air, bitwise_chip)) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: ShiftOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) + generate_rv32_is_type_immediate(rng) }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(harness, &instruction); + + let (a, _, _) = run_shift::(opcode, &b, &c); + assert_eq!( + a.map(F::from_canonical_u8), + tester.read::(1, rd) + ) +} + +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(SLL, 100)] +#[test_case(SRL, 100)] +#[test_case(SRA, 100)] +fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise_chip) = create_test_chip(&tester); - let (instruction, rd) = rv32_rand_write_register_or_imm( + for _ in 0..num_ops { + set_and_execute( &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), + &mut harness, &mut rng, + opcode, + None, + None, + None, ); - tester.execute(&mut chip, &instruction); - - let (a, _, _) = run_shift::(opcode, &b, &c); - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ) } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise_chip) + .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_shift_sll_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SLL, 100); -} - -#[test] -fn rv32_shift_srl_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SRL, 100); -} - -#[test] -fn rv32_shift_sra_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SRA, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32ShiftTestChip = - VmChipWrapper, ShiftCoreChip>; - #[derive(Clone, Copy, Default, PartialEq)] struct ShiftPrankValues { pub bit_shift: Option, @@ -134,63 +168,35 @@ struct ShiftPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_shift_negative_test( +fn run_negative_shift_test( opcode: ShiftOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], prank_vals: ShiftPrankValues, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = Rv32ShiftTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - ShiftCoreChip::new( - bitwise_chip.clone(), - range_checker_chip.clone(), - ShiftOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); - - tester.execute( - &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + let (mut harness, bitwise) = create_test_chip(&tester); + + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + Some(b), + Some(false), + Some(c), ); - let bit_shift = prank_vals - .bit_shift - .unwrap_or(c[0] % (RV32_CELL_BITS as u32)); - let bit_shift_carry = prank_vals - .bit_shift_carry - .unwrap_or(array::from_fn(|i| match opcode { - ShiftOpcode::SLL => b[i] >> ((RV32_CELL_BITS as u32) - bit_shift), - _ => b[i] % (1 << bit_shift), - })); - - range_checker_chip.clear(); - range_checker_chip.add_count(bit_shift, RV32_CELL_BITS.ilog2() as usize); - for (a_val, carry_val) in a.iter().zip(bit_shift_carry.iter()) { - range_checker_chip.add_count(*a_val, RV32_CELL_BITS); - range_checker_chip.add_count(*carry_val, bit_shift as usize); - } - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - + let adapter_width = BaseAir::::width(&harness.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut ShiftCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); + cols.a = prank_a.map(F::from_canonical_u32); if let Some(bit_multiplier_left) = prank_vals.bit_multiplier_left { cols.bit_multiplier_left = F::from_canonical_u32(bit_multiplier_left); } @@ -210,21 +216,16 @@ fn run_rv32_shift_negative_test( cols.bit_shift_carry = bit_shift_carry.map(F::from_canonical_u32); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -233,9 +234,9 @@ fn rv32_shift_wrong_negative_test() { let b = [1, 0, 0, 0]; let c = [1, 0, 0, 0]; let prank_vals = Default::default(); - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SLL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -249,7 +250,7 @@ fn rv32_sll_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -261,7 +262,7 @@ fn rv32_sll_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 0, 1, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -273,7 +274,7 @@ fn rv32_sll_wrong_bit_carry_negative_test() { bit_shift_carry: Some([0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -286,7 +287,7 @@ fn rv32_sll_wrong_bit_mult_side_negative_test() { bit_multiplier_right: Some(1), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, false); + run_negative_shift_test(SLL, a, b, c, prank_vals, false); } #[test] @@ -300,7 +301,7 @@ fn rv32_srl_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); } #[test] @@ -312,7 +313,7 @@ fn rv32_srl_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 1, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); } #[test] @@ -325,8 +326,8 @@ fn rv32_srx_wrong_bit_mult_side_negative_test() { bit_multiplier_right: Some(0), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -340,7 +341,7 @@ fn rv32_sra_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -352,7 +353,7 @@ fn rv32_sra_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 1, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -364,7 +365,7 @@ fn rv32_sra_wrong_sign_negative_test() { b_sign: Some(0), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, true); + run_negative_shift_test(SRA, a, b, c, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -375,11 +376,11 @@ fn rv32_sra_wrong_sign_negative_test() { #[test] fn run_sll_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 7, 61, 186]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [91, 0, 100, 0]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [0, 0, 0, 104]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 7, 61, 186]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [91, 0, 100, 0]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [0, 0, 0, 104]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SLL, &x, &y); + run_shift::(SLL, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -390,11 +391,11 @@ fn run_sll_sanity_test() { #[test] fn run_srl_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [49, 190, 190, 190]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [110, 100, 0, 0]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [49, 190, 190, 190]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [110, 100, 0, 0]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SRL, &x, &y); + run_shift::(SRL, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -405,11 +406,11 @@ fn run_srl_sanity_test() { #[test] fn run_sra_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [113, 20, 50, 80]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [110, 228, 255, 255]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [113, 20, 50, 80]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [110, 228, 255, 255]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SRA, &x, &y); + run_shift::(SRA, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } diff --git a/extensions/rv32im/circuit/src/test_utils.rs b/extensions/rv32im/circuit/src/test_utils.rs index 8a105ff990..f018b0d845 100644 --- a/extensions/rv32im/circuit/src/test_utils.rs +++ b/extensions/rv32im/circuit/src/test_utils.rs @@ -1,6 +1,6 @@ use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, VmOpcode}; -use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_backend::{p3_field::FieldAlgebra, verifier::VerificationError}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use rand::{rngs::StdRng, Rng}; @@ -10,8 +10,8 @@ use super::adapters::{RV32_REGISTER_NUM_LIMBS, RV_IS_TYPE_IMM_BITS}; #[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] pub fn rv32_rand_write_register_or_imm( tester: &mut VmChipTestBuilder, - rs1_writes: [u32; NUM_LIMBS], - rs2_writes: [u32; NUM_LIMBS], + rs1_writes: [u8; NUM_LIMBS], + rs2_writes: [u8; NUM_LIMBS], imm: Option, opcode_with_offset: usize, rng: &mut StdRng, @@ -22,9 +22,9 @@ pub fn rv32_rand_write_register_or_imm( let rs2 = imm.unwrap_or_else(|| gen_pointer(rng, NUM_LIMBS)); let rd = gen_pointer(rng, NUM_LIMBS); - tester.write::(1, rs1, rs1_writes.map(BabyBear::from_canonical_u32)); + tester.write::(1, rs1, rs1_writes.map(BabyBear::from_canonical_u8)); if !rs2_is_imm { - tester.write::(1, rs2, rs2_writes.map(BabyBear::from_canonical_u32)); + tester.write::(1, rs2, rs2_writes.map(BabyBear::from_canonical_u8)); } ( @@ -37,9 +37,7 @@ pub fn rv32_rand_write_register_or_imm( } #[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] -pub fn generate_rv32_is_type_immediate( - rng: &mut StdRng, -) -> (usize, [u32; RV32_REGISTER_NUM_LIMBS]) { +pub fn generate_rv32_is_type_immediate(rng: &mut StdRng) -> (usize, [u8; RV32_REGISTER_NUM_LIMBS]) { let mut imm: u32 = rng.gen_range(0..(1 << RV_IS_TYPE_IMM_BITS)); if (imm & 0x800) != 0 { imm |= !0xFFF @@ -51,7 +49,17 @@ pub fn generate_rv32_is_type_immediate( (imm >> 8) as u8, (imm >> 16) as u8, (imm >> 16) as u8, - ] - .map(|x| x as u32), + ], ) } + +/// Returns the corresponding verification error based on whether +/// an interaction error or a constraint error is expected +#[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] +pub fn get_verification_error(is_interaction_error: bool) -> VerificationError { + if is_interaction_error { + VerificationError::ChallengePhaseError + } else { + VerificationError::OodEvaluationMismatch + } +} diff --git a/extensions/rv32im/tests/Cargo.toml b/extensions/rv32im/tests/Cargo.toml index 2e68359532..45eb4c1654 100644 --- a/extensions/rv32im/tests/Cargo.toml +++ b/extensions/rv32im/tests/Cargo.toml @@ -20,6 +20,7 @@ openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } eyre.workspace = true test-case.workspace = true serde = { workspace = true, features = ["alloc"] } +strum.workspace = true [features] default = ["parallel"] diff --git a/extensions/rv32im/tests/src/lib.rs b/extensions/rv32im/tests/src/lib.rs index a4de516462..7902b38aa8 100644 --- a/extensions/rv32im/tests/src/lib.rs +++ b/extensions/rv32im/tests/src/lib.rs @@ -5,14 +5,17 @@ mod tests { use eyre::Result; use openvm_circuit::{ arch::{hasher::poseidon2::vm_poseidon2_hasher, ExecutionError, Streams, VmExecutor}, - system::memory::tree::public_values::UserPublicValuesProof, - utils::{air_test, air_test_with_min_segments}, + system::memory::merkle::public_values::UserPublicValuesProof, + utils::{air_test, air_test_with_min_segments, test_system_config_with_continuations}, }; - use openvm_instructions::exe::VmExe; - use openvm_rv32im_circuit::{Rv32IConfig, Rv32ImConfig}; + use openvm_instructions::{exe::VmExe, instruction::Instruction, LocalOpcode, SystemOpcode}; + #[cfg(test)] + use openvm_rv32im_circuit::Rv32ImCpuBuilder; + use openvm_rv32im_circuit::{Rv32IConfig, Rv32ICpuBuilder, Rv32ImConfig}; use openvm_rv32im_guest::hint_load_by_key_encode; use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, + DivRemOpcode, MulHOpcode, MulOpcode, Rv32ITranspilerExtension, Rv32IoTranspilerExtension, + Rv32MTranspilerExtension, }; use openvm_stark_sdk::{openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear}; use openvm_toolchain_tests::{ @@ -20,28 +23,42 @@ mod tests { get_programs_dir, }; use openvm_transpiler::{transpiler::Transpiler, FromElf}; + use strum::IntoEnumIterator; use test_case::test_case; type F = BabyBear; + #[cfg(test)] + fn test_rv32im_config() -> Rv32ImConfig { + Rv32ImConfig { + rv32i: Rv32IConfig { + system: test_system_config_with_continuations(), + ..Default::default() + }, + ..Default::default() + } + } + #[test_case("fibonacci", 1)] fn test_rv32i(example_name: &str, min_segments: usize) -> Result<()> { let config = Rv32IConfig::default(); let elf = build_example_program_at_path(get_programs_dir!(), example_name, &config)?; - let exe = VmExe::from_elf( + let mut exe = VmExe::from_elf( elf, Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - air_test_with_min_segments(config, exe, vec![], min_segments); + change_rv32m_insn_to_nop(&mut exe); + air_test_with_min_segments(Rv32ICpuBuilder, config, exe, vec![], min_segments); Ok(()) } + #[test_case("fibonacci", 1)] #[test_case("collatz", 1)] fn test_rv32im(example_name: &str, min_segments: usize) -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), example_name, &config)?; let exe = VmExe::from_elf( elf, @@ -50,14 +67,14 @@ mod tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(Rv32MTranspilerExtension), )?; - air_test_with_min_segments(config, exe, vec![], min_segments); + air_test_with_min_segments(Rv32ImCpuBuilder, config, exe, vec![], min_segments); Ok(()) } - // #[test_case("fibonacci", 1)] + #[test_case("fibonacci", 1)] #[test_case("collatz", 1)] fn test_rv32im_std(example_name: &str, min_segments: usize) -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), example_name, @@ -71,13 +88,13 @@ mod tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(Rv32MTranspilerExtension), )?; - air_test_with_min_segments(config, exe, vec![], min_segments); + air_test_with_min_segments(Rv32ImCpuBuilder, config, exe, vec![], min_segments); Ok(()) } #[test] fn test_read_vec() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "hint", &config)?; let exe = VmExe::from_elf( elf, @@ -87,13 +104,13 @@ mod tests { .with_extension(Rv32IoTranspilerExtension), )?; let input = vec![[0, 1, 2, 3].map(F::from_canonical_u8).to_vec()]; - air_test_with_min_segments(config, exe, input, 1); + air_test_with_min_segments(Rv32ImCpuBuilder, config, exe, input, 1); Ok(()) } #[test] fn test_hint_load_by_key() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "hint_load_by_key", &config)?; let exe = VmExe::from_elf( elf, @@ -110,13 +127,13 @@ mod tests { "key".as_bytes().to_vec(), hint_load_by_key_encode(&input), )])); - air_test_with_min_segments(config, exe, streams, 1); + air_test_with_min_segments(Rv32ImCpuBuilder, config, exe, streams, 1); Ok(()) } #[test] fn test_read() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "read", &config)?; let exe = VmExe::from_elf( elf, @@ -141,13 +158,13 @@ mod tests { .flat_map(|w| w.to_le_bytes()) .map(F::from_canonical_u8) .collect(); - air_test_with_min_segments(config, exe, vec![input], 1); + air_test_with_min_segments(Rv32ImCpuBuilder, config, exe, vec![input], 1); Ok(()) } #[test] fn test_reveal() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "reveal", &config)?; let exe = VmExe::from_elf( elf, @@ -156,11 +173,14 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - let executor = VmExecutor::::new(config.clone()); - let final_memory = executor.execute(exe, vec![])?.unwrap(); - let hasher = vm_poseidon2_hasher(); + + let executor = VmExecutor::new(config.clone())?; + let instance = executor.instance(&exe)?; + let state = instance.execute(vec![], None)?; + let final_memory = state.memory.memory; + let hasher = vm_poseidon2_hasher::(); let pv_proof = UserPublicValuesProof::compute( - config.system.memory_config.memory_dimensions(), + config.as_ref().memory_config.memory_dimensions(), 64, &hasher, &final_memory, @@ -186,7 +206,7 @@ mod tests { #[test] fn test_print() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "print", &config)?; let exe = VmExe::from_elf( elf, @@ -195,13 +215,13 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - air_test(config, exe); + air_test(Rv32ImCpuBuilder, config, exe); Ok(()) } #[test] fn test_heap_overflow() -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "heap_overflow", &config)?; let exe = VmExe::from_elf( elf, @@ -211,8 +231,10 @@ mod tests { .with_extension(Rv32IoTranspilerExtension), )?; - let executor = VmExecutor::::new(config.clone()); - match executor.execute(exe, vec![[0, 0, 0, 1].map(F::from_canonical_u8).to_vec()]) { + let executor = VmExecutor::new(config)?; + let instance = executor.instance(&exe)?; + let input = vec![[0, 0, 0, 1].map(F::from_canonical_u8).to_vec()]; + match instance.execute(input.clone(), None) { Err(ExecutionError::FailedWithExitCode(_)) => Ok(()), Err(_) => panic!("should fail with `FailedWithExitCode`"), Ok(_) => panic!("should fail"), @@ -221,7 +243,7 @@ mod tests { #[test] fn test_hashmap() -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "hashmap", @@ -235,13 +257,13 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - air_test(config, exe); + air_test(Rv32ImCpuBuilder, config, exe); Ok(()) } #[test] fn test_tiny_mem_test() -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "tiny-mem-test", @@ -255,14 +277,14 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - air_test(config, exe); + air_test(Rv32ImCpuBuilder, config, exe); Ok(()) } #[test] #[should_panic] fn test_load_x0() { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "load_x0", &config).unwrap(); let exe = VmExe::from_elf( elf, @@ -272,8 +294,9 @@ mod tests { .with_extension(Rv32IoTranspilerExtension), ) .unwrap(); - let executor = VmExecutor::::new(config.clone()); - executor.execute(exe, vec![]).unwrap(); + let executor = VmExecutor::new(config).unwrap(); + let instance = executor.instance(&exe).unwrap(); + instance.execute(vec![], None).unwrap(); } #[test_case("getrandom", vec!["getrandom", "getrandom-unsupported"])] @@ -281,7 +304,7 @@ mod tests { #[test_case("getrandom_v02", vec!["getrandom-v02", "getrandom-unsupported"])] #[test_case("getrandom_v02", vec!["getrandom-v02/custom"])] fn test_getrandom_unsupported(program: &str, features: Vec<&str>) { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), program, @@ -297,6 +320,26 @@ mod tests { .with_extension(Rv32IoTranspilerExtension), ) .unwrap(); - air_test(config, exe); + air_test(Rv32ImCpuBuilder, config, exe); + } + + // For testing programs that should only execute RV32I: + // The ELF might still have Mul instructions even though the program doesn't use them. We + // mask those to NOP here. + fn change_rv32m_insn_to_nop(exe: &mut VmExe) { + for (insn, _) in exe + .program + .instructions_and_debug_infos + .iter_mut() + .flatten() + { + if MulOpcode::iter().any(|op| op.global_opcode() == insn.opcode) + || MulHOpcode::iter().any(|op| op.global_opcode() == insn.opcode) + || DivRemOpcode::iter().any(|op| op.global_opcode() == insn.opcode) + { + *insn = Instruction::default(); + insn.opcode = SystemOpcode::PHANTOM.global_opcode(); + } + } } } diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha256/circuit/Cargo.toml index 95c87b0871..109978f016 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha256/circuit/Cargo.toml @@ -9,7 +9,6 @@ description = "OpenVM circuit extension for sha256" openvm-stark-backend = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-circuit-primitives = { workspace = true } -openvm-circuit-primitives-derive = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } diff --git a/extensions/sha256/circuit/src/extension.rs b/extensions/sha256/circuit/src/extension.rs index 783bc54f63..e58dc3c29b 100644 --- a/extensions/sha256/circuit/src/extension.rs +++ b/extensions/sha256/circuit/src/extension.rs @@ -1,105 +1,128 @@ +use std::{result::Result, sync::Arc}; + use derive_more::derive::From; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, VmExtension, VmInventory, VmInventoryBuilder, - VmInventoryError, + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, }, - system::phantom::PhantomChip, + system::memory::SharedMemoryHelper, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::*; -use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, -}; use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; use crate::*; -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Sha256Rv32Config { - #[system] - pub system: SystemConfig, - #[extension] - pub rv32i: Rv32I, - #[extension] - pub rv32m: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub sha256: Sha256, +// =================================== VM Extension Implementation ================================= +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub struct Sha256; + +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Sha256Executor { + Sha256(Sha256VmExecutor), } -impl Default for Sha256Rv32Config { - fn default() -> Self { - Self { - system: SystemConfig::default().with_continuations(), - rv32i: Rv32I, - rv32m: Rv32M::default(), - io: Rv32Io, - sha256: Sha256, - } +impl VmExecutionExtension for Sha256 { + type Executor = Sha256Executor; + + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + let sha256_step = Sha256VmExecutor::new(Rv32Sha256Opcode::CLASS_OFFSET, pointer_max_bits); + inventory.add_executor( + sha256_step, + Rv32Sha256Opcode::iter().map(|x| x.global_opcode()), + )?; + + Ok(()) } } -// Default implementation uses no init file -impl InitFileGenerator for Sha256Rv32Config {} +impl VmCircuitExtension for Sha256 { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -pub struct Sha256; + let bitwise_lu = { + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } + }; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum Sha256Executor { - Sha256(Sha256VmChip), -} + let sha256 = Sha256VmAir::new( + inventory.system().port(), + bitwise_lu, + pointer_max_bits, + inventory.new_bus_idx(), + ); + inventory.add_air(sha256); -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Sha256Periphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - Phantom(PhantomChip), + Ok(()) + } } -impl VmExtension for Sha256 { - type Executor = Sha256Executor; - type Periphery = Sha256Periphery; - - fn build( +pub struct Sha2CpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Sha2CpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip + _: &Sha256, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } }; - let sha256_chip = Sha256VmChip::new( - builder.system_port(), - builder.system_config().memory_config.pointer_max_bits, - bitwise_lu_chip, - builder.new_bus_idx(), - Rv32Sha256Opcode::CLASS_OFFSET, - builder.system_base().offline_memory(), + inventory.next_air::()?; + let sha256 = Sha256VmChip::new( + Sha256VmFiller::new(bitwise_lu, pointer_max_bits), + mem_helper, ); - inventory.add_executor( - sha256_chip, - Rv32Sha256Opcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor_chip(sha256); - Ok(inventory) + Ok(()) } } diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs index fe0844f902..58b3ff0f13 100644 --- a/extensions/sha256/circuit/src/lib.rs +++ b/extensions/sha256/circuit/src/lib.rs @@ -1,5 +1,87 @@ +use std::result::Result; + +use openvm_circuit::{ + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, + VmBuilder, VmChipComplex, VmProverExtension, + }, + system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, +}; +use openvm_circuit_derive::VmConfig; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; +use serde::{Deserialize, Serialize}; + mod sha256_chip; pub use sha256_chip::*; mod extension; pub use extension::*; + +#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] +pub struct Sha256Rv32Config { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub sha256: Sha256, +} + +impl Default for Sha256Rv32Config { + fn default() -> Self { + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + sha256: Sha256, + } + } +} + +// Default implementation uses no init file +impl InitFileGenerator for Sha256Rv32Config {} + +#[derive(Clone)] +pub struct Sha256Rv32CpuBuilder; + +impl VmBuilder for Sha256Rv32CpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Sha256Rv32Config; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Sha256Rv32Config, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; + Ok(chip_complex) + } +} diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs index f4f1df34eb..2fe1cb26c0 100644 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ b/extensions/sha256/circuit/src/sha256_chip/air.rs @@ -2,7 +2,10 @@ use std::{array, borrow::Borrow, cmp::min}; use openvm_circuit::{ arch::ExecutionBridge, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, + system::{ + memory::{offline_checker::MemoryBridge, MemoryAddress}, + SystemPort, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, @@ -17,7 +20,7 @@ use openvm_sha256_air::{ }; use openvm_sha256_transpiler::Rv32Sha256Opcode; use openvm_stark_backend::{ - interaction::InteractionBuilder, + interaction::{BusIndex, InteractionBuilder}, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra}, p3_matrix::Matrix, @@ -31,7 +34,7 @@ use super::{ /// Sha256VmAir does all constraints related to message padding and /// the Sha256Air subair constrains the actual hash -#[derive(Clone, Debug, derive_new::new)] +#[derive(Clone, Debug)] pub struct Sha256VmAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, @@ -44,6 +47,28 @@ pub struct Sha256VmAir { pub(super) padding_encoder: Encoder, } +impl Sha256VmAir { + pub fn new( + SystemPort { + execution_bus, + program_bus, + memory_bridge, + }: SystemPort, + bitwise_lookup_bus: BitwiseOperationLookupBus, + ptr_max_bits: usize, + self_bus_idx: BusIndex, + ) -> Self { + Self { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lookup_bus, + ptr_max_bits, + sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx), + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), + } + } +} + impl BaseAirWithPublicValues for Sha256VmAir {} impl PartitionedBaseAir for Sha256VmAir {} impl BaseAir for Sha256VmAir { diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha256/circuit/src/sha256_chip/mod.rs index 4c40eca5d8..54b07c8bff 100644 --- a/extensions/sha256/circuit/src/sha256_chip/mod.rs +++ b/extensions/sha256/circuit/src/sha256_chip/mod.rs @@ -1,16 +1,11 @@ //! Sha256 hasher. Handles full sha256 hashing with padding. //! variable length inputs read from VM memory. -use std::{ - array, - cmp::{max, min}, - sync::{Arc, Mutex}, -}; -use openvm_circuit::arch::{ - ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, SystemPort, -}; +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives::{ - bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, AlignedBytesBorrow, }; use openvm_instructions::{ instruction::Instruction, @@ -18,11 +13,11 @@ use openvm_instructions::{ riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, LocalOpcode, }; -use openvm_rv32im_circuit::adapters::read_rv32_register; -use openvm_sha256_air::{Sha256Air, SHA256_BLOCK_BITS}; +use openvm_sha256_air::{ + get_sha256_num_blocks, Sha256FillerHelper, SHA256_BLOCK_BITS, SHA256_ROWS_PER_BLOCK, +}; use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{interaction::BusIndex, p3_field::PrimeField32}; -use serde::{Deserialize, Serialize}; +use openvm_stark_backend::p3_field::PrimeField32; use sha2::{Digest, Sha256}; mod air; @@ -31,7 +26,7 @@ mod trace; pub use air::*; pub use columns::*; -use openvm_circuit::system::memory::{MemoryController, OfflineMemory, RecordId}; +pub use trace::*; #[cfg(test)] mod tests; @@ -47,65 +42,155 @@ const SHA256_WRITE_SIZE: usize = 32; pub const SHA256_BLOCK_CELLS: usize = SHA256_BLOCK_BITS / RV32_CELL_BITS; /// Number of rows we will do a read on for each SHA256 block pub const SHA256_NUM_READ_ROWS: usize = SHA256_BLOCK_CELLS / SHA256_READ_SIZE; -pub struct Sha256VmChip { - pub air: Sha256VmAir, - /// IO and memory data necessary for each opcode call - pub records: Vec>, - pub offline_memory: Arc>>, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - - offset: usize, +/// Maximum message length that this chip supports in bytes +pub const SHA256_MAX_MESSAGE_LEN: usize = 1 << 29; + +pub type Sha256VmChip = VmChipWrapper; + +#[derive(derive_new::new, Clone)] +pub struct Sha256VmExecutor { + pub offset: usize, + pub pointer_max_bits: usize, } -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Sha256Record { - pub from_state: ExecutionState, - pub dst_read: RecordId, - pub src_read: RecordId, - pub len_read: RecordId, - pub input_records: Vec<[RecordId; SHA256_NUM_READ_ROWS]>, - pub input_message: Vec<[[u8; SHA256_READ_SIZE]; SHA256_NUM_READ_ROWS]>, - pub digest_write: RecordId, +pub struct Sha256VmFiller { + pub inner: Sha256FillerHelper, + pub padding_encoder: Encoder, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub pointer_max_bits: usize, } -impl Sha256VmChip { +impl Sha256VmFiller { pub fn new( - SystemPort { - execution_bus, - program_bus, - memory_bridge, - }: SystemPort, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - self_bus_idx: BusIndex, - offset: usize, - offline_memory: Arc>>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, ) -> Self { Self { - air: Sha256VmAir::new( - ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_chip.bus(), - address_bits, - Sha256Air::new(bitwise_lookup_chip.bus(), self_bus_idx), - Encoder::new(PaddingFlags::COUNT, 2, false), - ), + inner: Sha256FillerHelper::new(), + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), bitwise_lookup_chip, - records: Vec::new(), - offset, - offline_memory, + pointer_max_bits, } } } -impl InstructionExecutor for Sha256VmChip { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let &Instruction { +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct ShaPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor for Sha256VmExecutor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut ShaPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _>) + } +} +impl MeteredExecutor for Sha256VmExecutor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _>) + } +} + +unsafe fn execute_e12_impl( + pre_compute: &ShaPreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); + let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let dst_u32 = u32::from_le_bytes(dst); + let src_u32 = u32::from_le_bytes(src); + let len_u32 = u32::from_le_bytes(len); + + let (output, height) = if IS_E1 { + // SAFETY: RV32_MEMORY_AS is memory address space of type u8 + let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); + let output = sha256_solve(message); + (output, 0) + } else { + let num_blocks = get_sha256_num_blocks(len_u32); + let mut message = Vec::with_capacity(len_u32 as usize); + for block_idx in 0..num_blocks as usize { + // Reads happen on the first 4 rows of each block + for row in 0..SHA256_NUM_READ_ROWS { + let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; + let row_input: [u8; SHA256_READ_SIZE] = vm_state.vm_read( + RV32_MEMORY_AS, + src_u32 + (read_idx * SHA256_READ_SIZE) as u32, + ); + message.extend_from_slice(&row_input); + } + } + let output = sha256_solve(&message[..len_u32 as usize]); + let height = num_blocks * SHA256_ROWS_PER_BLOCK as u32; + (output, height) + }; + vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + height +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &ShaPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +impl Sha256VmExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut ShaPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { opcode, a, b, @@ -113,87 +198,18 @@ impl InstructionExecutor for Sha256VmChip { d, e, .. - } = instruction; - let local_opcode = opcode.local_opcode_idx(self.offset); - debug_assert_eq!(local_opcode, Rv32Sha256Opcode::SHA256.local_usize()); - debug_assert_eq!(d, F::from_canonical_u32(RV32_REGISTER_AS)); - debug_assert_eq!(e, F::from_canonical_u32(RV32_MEMORY_AS)); - - debug_assert_eq!(from_state.timestamp, memory.timestamp()); - - let (dst_read, dst) = read_rv32_register(memory, d, a); - let (src_read, src) = read_rv32_register(memory, d, b); - let (len_read, len) = read_rv32_register(memory, d, c); - - #[cfg(debug_assertions)] - { - assert!(dst < (1 << self.air.ptr_max_bits)); - assert!(src < (1 << self.air.ptr_max_bits)); - assert!(len < (1 << self.air.ptr_max_bits)); - } - - // need to pad with one 1 bit, 64 bits for the message length and then pad until the length - // is divisible by [SHA256_BLOCK_BITS] - let num_blocks = ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); - - // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used - debug_assert!( - src as usize + num_blocks * SHA256_BLOCK_CELLS <= (1 << self.air.ptr_max_bits) - ); - let mut hasher = Sha256::new(); - let mut input_records = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS); - let mut input_message = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS); - let mut read_ptr = src; - for _ in 0..num_blocks { - let block_reads_records = array::from_fn(|i| { - memory.read( - e, - F::from_canonical_u32(read_ptr + (i * SHA256_READ_SIZE) as u32), - ) - }); - let block_reads_bytes = array::from_fn(|i| { - // we add to the hasher only the bytes that are part of the message - let num_reads = min( - SHA256_READ_SIZE, - (max(read_ptr, src + len) - read_ptr) as usize, - ); - let row_input = block_reads_records[i] - .1 - .map(|x| x.as_canonical_u32().try_into().unwrap()); - hasher.update(&row_input[..num_reads]); - read_ptr += SHA256_READ_SIZE as u32; - row_input - }); - input_records.push(block_reads_records.map(|x| x.0)); - input_message.push(block_reads_bytes); + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); } - - let mut digest = [0u8; SHA256_WRITE_SIZE]; - digest.copy_from_slice(hasher.finalize().as_ref()); - let (digest_write, _) = memory.write( - e, - F::from_canonical_u32(dst), - digest.map(|b| F::from_canonical_u8(b)), - ); - - self.records.push(Sha256Record { - from_state: from_state.map(F::from_canonical_u32), - dst_read, - src_read, - len_read, - input_records, - input_message, - digest_write, - }); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) - } - - fn get_opcode_name(&self, _: usize) -> String { - "SHA256".to_string() + *data = ShaPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + assert_eq!(&Rv32Sha256Opcode::SHA256.global_opcode(), opcode); + Ok(()) } } diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha256/circuit/src/sha256_chip/tests.rs index 55bc076e2c..3a2a533c39 100644 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ b/extensions/sha256/circuit/src/sha256_chip/tests.rs @@ -1,31 +1,74 @@ -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - SystemPort, +use std::{array, sync::Arc}; + +use openvm_circuit::{ + arch::{ + testing::{memory::gen_pointer, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + Arena, DenseRecordArena, MatrixRecordArena, PreflightExecutor, + }, + utils::get_random_message, }; use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_sha256_air::get_random_message; use openvm_sha256_transpiler::Rv32Sha256Opcode::{self, *}; use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::Sha256VmChip; -use crate::{sha256_solve, Sha256VmDigestCols, Sha256VmRoundCols}; +use super::{Sha256VmAir, Sha256VmChip, Sha256VmExecutor}; +use crate::{ + sha256_chip::trace::Sha256VmRecordLayout, sha256_solve, Sha256VmDigestCols, Sha256VmFiller, + Sha256VmRoundCols, +}; type F = BabyBear; -const BUS_IDX: BusIndex = 28; -fn set_and_execute( +const SELF_BUS_IDX: BusIndex = 28; +const MAX_INS_CAPACITY: usize = 4096; +type Harness = TestChipHarness, RA>; + +fn create_test_chips( tester: &mut VmChipTestBuilder, - chip: &mut Sha256VmChip, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = Sha256VmAir::new( + tester.system_port(), + bitwise_chip.bus(), + tester.address_bits(), + SELF_BUS_IDX, + ); + let executor = Sha256VmExecutor::new(Rv32Sha256Opcode::CLASS_OFFSET, tester.address_bits()); + let chip = Sha256VmChip::new( + Sha256VmFiller::new(bitwise_chip.clone(), tester.address_bits()), + tester.memory_helper(), + ); + + let harness = Harness::::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + (harness, (bitwise_chip.air, bitwise_chip)) +} + +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, rng: &mut StdRng, opcode: Rv32Sha256Opcode, message: Option<&[u8]>, len: Option, -) { - let len = len.unwrap_or(rng.gen_range(1..100000)); +) where + Sha256VmExecutor: PreflightExecutor, +{ + let len = len.unwrap_or(rng.gen_range(1..3000)); let tmp = get_random_message(rng, len); let message: &[u8] = message.unwrap_or(&tmp); let len = message.len(); @@ -34,12 +77,7 @@ fn set_and_execute( let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - let max_mem_ptr: u32 = 1 - << tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits; + let max_mem_ptr: u32 = 1 << tester.address_bits(); let dst_ptr = rng.gen_range(0..max_mem_ptr); let dst_ptr = dst_ptr ^ (dst_ptr & 3); tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); @@ -48,12 +86,17 @@ fn set_and_execute( tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); - for (i, &byte) in message.iter().enumerate() { - tester.write(2, src_ptr as usize + i, [F::from_canonical_u8(byte)]); - } + message.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); + tester.write( + 2, + src_ptr as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), + ); + }); tester.execute( - chip, + harness, &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), ); @@ -75,27 +118,18 @@ fn rand_sha256_test() { setup_tracing(); let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChip::new( - SystemPort { - execution_bus: tester.execution_bus(), - program_bus: tester.program_bus(), - memory_bridge: tester.memory_bridge(), - }, - tester.address_bits(), - bitwise_chip.clone(), - BUS_IDX, - Rv32Sha256Opcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); + let (mut harness, bitwise) = create_test_chips(&mut tester); - let num_tests: usize = 3; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut harness, &mut rng, SHA256, None, None); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); tester.simple_test().expect("Verification failed"); } @@ -108,20 +142,7 @@ fn rand_sha256_test() { fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChip::new( - SystemPort { - execution_bus: tester.execution_bus(), - program_bus: tester.program_bus(), - memory_bridge: tester.memory_bridge(), - }, - tester.address_bits(), - bitwise_chip.clone(), - BUS_IDX, - Rv32Sha256Opcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); + let (mut harness, _) = create_test_chips::>(&mut tester); println!( "Sha256VmDigestCols::width(): {}", @@ -133,7 +154,7 @@ fn execute_roundtrip_sanity_test() { ); let num_tests: usize = 1; for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); + set_and_execute(&mut tester, &mut harness, &mut rng, SHA256, None, None); } } @@ -147,3 +168,47 @@ fn sha256_solve_sanity_check() { ]; assert_eq!(output, expected); } + +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_harness, bitwise) = create_test_chips(&mut tester); + + { + let mut dense_harness = create_test_chips::(&mut tester).0; + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut dense_harness, + &mut rng, + SHA256, + None, + None, + ); + } + + let mut record_interpreter = dense_harness + .arena + .get_record_seeker::<_, Sha256VmRecordLayout>(); + record_interpreter.transfer_to_matrix_arena(&mut sparse_harness.arena); + } + + let tester = tester + .build() + .load(sparse_harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs index c02cd00dd8..486d295bf5 100644 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ b/extensions/sha256/circuit/src/sha256_chip/trace.rs @@ -1,351 +1,594 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; +use std::{ + array, + borrow::{Borrow, BorrowMut}, + cmp::min, +}; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use openvm_rv32im_circuit::adapters::compose; +use openvm_circuit::{ + arch::*, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; use openvm_sha256_air::{ - get_flag_pt_array, limbs_into_u32, Sha256Air, SHA256_BLOCK_WORDS, SHA256_BUFFER_SIZE, SHA256_H, - SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, + get_flag_pt_array, get_sha256_num_blocks, Sha256FillerHelper, SHA256_BLOCK_BITS, SHA256_H, + SHA256_ROWS_PER_BLOCK, }; +use openvm_sha256_transpiler::Rv32Sha256Opcode; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, + p3_field::PrimeField32, + p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::get_air_name, - AirRef, Chip, ChipUsageGetter, }; use super::{ - Sha256VmChip, Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, - SHA256VM_DIGEST_WIDTH, SHA256VM_ROUND_WIDTH, + Sha256VmDigestCols, Sha256VmExecutor, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, + SHA256VM_DIGEST_WIDTH, }; use crate::{ - sha256_chip::{PaddingFlags, SHA256_READ_SIZE}, - SHA256_BLOCK_CELLS, + sha256_chip::{PaddingFlags, SHA256_READ_SIZE, SHA256_REGISTER_READS, SHA256_WRITE_SIZE}, + sha256_solve, Sha256VmControlCols, Sha256VmFiller, SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, + SHA256_BLOCK_CELLS, SHA256_MAX_MESSAGE_LEN, SHA256_NUM_READ_ROWS, }; -impl Chip for Sha256VmChip> +#[derive(Clone, Copy)] +pub struct Sha256VmMetadata { + pub num_blocks: u32, +} + +impl MultiRowMetadata for Sha256VmMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_blocks as usize * SHA256_ROWS_PER_BLOCK + } +} + +pub(crate) type Sha256VmRecordLayout = MultiRowLayout; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Sha256VmRecordHeader { + pub from_pc: u32, + pub timestamp: u32, + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub dst_ptr: u32, + pub src_ptr: u32, + pub len: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; SHA256_REGISTER_READS], + pub write_aux: MemoryWriteBytesAuxRecord, +} + +pub struct Sha256VmRecordMut<'a> { + pub inner: &'a mut Sha256VmRecordHeader, + // Having a continuous slice of the input is useful for fast hashing in `execute` + pub input: &'a mut [u8], + pub read_aux: &'a mut [MemoryReadAuxRecord], +} + +/// Custom borrowing that splits the buffer into a fixed `Sha256VmRecord` header +/// followed by a slice of `u8`'s of length `SHA256_BLOCK_CELLS * num_blocks` where `num_blocks` is +/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length +/// `SHA256_NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly +/// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the +/// slices. +impl<'a> CustomBorrow<'a, Sha256VmRecordMut<'a>, Sha256VmRecordLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: Sha256VmRecordLayout) -> Sha256VmRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + + // Using `split_at_mut_unchecked` for perf reasons + // input is a slice of `u8`'s of length `SHA256_BLOCK_CELLS * num_blocks`, so the alignment + // is always satisfied + let (input, rest) = unsafe { + rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * SHA256_BLOCK_CELLS) + }; + + // Using `align_to_mut` to make sure the returned slice is properly aligned to + // `MemoryReadAuxRecord` Additionally, Rust's subslice operation (a few lines below) + // will verify that the buffer has enough capacity + let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; + Sha256VmRecordMut { + inner: header_buf.borrow_mut(), + input, + read_aux: &mut read_aux_buf + [..(layout.metadata.num_blocks as usize) * SHA256_NUM_READ_ROWS], + } + } + + unsafe fn extract_layout(&self) -> Sha256VmRecordLayout { + let header: &Sha256VmRecordHeader = self.borrow(); + Sha256VmRecordLayout { + metadata: Sha256VmMetadata { + num_blocks: get_sha256_num_blocks(header.len), + }, + } + } +} + +impl SizedRecord for Sha256VmRecordMut<'_> { + fn size(layout: &Sha256VmRecordLayout) -> usize { + let mut total_len = size_of::(); + total_len += layout.metadata.num_blocks as usize * SHA256_BLOCK_CELLS; + // Align the pointer to the alignment of `MemoryReadAuxRecord` + total_len = total_len.next_multiple_of(align_of::()); + total_len += layout.metadata.num_blocks as usize + * SHA256_NUM_READ_ROWS + * size_of::(); + total_len + } + + fn alignment(_layout: &Sha256VmRecordLayout) -> usize { + align_of::() + } +} + +impl PreflightExecutor for Sha256VmExecutor where - Val: PrimeField32, + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, Sha256VmRecordLayout, Sha256VmRecordMut<'buf>>, { - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32Sha256Opcode::SHA256) } - fn generate_air_proof_input(self) -> AirProofInput { - let non_padded_height = self.current_trace_height(); - let height = next_power_of_two_or_zero(non_padded_height); - let width = self.trace_width(); - let mut values = Val::::zero_vec(height * width); - if height == 0 { - return AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)); + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(*opcode, Rv32Sha256Opcode::SHA256.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Reading the length first to allocate a record of correct size + let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()); + + let num_blocks = get_sha256_num_blocks(len); + let record = state.ctx.alloc(MultiRowLayout { + metadata: Sha256VmMetadata { num_blocks }, + }); + + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.rd_ptr = a.as_canonical_u32(); + record.inner.rs1_ptr = b.as_canonical_u32(); + record.inner.rs2_ptr = c.as_canonical_u32(); + + record.inner.dst_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rd_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.src_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs1_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.len = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs2_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used + debug_assert!( + record.inner.src_ptr as usize + num_blocks as usize * SHA256_BLOCK_CELLS + <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.dst_ptr as usize + SHA256_WRITE_SIZE <= (1 << self.pointer_max_bits) + ); + // We don't support messages longer than 2^29 bytes + debug_assert!(record.inner.len < SHA256_MAX_MESSAGE_LEN as u32); + + for block_idx in 0..num_blocks as usize { + // Reads happen on the first 4 rows of each block + for row in 0..SHA256_NUM_READ_ROWS { + let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; + let row_input: [u8; SHA256_READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * SHA256_READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * SHA256_READ_SIZE..(read_idx + 1) * SHA256_READ_SIZE] + .copy_from_slice(&row_input); + } + } + + let output = sha256_solve(&record.input[..len as usize]); + tracing_write( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr, + output, + &mut record.inner.write_aux.prev_timestamp, + &mut record.inner.write_aux.prev_data, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for Sha256VmFiller { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace_matrix: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; } - let records = self.records; - let offline_memory = self.offline_memory.lock().unwrap(); - let memory_aux_cols_factory = offline_memory.aux_cols_factory(); - - let mem_ptr_shift: u32 = - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.air.ptr_max_bits); - - let mut states = Vec::with_capacity(height.div_ceil(SHA256_ROWS_PER_BLOCK)); - let mut global_block_idx = 0; - for (record_idx, record) in records.iter().enumerate() { - let dst_read = offline_memory.record_by_id(record.dst_read); - let src_read = offline_memory.record_by_id(record.src_read); - let len_read = offline_memory.record_by_id(record.len_read); - - self.bitwise_lookup_chip.request_range( - dst_read - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - * mem_ptr_shift, - src_read - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - * mem_ptr_shift, - ); - let len = compose(len_read.data_slice().try_into().unwrap()); - let mut state = &None; - for (i, input_message) in record.input_message.iter().enumerate() { - let input_message = input_message - .iter() - .flatten() - .copied() - .collect::>() - .try_into() - .unwrap(); - states.push(Some(Self::generate_state( - state, - input_message, - record_idx, - len, - i == record.input_records.len() - 1, - ))); - state = &states[global_block_idx]; - global_block_idx += 1; + + let mut chunks = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); + let mut sizes = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); + let mut trace = &mut trace_matrix.values[..]; + let mut num_blocks_so_far = 0; + + // First pass over the trace to get the number of blocks for each instruction + // and divide the matrix into chunks of needed sizes + loop { + if num_blocks_so_far * SHA256_ROWS_PER_BLOCK >= rows_used { + // Push all the padding rows as a single chunk and break + chunks.push(trace); + sizes.push((0, num_blocks_so_far)); + break; + } else { + let record: &Sha256VmRecordHeader = + unsafe { get_record_from_slice(&mut trace, ()) }; + let num_blocks = ((record.len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); + let (chunk, rest) = + trace.split_at_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK * num_blocks); + chunks.push(chunk); + sizes.push((num_blocks, num_blocks_so_far)); + num_blocks_so_far += num_blocks; + trace = rest; } } - states.extend(std::iter::repeat_n( - None, - (height - non_padded_height).div_ceil(SHA256_ROWS_PER_BLOCK), - )); // During the first pass we will fill out most of the matrix // But there are some cells that can't be generated by the first pass so we will do a second - // pass over the matrix - values - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(states.into_par_iter().enumerate()) - .for_each(|(block, (global_block_idx, state))| { - // Fill in a valid block - if let Some(state) = state { - let mut has_padding_occurred = - state.local_block_idx * SHA256_BLOCK_CELLS > state.message_len as usize; - let message_left = if has_padding_occurred { - 0 - } else { - state.message_len as usize - state.local_block_idx * SHA256_BLOCK_CELLS - }; - let is_last_block = state.is_last_block; - let buffer: [[Val; SHA256_BUFFER_SIZE]; 4] = array::from_fn(|j| { - array::from_fn(|k| { - Val::::from_canonical_u8( - state.block_input_message[j * SHA256_BUFFER_SIZE + k], - ) - }) + // pass over the matrix later + chunks.par_iter_mut().zip(sizes.par_iter()).for_each( + |(slice, (num_blocks, global_block_offset))| { + if global_block_offset * SHA256_ROWS_PER_BLOCK >= rows_used { + // Fill in the invalid rows + slice.par_chunks_mut(SHA256VM_WIDTH).for_each(|row| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + row.as_mut_ptr() as *mut u8, + 0, + SHA256VM_WIDTH * size_of::(), + ); + } + let cols: &mut Sha256VmRoundCols = + row[..SHA256VM_ROUND_WIDTH].borrow_mut(); + self.inner.generate_default_row(&mut cols.inner); }); + return; + } - let padded_message: [u32; SHA256_BLOCK_WORDS] = array::from_fn(|j| { - limbs_into_u32::(array::from_fn(|k| { - state.block_padded_message[(j + 1) * SHA256_WORD_U8S - k - 1] as u32 - })) - }); + let record: Sha256VmRecordMut = unsafe { + get_record_from_slice( + slice, + Sha256VmRecordLayout { + metadata: Sha256VmMetadata { + num_blocks: *num_blocks as u32, + }, + }, + ) + }; - self.air.sha256_subair.generate_block_trace::>( - block, - width, - SHA256VM_CONTROL_WIDTH, - &padded_message, - self.bitwise_lookup_chip.clone(), - &state.hash, - is_last_block, - global_block_idx as u32 + 1, - state.local_block_idx as u32, - &buffer, - ); - - let block_reads = records[state.message_idx].input_records - [state.local_block_idx] - .map(|record_id| offline_memory.record_by_id(record_id)); - - let mut read_ptr = block_reads[0].pointer; - let mut cur_timestamp = Val::::from_canonical_u32(block_reads[0].timestamp); - - let read_size = Val::::from_canonical_usize(SHA256_READ_SIZE); - for row in 0..SHA256_ROWS_PER_BLOCK { - let row_slice = &mut block[row * width..(row + 1) * width]; - if row < 16 { - let cols: &mut Sha256VmRoundCols> = - row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); - cols.control.len = Val::::from_canonical_u32(state.message_len); - cols.control.read_ptr = read_ptr; - cols.control.cur_timestamp = cur_timestamp; - if row < 4 { - read_ptr += read_size; - cur_timestamp += Val::::ONE; - memory_aux_cols_factory - .generate_read_aux(block_reads[row], &mut cols.read_aux); - - if (row + 1) * SHA256_READ_SIZE <= message_left { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotPadding as usize, - ) - .map(Val::::from_canonical_u32); - } else if !has_padding_occurred { - has_padding_occurred = true; - let len = message_left - row * SHA256_READ_SIZE; - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - if row == 3 && is_last_block { - PaddingFlags::FirstPadding0_LastRow - } else { - PaddingFlags::FirstPadding0 - } as usize - + len, - ) - .map(Val::::from_canonical_u32); - } else { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - if row == 3 && is_last_block { - PaddingFlags::EntirePaddingLastRow - } else { - PaddingFlags::EntirePadding - } as usize, - ) - .map(Val::::from_canonical_u32); - } - } else { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(Val::::from_canonical_u32); - } - cols.control.padding_occurred = - Val::::from_bool(has_padding_occurred); - } else { - if is_last_block { - has_padding_occurred = false; - } - let cols: &mut Sha256VmDigestCols> = - row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut(); - cols.control.len = Val::::from_canonical_u32(state.message_len); - cols.control.read_ptr = read_ptr; - cols.control.cur_timestamp = cur_timestamp; - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(Val::::from_canonical_u32); - if is_last_block { - let record = &records[state.message_idx]; - let dst_read = offline_memory.record_by_id(record.dst_read); - let src_read = offline_memory.record_by_id(record.src_read); - let len_read = offline_memory.record_by_id(record.len_read); - let digest_write = offline_memory.record_by_id(record.digest_write); - cols.from_state = record.from_state; - cols.rd_ptr = dst_read.pointer; - cols.rs1_ptr = src_read.pointer; - cols.rs2_ptr = len_read.pointer; - cols.dst_ptr.copy_from_slice(dst_read.data_slice()); - cols.src_ptr.copy_from_slice(src_read.data_slice()); - cols.len_data.copy_from_slice(len_read.data_slice()); - memory_aux_cols_factory - .generate_read_aux(dst_read, &mut cols.register_reads_aux[0]); - memory_aux_cols_factory - .generate_read_aux(src_read, &mut cols.register_reads_aux[1]); - memory_aux_cols_factory - .generate_read_aux(len_read, &mut cols.register_reads_aux[2]); - memory_aux_cols_factory - .generate_write_aux(digest_write, &mut cols.writes_aux); - } - cols.control.padding_occurred = - Val::::from_bool(has_padding_occurred); - } - } - } - // Fill in the invalid rows - else { - block.par_chunks_mut(width).for_each(|row| { - let cols: &mut Sha256VmRoundCols> = row.borrow_mut(); - self.air.sha256_subair.generate_default_row(&mut cols.inner); - }) + let mut input: Vec = Vec::with_capacity(SHA256_BLOCK_CELLS * num_blocks); + input.extend_from_slice(record.input); + let mut padded_input = input.clone(); + let len = record.inner.len as usize; + let padded_input_len = padded_input.len(); + padded_input[len] = 1 << (RV32_CELL_BITS - 1); + padded_input[len + 1..padded_input_len - 4].fill(0); + padded_input[padded_input_len - 4..] + .copy_from_slice(&((len as u32) << 3).to_be_bytes()); + + let mut prev_hashes = Vec::with_capacity(*num_blocks); + prev_hashes.push(SHA256_H); + for i in 0..*num_blocks - 1 { + prev_hashes.push(Sha256FillerHelper::get_block_hash( + &prev_hashes[i], + padded_input[i * SHA256_BLOCK_CELLS..(i + 1) * SHA256_BLOCK_CELLS] + .try_into() + .unwrap(), + )); } - }); + // Copy the read aux records and input to another place to safely fill in the trace + // matrix without overwriting the record + let mut read_aux_records = Vec::with_capacity(SHA256_NUM_READ_ROWS * num_blocks); + read_aux_records.extend_from_slice(record.read_aux); + let vm_record = record.inner.clone(); + + slice + .par_chunks_exact_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) + .enumerate() + .for_each(|(block_idx, block_slice)| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + block_slice.as_mut_ptr() as *mut u8, + 0, + SHA256_ROWS_PER_BLOCK * SHA256VM_WIDTH * size_of::(), + ); + } + self.fill_block_trace::( + block_slice, + &vm_record, + &read_aux_records[block_idx * SHA256_NUM_READ_ROWS + ..(block_idx + 1) * SHA256_NUM_READ_ROWS], + &input[block_idx * SHA256_BLOCK_CELLS + ..(block_idx + 1) * SHA256_BLOCK_CELLS], + &padded_input[block_idx * SHA256_BLOCK_CELLS + ..(block_idx + 1) * SHA256_BLOCK_CELLS], + block_idx == *num_blocks - 1, + *global_block_offset + block_idx, + block_idx, + prev_hashes[block_idx], + mem_helper, + ); + }); + }, + ); // Do a second pass over the trace to fill in the missing values // Note, we need to skip the very first row - values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) + trace_matrix.values[SHA256VM_WIDTH..] + .par_chunks_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) + .take(rows_used / SHA256_ROWS_PER_BLOCK) .for_each(|chunk| { - self.air - .sha256_subair - .generate_missing_cells(chunk, width, SHA256VM_CONTROL_WIDTH); + self.inner + .generate_missing_cells(chunk, SHA256VM_WIDTH, SHA256VM_CONTROL_WIDTH); }); - - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) } } -impl ChipUsageGetter for Sha256VmChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.iter().fold(0, |acc, record| { - acc + record.input_records.len() * SHA256_ROWS_PER_BLOCK - }) - } +impl Sha256VmFiller { + #[allow(clippy::too_many_arguments)] + fn fill_block_trace( + &self, + block_slice: &mut [F], + record: &Sha256VmRecordHeader, + read_aux_records: &[MemoryReadAuxRecord], + input: &[u8], + padded_input: &[u8], + is_last_block: bool, + global_block_idx: usize, + local_block_idx: usize, + prev_hash: [u32; 8], + mem_helper: &MemoryAuxColsFactory, + ) { + debug_assert_eq!(input.len(), SHA256_BLOCK_CELLS); + debug_assert_eq!(padded_input.len(), SHA256_BLOCK_CELLS); + debug_assert_eq!(read_aux_records.len(), SHA256_NUM_READ_ROWS); - fn trace_width(&self) -> usize { - BaseAir::::width(&self.air) - } -} + let padded_input = array::from_fn(|i| { + u32::from_be_bytes(padded_input[i * 4..(i + 1) * 4].try_into().unwrap()) + }); -/// This is the state information that a block will use to generate its trace -#[derive(Debug, Clone)] -struct Sha256State { - hash: [u32; SHA256_HASH_WORDS], - local_block_idx: usize, - message_len: u32, - block_input_message: [u8; SHA256_BLOCK_CELLS], - block_padded_message: [u8; SHA256_BLOCK_CELLS], - message_idx: usize, - is_last_block: bool, -} + let block_start_timestamp = record.timestamp + + (SHA256_REGISTER_READS + SHA256_NUM_READ_ROWS * local_block_idx) as u32; -impl Sha256VmChip { - fn generate_state( - prev_state: &Option, - block_input_message: [u8; SHA256_BLOCK_CELLS], - message_idx: usize, - message_len: u32, - is_last_block: bool, - ) -> Sha256State { - let local_block_idx = if let Some(prev_state) = prev_state { - prev_state.local_block_idx + 1 - } else { + let read_cells = (SHA256_BLOCK_CELLS * local_block_idx) as u32; + let block_start_read_ptr = record.src_ptr + read_cells; + + let message_left = if record.len <= read_cells { 0 + } else { + (record.len - read_cells) as usize }; - let has_padding_occurred = local_block_idx * SHA256_BLOCK_CELLS > message_len as usize; - let message_left = if has_padding_occurred { - 0 + + // -1 means that padding occurred before the start of the block + // 18 means that no padding occurred on this block + let first_padding_row = if record.len < read_cells { + -1 + } else if message_left < SHA256_BLOCK_CELLS { + (message_left / SHA256_READ_SIZE) as i32 } else { - message_len as usize - local_block_idx * SHA256_BLOCK_CELLS + 18 }; - let padded_message_bytes: [u8; SHA256_BLOCK_CELLS] = array::from_fn(|j| { - if j < message_left { - block_input_message[j] - } else if j == message_left && !has_padding_occurred { - 1 << (RV32_CELL_BITS - 1) - } else if !is_last_block || j < SHA256_BLOCK_CELLS - 4 { - 0u8 - } else { - let shift_amount = (SHA256_BLOCK_CELLS - j - 1) * RV32_CELL_BITS; - ((message_len * RV32_CELL_BITS as u32) - .checked_shr(shift_amount as u32) - .unwrap_or(0) - & ((1 << RV32_CELL_BITS) - 1)) as u8 - } - }); + // Fill in the VM columns first because the inner `carry_or_buffer` needs to be filled in + block_slice + .par_chunks_exact_mut(SHA256VM_WIDTH) + .enumerate() + .for_each(|(row_idx, row_slice)| { + // Handle round rows and digest row separately + if row_idx == SHA256_ROWS_PER_BLOCK - 1 { + // This is a digest row + let digest_cols: &mut Sha256VmDigestCols = + row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut(); + digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); + digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc); + digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); + digest_cols.dst_ptr = record.dst_ptr.to_le_bytes().map(F::from_canonical_u8); + digest_cols.src_ptr = record.src_ptr.to_le_bytes().map(F::from_canonical_u8); + digest_cols.len_data = record.len.to_le_bytes().map(F::from_canonical_u8); + if is_last_block { + digest_cols + .register_reads_aux + .iter_mut() + .zip(record.register_reads_aux.iter()) + .enumerate() + .for_each(|(idx, (cols_read, record_read))| { + mem_helper.fill( + record_read.prev_timestamp, + record.timestamp + idx as u32, + cols_read.as_mut(), + ); + }); + digest_cols + .writes_aux + .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8)); + // In the last block we do `SHA256_NUM_READ_ROWS` reads and then write the + // result thus the timestamp of the write is + // `block_start_timestamp + SHA256_NUM_READ_ROWS` + mem_helper.fill( + record.write_aux.prev_timestamp, + block_start_timestamp + SHA256_NUM_READ_ROWS as u32, + digest_cols.writes_aux.as_mut(), + ); + // Need to range check the destination and source pointers + let msl_rshift: u32 = + ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; + let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS + - self.pointer_max_bits) + as u32; + self.bitwise_lookup_chip.request_range( + (record.dst_ptr >> msl_rshift) << msl_lshift, + (record.src_ptr >> msl_rshift) << msl_lshift, + ); + } else { + // Filling in zeros to make sure the accidental garbage data doesn't + // overflow the prime + digest_cols.register_reads_aux.iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + digest_cols + .writes_aux + .set_prev_data([F::ZERO; SHA256_WRITE_SIZE]); + mem_helper.fill_zero(digest_cols.writes_aux.as_mut()); + } + digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); + digest_cols.inner.flags.is_digest_row = F::from_bool(true); + } else { + // This is a round row + let round_cols: &mut Sha256VmRoundCols = + row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); + // Take care of the first 4 round rows (aka read rows) + if row_idx < SHA256_NUM_READ_ROWS { + round_cols + .inner + .message_schedule + .carry_or_buffer + .as_flattened_mut() + .iter_mut() + .zip( + input[row_idx * SHA256_READ_SIZE..(row_idx + 1) * SHA256_READ_SIZE] + .iter(), + ) + .for_each(|(cell, data)| { + *cell = F::from_canonical_u8(*data); + }); + mem_helper.fill( + read_aux_records[row_idx].prev_timestamp, + block_start_timestamp + row_idx as u32, + round_cols.read_aux.as_mut(), + ); + } else { + mem_helper.fill_zero(round_cols.read_aux.as_mut()); + } + } + // Fill in the control cols, doesn't matter if it is a round or digest row + let control_cols: &mut Sha256VmControlCols = + row_slice[..SHA256VM_CONTROL_WIDTH].borrow_mut(); + control_cols.len = F::from_canonical_u32(record.len); + // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr + control_cols.cur_timestamp = F::from_canonical_u32( + block_start_timestamp + min(row_idx, SHA256_NUM_READ_ROWS) as u32, + ); + control_cols.read_ptr = F::from_canonical_u32( + block_start_read_ptr + + (SHA256_READ_SIZE * min(row_idx, SHA256_NUM_READ_ROWS)) as u32, + ); - if let Some(prev_state) = prev_state { - Sha256State { - hash: Sha256Air::get_block_hash(&prev_state.hash, prev_state.block_padded_message), - local_block_idx, - message_len, - block_input_message, - block_padded_message: padded_message_bytes, - message_idx, - is_last_block, - } - } else { - Sha256State { - hash: SHA256_H, - local_block_idx: 0, - message_len, - block_input_message, - block_padded_message: padded_message_bytes, - message_idx, - is_last_block, - } - } + // Fill in the padding flags + if row_idx < SHA256_NUM_READ_ROWS { + #[allow(clippy::comparison_chain)] + if (row_idx as i32) < first_padding_row { + control_cols.pad_flags = get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotPadding as usize, + ) + .map(F::from_canonical_u32); + } else if row_idx as i32 == first_padding_row { + let len = message_left - row_idx * SHA256_READ_SIZE; + control_cols.pad_flags = get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::FirstPadding0_LastRow + } else { + PaddingFlags::FirstPadding0 + } as usize + + len, + ) + .map(F::from_canonical_u32); + } else { + control_cols.pad_flags = get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::EntirePaddingLastRow + } else { + PaddingFlags::EntirePadding + } as usize, + ) + .map(F::from_canonical_u32); + } + } else { + control_cols.pad_flags = get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotConsidered as usize, + ) + .map(F::from_canonical_u32); + } + if is_last_block && row_idx == SHA256_ROWS_PER_BLOCK - 1 { + // If last digest row, then we set padding_occurred = 0 + control_cols.padding_occurred = F::ZERO; + } else { + control_cols.padding_occurred = + F::from_bool((row_idx as i32) >= first_padding_row); + } + }); + + // Fill in the inner trace when the `buffer_or_carry` is filled in + self.inner.generate_block_trace::( + block_slice, + SHA256VM_WIDTH, + SHA256VM_CONTROL_WIDTH, + &padded_input, + self.bitwise_lookup_chip.as_ref(), + &prev_hash, + is_last_block, + global_block_idx as u32 + 1, // global block index is 1-indexed + local_block_idx as u32, + ); } } diff --git a/extensions/sha256/guest/src/lib.rs b/extensions/sha256/guest/src/lib.rs index 1c51a272fd..8f7c072f4a 100644 --- a/extensions/sha256/guest/src/lib.rs +++ b/extensions/sha256/guest/src/lib.rs @@ -1,11 +1,15 @@ #![no_std] +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + /// This is custom-0 defined in RISC-V spec document pub const OPCODE: u8 = 0x0b; pub const SHA256_FUNCT3: u8 = 0b100; pub const SHA256_FUNCT7: u8 = 0x1; -/// zkvm native implementation of sha256 +/// Native hook for sha256 +/// /// # Safety /// /// The VM accepts the preimage by pointer and length, and writes the @@ -13,10 +17,53 @@ pub const SHA256_FUNCT7: u8 = 0x1; /// - `bytes` must point to an input buffer at least `len` long. /// - `output` must point to a buffer that is at least 32-bytes long. /// -/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +/// [`sha2`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf #[cfg(target_os = "zkvm")] #[inline(always)] #[no_mangle] pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes + const INPUT_ALIGN: usize = 16; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(bytes, len, output); + } + }; + } +} + +/// sha256 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) { openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len); } diff --git a/guest-libs/ff_derive/Cargo.toml b/guest-libs/ff_derive/Cargo.toml index 54d4628897..a4d9c24579 100644 --- a/guest-libs/ff_derive/Cargo.toml +++ b/guest-libs/ff_derive/Cargo.toml @@ -27,7 +27,7 @@ syn = { version = "1", features = ["full"] } [dev-dependencies] openvm-instructions = { workspace = true } -openvm-stark-sdk = { workspace = true } +openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils", "parallel"]} openvm-transpiler = { workspace = true } openvm-algebra-transpiler = { workspace = true } @@ -37,4 +37,3 @@ openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } num-bigint = { workspace = true } - diff --git a/guest-libs/ff_derive/src/lib.rs b/guest-libs/ff_derive/src/lib.rs index 8a64062c33..10a8b64cd2 100644 --- a/guest-libs/ff_derive/src/lib.rs +++ b/guest-libs/ff_derive/src/lib.rs @@ -1,4 +1,5 @@ #![recursion_limit = "1024"] +#![allow(clippy::manual_repeat_n)] extern crate proc_macro; extern crate proc_macro2; diff --git a/guest-libs/ff_derive/tests/lib.rs b/guest-libs/ff_derive/tests/lib.rs index 6df9a1d675..8e85b316e8 100644 --- a/guest-libs/ff_derive/tests/lib.rs +++ b/guest-libs/ff_derive/tests/lib.rs @@ -4,9 +4,9 @@ mod tests { use eyre::Result; use num_bigint::BigUint; - use openvm_algebra_circuit::Rv32ModularConfig; + use openvm_algebra_circuit::{Rv32ModularConfig, Rv32ModularCpuBuilder}; use openvm_algebra_transpiler::ModularTranspilerExtension; - use openvm_circuit::utils::air_test; + use openvm_circuit::utils::{air_test, test_system_config_with_continuations}; use openvm_instructions::exe::VmExe; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, @@ -20,11 +20,18 @@ mod tests { type F = BabyBear; + #[cfg(test)] + fn test_rv32modular_config(moduli: Vec) -> Rv32ModularConfig { + let mut config = Rv32ModularConfig::new(moduli); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_full_limbs() -> Result<()> { let moduli = ["39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "full_limbs", @@ -39,14 +46,14 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_fermat() -> Result<()> { let moduli = ["65537"].map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "fermat", &config)?; let openvm_exe = VmExe::from_elf( @@ -58,14 +65,14 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_sqrt() -> Result<()> { let moduli = ["357686312646216567629137"].map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "sqrt", &config)?; let openvm_exe = VmExe::from_elf( @@ -77,7 +84,7 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } @@ -86,7 +93,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "constants", @@ -101,7 +108,7 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } @@ -110,7 +117,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "from_u128", @@ -125,7 +132,7 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } @@ -134,7 +141,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "batch_inversion", @@ -150,7 +157,7 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } @@ -159,7 +166,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "operations", @@ -174,7 +181,7 @@ mod tests { .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32ModularCpuBuilder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 362df43b6f..d4862bd79c 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -32,19 +32,18 @@ num-bigint = { workspace = true } [dev-dependencies] openvm-circuit = { workspace = true, features = ["test-utils", "parallel"] } openvm-transpiler.workspace = true -openvm-algebra-circuit.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true openvm-sha256-circuit.workspace = true openvm-sha256-transpiler.workspace = true -openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true +rand = { workspace = true } serde.workspace = true eyre.workspace = true derive_more = { workspace = true, features = ["from"] } @@ -84,4 +83,5 @@ ignored = [ "derive_more", "signature", "once_cell", + "rand", ] diff --git a/guest-libs/k256/src/internal.rs b/guest-libs/k256/src/internal.rs index b8f8857dc9..868bce2cd5 100644 --- a/guest-libs/k256/src/internal.rs +++ b/guest-libs/k256/src/internal.rs @@ -4,8 +4,8 @@ use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, + weierstrass::{CachedMulTable, WeierstrassPoint}, + CyclicGroup, Group, IntrinsicCurve, }; use openvm_ecc_sw_macros::sw_declare; diff --git a/guest-libs/k256/src/point.rs b/guest-libs/k256/src/point.rs index b854ef582b..5e66303284 100644 --- a/guest-libs/k256/src/point.rs +++ b/guest-libs/k256/src/point.rs @@ -14,10 +14,7 @@ use elliptic_curve::{ FieldBytesEncoding, }; use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{ - weierstrass::{IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, -}; +use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, IntrinsicCurve}; use crate::{ internal::{Secp256k1Coord, Secp256k1Point, Secp256k1Scalar}, @@ -181,7 +178,7 @@ impl MulByGenerator for Secp256k1Point {} impl DecompressPoint for Secp256k1Point { /// Note that this is not constant time fn decompress(x_bytes: &FieldBytes, y_is_odd: Choice) -> CtOption { - use openvm_ecc_guest::weierstrass::FromCompressed; + use openvm_ecc_guest::FromCompressed; let x = Secp256k1Coord::from_be_bytes_unchecked(x_bytes.as_slice()); let rec_id = y_is_odd.unwrap_u8(); diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index e38675aa09..b42ad6043b 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -2,8 +2,13 @@ mod guest_tests { use ecdsa_config::EcdsaConfig; use eyre::Result; use openvm_algebra_transpiler::ModularTranspilerExtension; - use openvm_circuit::{arch::instructions::exe::VmExe, utils::air_test}; - use openvm_ecc_circuit::{Rv32WeierstrassConfig, SECP256K1_CONFIG}; + use openvm_circuit::{ + arch::instructions::exe::VmExe, + utils::{air_test, test_system_config_with_continuations}, + }; + #[cfg(test)] + use openvm_ecc_circuit::SwCurveCoeffs; + use openvm_ecc_circuit::{CurveConfig, Rv32EccConfig, Rv32EccCpuBuilder, SECP256K1_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, @@ -13,11 +18,20 @@ mod guest_tests { use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; + use crate::guest_tests::ecdsa_config::EcdsaCpuBuilder; + type F = BabyBear; + #[cfg(test)] + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, vec![]); + *config.as_mut() = test_system_config_with_continuations(); + config + } + #[test] fn test_add() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "add", &config)?; let openvm_exe = VmExe::from_elf( @@ -29,13 +43,13 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_mul() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "mul", &config)?; let openvm_exe = VmExe::from_elf( @@ -47,13 +61,13 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_linear_combination() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "linear_combination", @@ -68,62 +82,44 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } + // TODO[jpw]: switch to using SDK to avoid this mod ecdsa_config { - use eyre::Result; - use openvm_algebra_circuit::{ - ModularExtension, ModularExtensionExecutor, ModularExtensionPeriphery, - }; use openvm_circuit::{ - arch::{InitFileGenerator, SystemConfig}, + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, + SystemConfig, VmBuilder, VmChipComplex, VmProverExtension, + }, derive::VmConfig, + system::SystemChipInventory, }; use openvm_ecc_circuit::{ - CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, - WeierstrassExtensionPeriphery, + CurveConfig, Rv32EccConfig, Rv32EccConfigExecutor, Rv32EccCpuBuilder, SwCurveCoeffs, }; - use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, + use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; + use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; - use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] pub struct EcdsaConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub base: Rv32I, - #[extension] - pub mul: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub modular: ModularExtension, - #[extension] - pub weierstrass: WeierstrassExtension, + #[config(generics = true)] + pub ecc: Rv32EccConfig, #[extension] pub sha256: Sha256, } impl EcdsaConfig { - pub fn new(curves: Vec) -> Self { - let primes: Vec<_> = curves - .iter() - .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) - .collect(); + pub fn new(curves: Vec>) -> Self { Self { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(primes), - weierstrass: WeierstrassExtension::new(curves), + ecc: Rv32EccConfig::new(curves, vec![]), sha256: Default::default(), } } @@ -133,11 +129,44 @@ mod guest_tests { fn generate_init_file_contents(&self) -> Option { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", - self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.ecc.modular.modular.generate_moduli_init(), + self.ecc.ecc.generate_ecc_init() )) } } + + #[derive(Clone)] + pub struct EcdsaCpuBuilder; + + impl VmBuilder for EcdsaCpuBuilder + where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, + { + type VmConfig = EcdsaConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &EcdsaConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&Rv32EccCpuBuilder, &config.ecc, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover( + &Sha2CpuProverExt, + &config.sha256, + inventory, + )?; + Ok(chip_complex) + } + } } #[test] @@ -156,13 +185,13 @@ mod guest_tests { .with_extension(ModularTranspilerExtension) .with_extension(Sha256TranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(EcdsaCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_scalar_sqrt() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "scalar_sqrt", @@ -177,7 +206,7 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/k256/tests/programs/examples/add.rs b/guest-libs/k256/tests/programs/examples/add.rs index c6989f44f5..7f7a9d735f 100644 --- a/guest-libs/k256/tests/programs/examples/add.rs +++ b/guest-libs/k256/tests/programs/examples/add.rs @@ -10,7 +10,7 @@ use openvm_k256::Secp256k1Point; mod test_vectors; use test_vectors::ADD_TEST_VECTORS; -openvm::init!("openvm_init_simple.rs"); +openvm::init!("openvm_init_add.rs"); openvm::entry!(main); diff --git a/guest-libs/k256/tests/programs/examples/mul.rs b/guest-libs/k256/tests/programs/examples/mul.rs index 65cb74fb22..e15128a190 100644 --- a/guest-libs/k256/tests/programs/examples/mul.rs +++ b/guest-libs/k256/tests/programs/examples/mul.rs @@ -10,7 +10,7 @@ use openvm_k256::Secp256k1Point; mod test_vectors; use test_vectors::{ADD_TEST_VECTORS, MUL_TEST_VECTORS}; -openvm::init!("openvm_init_simple.rs"); +openvm::init!("openvm_init_mul.rs"); openvm::entry!(main); diff --git a/guest-libs/k256/tests/programs/openvm_init_add.rs b/guest-libs/k256/tests/programs/openvm_init_add.rs index bec9f527e9..0905f21c53 100644 --- a/guest-libs/k256/tests/programs/openvm_init_add.rs +++ b/guest-libs/k256/tests/programs/openvm_init_add.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs b/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs index bec9f527e9..0905f21c53 100644 --- a/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs +++ b/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs b/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs index bec9f527e9..0905f21c53 100644 --- a/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs +++ b/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/guest-libs/k256/tests/programs/openvm_init_mul.rs b/guest-libs/k256/tests/programs/openvm_init_mul.rs index bec9f527e9..0905f21c53 100644 --- a/guest-libs/k256/tests/programs/openvm_init_mul.rs +++ b/guest-libs/k256/tests/programs/openvm_init_mul.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs b/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs index bec9f527e9..0905f21c53 100644 --- a/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs +++ b/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::te_macros::te_init! {} diff --git a/guest-libs/k256/tests/programs/openvm_init_simple.rs b/guest-libs/k256/tests/programs/openvm_init_simple.rs deleted file mode 100644 index bec9f527e9..0000000000 --- a/guest-libs/k256/tests/programs/openvm_init_simple.rs +++ /dev/null @@ -1,3 +0,0 @@ -// This file is automatically generated by cargo openvm. Do not rename or edit. -openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } -openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } diff --git a/guest-libs/keccak256/tests/lib.rs b/guest-libs/keccak256/tests/lib.rs index 836d158a4c..3c000d8e2b 100644 --- a/guest-libs/keccak256/tests/lib.rs +++ b/guest-libs/keccak256/tests/lib.rs @@ -3,7 +3,7 @@ mod tests { use eyre::Result; use openvm_circuit::utils::air_test; use openvm_instructions::exe::VmExe; - use openvm_keccak256_circuit::Keccak256Rv32Config; + use openvm_keccak256_circuit::{Keccak256Rv32Config, Keccak256Rv32CpuBuilder}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, @@ -27,7 +27,7 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Keccak256Rv32CpuBuilder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index e54a7d22d6..3b2210f400 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -29,19 +29,18 @@ ff = { workspace = true } [dev-dependencies] openvm-circuit = { workspace = true, features = ["test-utils", "parallel"] } openvm-transpiler.workspace = true -openvm-algebra-circuit.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true openvm-sha256-circuit.workspace = true openvm-sha256-transpiler.workspace = true -openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true +rand = { workspace = true } serde.workspace = true eyre.workspace = true derive_more = { workspace = true, features = ["from"] } @@ -70,4 +69,4 @@ voprf = ["elliptic-curve/voprf"] num-bigint = { workspace = true } [package.metadata.cargo-shear] -ignored = ["openvm", "serde", "num-bigint", "derive_more"] +ignored = ["openvm", "serde", "num-bigint", "derive_more", "rand"] diff --git a/guest-libs/p256/src/internal.rs b/guest-libs/p256/src/internal.rs index b98c401c8c..7db8f868c6 100644 --- a/guest-libs/p256/src/internal.rs +++ b/guest-libs/p256/src/internal.rs @@ -4,8 +4,8 @@ use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, + weierstrass::{CachedMulTable, WeierstrassPoint}, + CyclicGroup, Group, IntrinsicCurve, }; use openvm_ecc_sw_macros::sw_declare; diff --git a/guest-libs/p256/src/point.rs b/guest-libs/p256/src/point.rs index ee87396c74..3d4030d807 100644 --- a/guest-libs/p256/src/point.rs +++ b/guest-libs/p256/src/point.rs @@ -14,10 +14,7 @@ use elliptic_curve::{ FieldBytesEncoding, }; use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{ - weierstrass::{IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, -}; +use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, IntrinsicCurve}; use crate::{ internal::{P256Coord, P256Point, P256Scalar}, @@ -177,7 +174,7 @@ impl MulByGenerator for P256Point {} impl DecompressPoint for P256Point { /// Note that this is not constant time fn decompress(x_bytes: &FieldBytes, y_is_odd: Choice) -> CtOption { - use openvm_ecc_guest::weierstrass::FromCompressed; + use openvm_ecc_guest::FromCompressed; let x = P256Coord::from_be_bytes_unchecked(x_bytes.as_slice()); let rec_id = y_is_odd.unwrap_u8(); diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index f11cb63325..b3583cca2f 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -2,8 +2,13 @@ mod guest_tests { use ecdsa_config::EcdsaConfig; use eyre::Result; use openvm_algebra_transpiler::ModularTranspilerExtension; - use openvm_circuit::{arch::instructions::exe::VmExe, utils::air_test}; - use openvm_ecc_circuit::{Rv32WeierstrassConfig, P256_CONFIG}; + use openvm_circuit::{ + arch::instructions::exe::VmExe, + utils::{air_test, test_system_config_with_continuations}, + }; + #[cfg(test)] + use openvm_ecc_circuit::SwCurveCoeffs; + use openvm_ecc_circuit::{CurveConfig, Rv32EccConfig, Rv32EccCpuBuilder, P256_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, @@ -13,11 +18,20 @@ mod guest_tests { use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; + use crate::guest_tests::ecdsa_config::EcdsaCpuBuilder; + type F = BabyBear; + #[cfg(test)] + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, vec![]); + *config.as_mut() = test_system_config_with_continuations(); + config + } + #[test] fn test_add() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "add", &config)?; let openvm_exe = VmExe::from_elf( @@ -29,13 +43,13 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_mul() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "mul", &config)?; let openvm_exe = VmExe::from_elf( @@ -47,13 +61,13 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_linear_combination() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "linear_combination", @@ -68,62 +82,44 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } + // TODO[jpw]: switch to using SDK to avoid this mod ecdsa_config { - use eyre::Result; - use openvm_algebra_circuit::{ - ModularExtension, ModularExtensionExecutor, ModularExtensionPeriphery, - }; use openvm_circuit::{ - arch::{InitFileGenerator, SystemConfig}, + arch::{ + AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, + SystemConfig, VmBuilder, VmChipComplex, VmProverExtension, + }, derive::VmConfig, + system::SystemChipInventory, }; use openvm_ecc_circuit::{ - CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, - WeierstrassExtensionPeriphery, + CurveConfig, Rv32EccConfig, Rv32EccConfigExecutor, Rv32EccCpuBuilder, SwCurveCoeffs, }; - use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, + use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; + use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; - use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] pub struct EcdsaConfig { - #[system] - pub system: SystemConfig, - #[extension] - pub base: Rv32I, - #[extension] - pub mul: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub modular: ModularExtension, - #[extension] - pub weierstrass: WeierstrassExtension, + #[config(generics = true)] + pub ecc: Rv32EccConfig, #[extension] pub sha256: Sha256, } impl EcdsaConfig { - pub fn new(curves: Vec) -> Self { - let primes: Vec<_> = curves - .iter() - .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) - .collect(); + pub fn new(curves: Vec>) -> Self { Self { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(primes), - weierstrass: WeierstrassExtension::new(curves), + ecc: Rv32EccConfig::new(curves, vec![]), sha256: Default::default(), } } @@ -133,11 +129,44 @@ mod guest_tests { fn generate_init_file_contents(&self) -> Option { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", - self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.ecc.modular.modular.generate_moduli_init(), + self.ecc.ecc.generate_ecc_init() )) } } + + #[derive(Clone)] + pub struct EcdsaCpuBuilder; + + impl VmBuilder for EcdsaCpuBuilder + where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, + { + type VmConfig = EcdsaConfig; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &EcdsaConfig, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&Rv32EccCpuBuilder, &config.ecc, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover( + &Sha2CpuProverExt, + &config.sha256, + inventory, + )?; + Ok(chip_complex) + } + } } #[test] @@ -156,13 +185,13 @@ mod guest_tests { .with_extension(ModularTranspilerExtension) .with_extension(Sha256TranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(EcdsaCpuBuilder, config, openvm_exe); Ok(()) } #[test] fn test_scalar_sqrt() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "scalar_sqrt", @@ -177,7 +206,7 @@ mod guest_tests { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/p256/tests/programs/examples/add.rs b/guest-libs/p256/tests/programs/examples/add.rs index b1c4d62fb4..c21c43ab32 100644 --- a/guest-libs/p256/tests/programs/examples/add.rs +++ b/guest-libs/p256/tests/programs/examples/add.rs @@ -7,7 +7,7 @@ use openvm_p256::NistP256; #[allow(unused)] use openvm_p256::P256Point; -openvm::init!("openvm_init_simple.rs"); +openvm::init!("openvm_init_add.rs"); openvm::entry!(main); diff --git a/guest-libs/p256/tests/programs/openvm_init_add.rs b/guest-libs/p256/tests/programs/openvm_init_add.rs index 02f8b5c05d..a2cd709ba0 100644 --- a/guest-libs/p256/tests/programs/openvm_init_add.rs +++ b/guest-libs/p256/tests/programs/openvm_init_add.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { P256Point } +openvm_ecc_guest::sw_macros::sw_init! { "P256Point" } diff --git a/guest-libs/p256/tests/programs/openvm_init_ecdsa.rs b/guest-libs/p256/tests/programs/openvm_init_ecdsa.rs index 02f8b5c05d..a2cd709ba0 100644 --- a/guest-libs/p256/tests/programs/openvm_init_ecdsa.rs +++ b/guest-libs/p256/tests/programs/openvm_init_ecdsa.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { P256Point } +openvm_ecc_guest::sw_macros::sw_init! { "P256Point" } diff --git a/guest-libs/p256/tests/programs/openvm_init_linear_combination.rs b/guest-libs/p256/tests/programs/openvm_init_linear_combination.rs index 02f8b5c05d..a2cd709ba0 100644 --- a/guest-libs/p256/tests/programs/openvm_init_linear_combination.rs +++ b/guest-libs/p256/tests/programs/openvm_init_linear_combination.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { P256Point } +openvm_ecc_guest::sw_macros::sw_init! { "P256Point" } diff --git a/guest-libs/p256/tests/programs/openvm_init_mul.rs b/guest-libs/p256/tests/programs/openvm_init_mul.rs index 02f8b5c05d..a2cd709ba0 100644 --- a/guest-libs/p256/tests/programs/openvm_init_mul.rs +++ b/guest-libs/p256/tests/programs/openvm_init_mul.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { P256Point } +openvm_ecc_guest::sw_macros::sw_init! { "P256Point" } diff --git a/guest-libs/p256/tests/programs/openvm_init_scalar_sqrt.rs b/guest-libs/p256/tests/programs/openvm_init_scalar_sqrt.rs index 02f8b5c05d..a2cd709ba0 100644 --- a/guest-libs/p256/tests/programs/openvm_init_scalar_sqrt.rs +++ b/guest-libs/p256/tests/programs/openvm_init_scalar_sqrt.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { P256Point } +openvm_ecc_guest::sw_macros::sw_init! { "P256Point" } diff --git a/guest-libs/p256/tests/programs/openvm_init_simple.rs b/guest-libs/p256/tests/programs/openvm_init_simple.rs deleted file mode 100644 index 02f8b5c05d..0000000000 --- a/guest-libs/p256/tests/programs/openvm_init_simple.rs +++ /dev/null @@ -1,3 +0,0 @@ -// This file is automatically generated by cargo openvm. Do not rename or edit. -openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } -openvm_ecc_guest::sw_macros::sw_init! { P256Point } diff --git a/guest-libs/pairing/Cargo.toml b/guest-libs/pairing/Cargo.toml index 1e0bcbc80b..cccc2bb7d5 100644 --- a/guest-libs/pairing/Cargo.toml +++ b/guest-libs/pairing/Cargo.toml @@ -53,6 +53,7 @@ rand.workspace = true num-bigint.workspace = true num-traits.workspace = true halo2curves-axiom = { workspace = true } +openvm-pairing = { path = ".", features = ["halo2curves"] } [features] default = [] diff --git a/guest-libs/pairing/src/bls12_381/mod.rs b/guest-libs/pairing/src/bls12_381/mod.rs index 0a7c150e1c..d3557ba61a 100644 --- a/guest-libs/pairing/src/bls12_381/mod.rs +++ b/guest-libs/pairing/src/bls12_381/mod.rs @@ -4,7 +4,7 @@ use core::ops::Neg; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{weierstrass::IntrinsicCurve, CyclicGroup, Group}; +use openvm_ecc_guest::{CyclicGroup, Group, IntrinsicCurve}; mod fp12; mod fp2; diff --git a/guest-libs/pairing/src/bn254/mod.rs b/guest-libs/pairing/src/bn254/mod.rs index 8384b8b3e8..a8d3f99f68 100644 --- a/guest-libs/pairing/src/bn254/mod.rs +++ b/guest-libs/pairing/src/bn254/mod.rs @@ -5,10 +5,7 @@ use core::ops::{Add, Neg}; use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve}, - CyclicGroup, Group, -}; +use openvm_ecc_guest::{weierstrass::CachedMulTable, CyclicGroup, Group, IntrinsicCurve}; use openvm_ecc_sw_macros::sw_declare; use openvm_pairing_guest::pairing::PairingIntrinsics; diff --git a/guest-libs/pairing/tests/lib.rs b/guest-libs/pairing/tests/lib.rs index 6e55834b77..0981a282b6 100644 --- a/guest-libs/pairing/tests/lib.rs +++ b/guest-libs/pairing/tests/lib.rs @@ -9,20 +9,23 @@ mod bn254 { bn256::{Fq12, Fq2, Fr, G1Affine, G2Affine}, ff::Field, }; - use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; + use openvm_algebra_circuit::{Fp2Extension, Rv32ModularConfig}; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; - use openvm_circuit::{ - arch::SystemConfig, - utils::{air_test, air_test_impl, air_test_with_min_segments}, + use openvm_circuit::utils::{ + air_test, air_test_impl, air_test_with_min_segments, test_system_config_with_continuations, + }; + use openvm_ecc_circuit::{ + CurveConfig, EccExtension, Rv32EccConfig, Rv32EccCpuBuilder, SwCurveCoeffs, }; - use openvm_ecc_circuit::{Rv32WeierstrassConfig, WeierstrassExtension}; use openvm_ecc_guest::{ algebra::{field::FieldExtension, IntMod}, AffinePoint, }; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_instructions::exe::VmExe; - use openvm_pairing_circuit::{PairingCurve, PairingExtension, Rv32PairingConfig}; + use openvm_pairing_circuit::{ + PairingCurve, PairingExtension, Rv32PairingConfig, Rv32PairingCpuBuilder, + }; use openvm_pairing_guest::{ bn254::{BN254_COMPLEX_STRUCT_NAME, BN254_MODULUS}, halo2curves_shims::bn254::Bn254, @@ -32,7 +35,11 @@ mod bn254 { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_stark_sdk::{openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear}; + use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, + openvm_stark_backend::p3_field::FieldAlgebra, + p3_baby_bear::BabyBear, + }; use openvm_toolchain_tests::{build_example_program_at_path_with_features, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; use rand::SeedableRng; @@ -48,21 +55,26 @@ mod bn254 { .zip(primes.clone()) .collect::>(); Rv32PairingConfig { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(primes.to_vec()), + modular: Rv32ModularConfig::new(primes.to_vec()), fp2: Fp2Extension::new(primes_with_names), - weierstrass: WeierstrassExtension::new(vec![]), + ecc: EccExtension::new(vec![], vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bn254]), } } + #[cfg(test)] + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + use openvm_ecc_circuit::Rv32EccConfig; + + let mut config = Rv32EccConfig::new(sw_curves, vec![]); + *config.as_mut() = test_system_config_with_continuations(); + config + } + #[test] fn test_bn_ec() -> Result<()> { let curve = PairingCurve::Bn254.curve_config(); - let config = Rv32WeierstrassConfig::new(vec![curve]); + let config = test_rv32ecc_config(vec![curve]); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "bn_ec", @@ -78,7 +90,7 @@ mod bn254 { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } @@ -111,10 +123,10 @@ mod bn254 { .into_iter() .flat_map(|fp12| fp12.to_coeffs()) .flat_map(|fp2| fp2.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io], 1); Ok(()) } @@ -155,7 +167,7 @@ mod bn254 { .chain(r0) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); // Test mul_by_01234 @@ -167,12 +179,12 @@ mod bn254 { .chain(r1.to_coeffs()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -208,7 +220,7 @@ mod bn254 { let io0 = [s.x, s.y, pt.x, pt.y, l.b, l.c] .into_iter() .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); // Test miller_double_and_add_step @@ -216,12 +228,12 @@ mod bn254 { let io1 = [s.x, s.y, q.x, q.y, pt.x, pt.y, l0.b, l0.c, l1.b, l1.c] .into_iter() .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -260,7 +272,7 @@ mod bn254 { let io0 = s .into_iter() .flat_map(|pt| [pt.x, pt.y].into_iter().flat_map(|fp| fp.to_bytes())) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io1 = q @@ -269,12 +281,12 @@ mod bn254 { .chain(f.to_coeffs()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -318,7 +330,7 @@ mod bn254 { let io0 = s .into_iter() .flat_map(|pt| [pt.x, pt.y].into_iter().flat_map(|fp| fp.to_bytes())) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io1 = q @@ -326,12 +338,12 @@ mod bn254 { .flat_map(|pt| [pt.x, pt.y].into_iter()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -375,7 +387,7 @@ mod bn254 { let io0 = s .into_iter() .flat_map(|pt| [pt.x, pt.y].into_iter().flat_map(|fp| fp.to_bytes())) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io1 = q @@ -383,12 +395,20 @@ mod bn254 { .flat_map(|pt| [pt.x, pt.y].into_iter()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); // Don't run debugger because it's slow - air_test_impl(get_testing_config(), openvm_exe, vec![io_all], 1, false); + air_test_impl::( + FriParameters::new_for_testing(1), + Rv32PairingCpuBuilder, + get_testing_config(), + openvm_exe, + vec![io_all], + 1, + false, + )?; Ok(()) } @@ -442,7 +462,7 @@ mod bn254 { .flat_map(|w| w.to_le_bytes()) .map(F::from_canonical_u8) .collect(); - air_test_with_min_segments(config, openvm_exe, vec![io], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io], 1); Ok(()) } } @@ -456,19 +476,24 @@ mod bls12_381 { }; use num_bigint::BigUint; use num_traits::{self, FromPrimitive}; - use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; + use openvm_algebra_circuit::{Fp2Extension, Rv32ModularConfig}; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; use openvm_circuit::{ - arch::{instructions::exe::VmExe, SystemConfig}, - utils::{air_test, air_test_impl, air_test_with_min_segments}, + arch::instructions::exe::VmExe, + utils::{ + air_test, air_test_impl, air_test_with_min_segments, + test_system_config_with_continuations, + }, }; - use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, WeierstrassExtension}; + use openvm_ecc_circuit::{CurveConfig, Rv32EccConfig, Rv32EccCpuBuilder, SwCurveCoeffs}; use openvm_ecc_guest::{ algebra::{field::FieldExtension, IntMod}, AffinePoint, }; use openvm_ecc_transpiler::EccTranspilerExtension; - use openvm_pairing_circuit::{PairingCurve, PairingExtension, Rv32PairingConfig}; + use openvm_pairing_circuit::{ + PairingCurve, PairingExtension, Rv32PairingConfig, Rv32PairingCpuBuilder, + }; use openvm_pairing_guest::{ bls12_381::{ BLS12_381_COMPLEX_STRUCT_NAME, BLS12_381_ECC_STRUCT_NAME, BLS12_381_MODULUS, @@ -481,7 +506,11 @@ mod bls12_381 { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_stark_sdk::{openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear}; + use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, + openvm_stark_backend::p3_field::FieldAlgebra, + p3_baby_bear::BabyBear, + }; use openvm_toolchain_tests::{build_example_program_at_path_with_features, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; use rand::SeedableRng; @@ -490,6 +519,8 @@ mod bls12_381 { #[cfg(test)] pub fn get_testing_config() -> Rv32PairingConfig { + use openvm_ecc_circuit::EccExtension; + let primes = [BLS12_381_MODULUS.clone()]; let complex_struct_names = [BLS12_381_COMPLEX_STRUCT_NAME.to_string()]; let primes_with_names = complex_struct_names @@ -497,27 +528,32 @@ mod bls12_381 { .zip(primes.clone()) .collect::>(); Rv32PairingConfig { - system: SystemConfig::default().with_continuations(), - base: Default::default(), - mul: Default::default(), - io: Default::default(), - modular: ModularExtension::new(primes.to_vec()), + modular: Rv32ModularConfig::new(primes.to_vec()), fp2: Fp2Extension::new(primes_with_names), - weierstrass: WeierstrassExtension::new(vec![]), + ecc: EccExtension::new(vec![], vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bls12_381]), } } + #[cfg(test)] + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, vec![]); + *config.as_mut() = test_system_config_with_continuations(); + config + } + #[test] fn test_bls_ec() -> Result<()> { let curve = CurveConfig { struct_name: BLS12_381_ECC_STRUCT_NAME.to_string(), modulus: BLS12_381_MODULUS.clone(), scalar: BLS12_381_ORDER.clone(), - a: BigUint::ZERO, - b: BigUint::from_u8(4).unwrap(), + coeffs: SwCurveCoeffs { + a: BigUint::ZERO, + b: BigUint::from_u8(4).unwrap(), + }, }; - let config = Rv32WeierstrassConfig::new(vec![curve]); + let config = test_rv32ecc_config(vec![curve]); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "bls_ec", @@ -533,7 +569,7 @@ mod bls12_381 { .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Rv32EccCpuBuilder, config, openvm_exe); Ok(()) } @@ -566,10 +602,10 @@ mod bls12_381 { .into_iter() .flat_map(|fp12| fp12.to_coeffs()) .flat_map(|fp2| fp2.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io], 1); Ok(()) } @@ -610,7 +646,7 @@ mod bls12_381 { .chain(r0) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); // Test mul_by_02345 @@ -623,12 +659,12 @@ mod bls12_381 { .chain(r1.to_coeffs()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -664,7 +700,7 @@ mod bls12_381 { let io0 = [s.x, s.y, pt.x, pt.y, l.b, l.c] .into_iter() .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); // Test miller_double_and_add_step @@ -672,12 +708,12 @@ mod bls12_381 { let io1 = [s.x, s.y, q.x, q.y, pt.x, pt.y, l0.b, l0.c, l1.b, l1.c] .into_iter() .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -722,7 +758,7 @@ mod bls12_381 { let io0 = s .into_iter() .flat_map(|pt| [pt.x, pt.y].into_iter().flat_map(|fp| fp.to_bytes())) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io1 = q @@ -731,12 +767,12 @@ mod bls12_381 { .chain(f.to_coeffs()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -779,7 +815,7 @@ mod bls12_381 { let io0 = s .into_iter() .flat_map(|pt| [pt.x, pt.y].into_iter().flat_map(|fp| fp.to_bytes())) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io1 = q @@ -787,12 +823,12 @@ mod bls12_381 { .flat_map(|pt| [pt.x, pt.y].into_iter()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); - air_test_with_min_segments(config, openvm_exe, vec![io_all], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io_all], 1); Ok(()) } @@ -836,7 +872,7 @@ mod bls12_381 { let io0 = s .into_iter() .flat_map(|pt| [pt.x, pt.y].into_iter().flat_map(|fp| fp.to_bytes())) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io1 = q @@ -844,12 +880,20 @@ mod bls12_381 { .flat_map(|pt| [pt.x, pt.y].into_iter()) .flat_map(|fp2| fp2.to_coeffs()) .flat_map(|fp| fp.to_bytes()) - .map(FieldAlgebra::from_canonical_u8) + .map(F::from_canonical_u8) .collect::>(); let io_all = io0.into_iter().chain(io1).collect::>(); // Don't run debugger because it's slow - air_test_impl(get_testing_config(), openvm_exe, vec![io_all], 1, false); + air_test_impl::( + FriParameters::new_for_testing(1), + Rv32PairingCpuBuilder, + get_testing_config(), + openvm_exe, + vec![io_all], + 1, + false, + )?; Ok(()) } @@ -903,7 +947,7 @@ mod bls12_381 { .flat_map(|w| w.to_le_bytes()) .map(F::from_canonical_u8) .collect(); - air_test_with_min_segments(config, openvm_exe, vec![io], 1); + air_test_with_min_segments(Rv32PairingCpuBuilder, config, openvm_exe, vec![io], 1); Ok(()) } } diff --git a/guest-libs/pairing/tests/programs/openvm_init_bls_ec_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_bls_ec_bls12_381.rs index 95a4e46fd3..e8bbe47154 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_bls_ec_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_bls_ec_bls12_381.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787", "52435875175126190479447740508185965837690552500527637822603658699938581184513" } -openvm_ecc_guest::sw_macros::sw_init! { Bls12_381G1Affine } +openvm_ecc_guest::sw_macros::sw_init! { "Bls12_381G1Affine" } diff --git a/guest-libs/pairing/tests/programs/openvm_init_bls_final_exp_hint_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_bls_final_exp_hint_bls12_381.rs index 00181b03ef..ec5d5ed804 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_bls_final_exp_hint_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_bls_final_exp_hint_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_bn_ec_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_bn_ec_bn254.rs index 64de28e83a..e8911a04a3 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_bn_ec_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_bn_ec_bn254.rs @@ -1,3 +1,3 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583", "21888242871839275222246405745257275088548364400416034343698204186575808495617" } -openvm_ecc_guest::sw_macros::sw_init! { Bn254G1Affine } +openvm_ecc_guest::sw_macros::sw_init! { "Bn254G1Affine" } diff --git a/guest-libs/pairing/tests/programs/openvm_init_bn_final_exp_hint_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_bn_final_exp_hint_bn254.rs index c130859ad8..1a1e1f95ea 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_bn_final_exp_hint_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_bn_final_exp_hint_bn254.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs index 00181b03ef..ec5d5ed804 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bn254.rs index c130859ad8..1a1e1f95ea 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_fp12_mul_bn254.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bls12_381.rs index 00181b03ef..ec5d5ed804 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bn254.rs index c130859ad8..1a1e1f95ea 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_bn254.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bls12_381.rs index 00181b03ef..ec5d5ed804 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bn254.rs index c130859ad8..1a1e1f95ea 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_check_fallback_bn254.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bls12_381.rs index 00181b03ef..ec5d5ed804 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bn254.rs index c130859ad8..1a1e1f95ea 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_line_bn254.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bls12_381.rs index 00181b03ef..ec5d5ed804 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bn254.rs index c130859ad8..1a1e1f95ea 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_loop_bn254.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bls12_381.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bls12_381.rs index 00181b03ef..ec5d5ed804 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bls12_381.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bls12_381Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bn254.rs b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bn254.rs index c130859ad8..1a1e1f95ea 100644 --- a/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bn254.rs +++ b/guest-libs/pairing/tests/programs/openvm_init_pairing_miller_step_bn254.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } -openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } -openvm_ecc_guest::sw_macros::sw_init! { } +openvm_algebra_guest::complex_macros::complex_init! { "Bn254Fp2" { mod_idx = 0 } } +openvm_ecc_guest::sw_macros::sw_init! {} diff --git a/guest-libs/ruint/ruint-macro/src/lib.rs b/guest-libs/ruint/ruint-macro/src/lib.rs index 86dc5afcf0..67660292f4 100644 --- a/guest-libs/ruint/ruint-macro/src/lib.rs +++ b/guest-libs/ruint/ruint-macro/src/lib.rs @@ -1,5 +1,6 @@ #![doc = include_str!("../README.md")] #![warn(clippy::all, clippy::pedantic, clippy::nursery)] +#![allow(clippy::manual_div_ceil)] use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree}; use std::fmt::{self, Write}; diff --git a/guest-libs/ruint/tests/lib.rs b/guest-libs/ruint/tests/lib.rs index 3db2697775..fec78d47f7 100644 --- a/guest-libs/ruint/tests/lib.rs +++ b/guest-libs/ruint/tests/lib.rs @@ -1,7 +1,7 @@ #[cfg(test)] mod tests { use eyre::Result; - use openvm_bigint_circuit::Int256Rv32Config; + use openvm_bigint_circuit::{Int256Rv32Config, Int256Rv32CpuBuilder}; use openvm_bigint_transpiler::Int256TranspilerExtension; use openvm_circuit::utils::air_test; use openvm_instructions::exe::VmExe; @@ -30,7 +30,7 @@ mod tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(Int256TranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Int256Rv32CpuBuilder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/ruint/tests/programs/examples/matrix_power.rs b/guest-libs/ruint/tests/programs/examples/matrix_power.rs index 95826d32de..6a874bc35e 100644 --- a/guest-libs/ruint/tests/programs/examples/matrix_power.rs +++ b/guest-libs/ruint/tests/programs/examples/matrix_power.rs @@ -123,6 +123,11 @@ pub fn main() { panic!(); } + if U256::from_limbs([u64::MAX; 4]) + one != zero { + print("FAIL: U256::MAX == 0 test failed"); + panic!(); + } + if two_to_200 != two_to_200 { print("FAIL: 2^200 clone test failed"); panic!(); diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index 9ebab5ac02..669c2c3db6 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -6,7 +6,7 @@ mod tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_circuit::Sha256Rv32Config; + use openvm_sha256_circuit::{Sha256Rv32Config, Sha256Rv32CpuBuilder}; use openvm_sha256_transpiler::Sha256TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; @@ -27,7 +27,7 @@ mod tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(Sha256TranspilerExtension), )?; - air_test(config, openvm_exe); + air_test(Sha256Rv32CpuBuilder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/verify_stark/Cargo.toml b/guest-libs/verify_stark/Cargo.toml index 070083edad..66f13731d2 100644 --- a/guest-libs/verify_stark/Cargo.toml +++ b/guest-libs/verify_stark/Cargo.toml @@ -21,4 +21,4 @@ openvm-circuit = { workspace = true, features = ["parallel"] } openvm-stark-sdk = { workspace = true } openvm-native-compiler.workspace = true openvm-verify-stark.workspace = true -eyre.workspace = true \ No newline at end of file +eyre.workspace = true diff --git a/guest-libs/verify_stark/tests/integration_test.rs b/guest-libs/verify_stark/tests/integration_test.rs index e05b5f12a1..3eaa148e27 100644 --- a/guest-libs/verify_stark/tests/integration_test.rs +++ b/guest-libs/verify_stark/tests/integration_test.rs @@ -7,7 +7,7 @@ mod tests { use openvm_native_compiler::conversion::CompilerOptions; use openvm_sdk::{ commit::AppExecutionCommit, - config::{AggStarkConfig, AppConfig, SdkSystemConfig, SdkVmConfig}, + config::{AggStarkConfig, AppConfig, SdkSystemConfig, SdkVmConfig, SdkVmCpuBuilder}, keygen::AggStarkProvingKey, Sdk, StdIn, }; @@ -71,7 +71,7 @@ mod tests { ..Default::default() }, root_max_constraint_degree: (1 << ROOT_LOG_BLOWUP) + 1, - }); + })?; let asm = sdk.generate_root_verifier_asm(&agg_pk); let asm_path = format!( "{}/examples/verify_openvm_stark/{}", @@ -81,6 +81,7 @@ mod tests { std::fs::write(asm_path, asm)?; let e2e_stark_proof = sdk.generate_e2e_stark_proof( + SdkVmCpuBuilder, Arc::new(app_pk), committed_app_exe, agg_pk, @@ -100,8 +101,8 @@ mod tests { sdk.transpile(elf, vm_config.transpiler())? }; - // app_exe publishes 7th and 8th fibonacci numbers. - let pvs: Vec = [13u32, 21, 0, 0, 0, 0, 0, 0] + // app_exe publishes 31st and 32nd fibonacci numbers. + let pvs: Vec = [1346269, 2178309, 0, 0, 0, 0, 0, 0u32] .iter() .flat_map(|x| x.to_le_bytes()) .collect(); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 35e9b966ed..8825102061 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.85.1" +channel = "1.86.0" components = ["clippy", "rustfmt"]