diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index b9d34507..0eba6cb7 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -476,7 +476,8 @@ fn compile_time_transform_in_lines( value, location, } = line - && let Some(expanded) = try_expand_match_range(value, targets, *location)? + && let Some(expanded) = + try_expand_match_range(value, targets, *location, const_arrays, &vector_len_tracker)? { lines.splice(i..=i, expanded); continue; @@ -642,10 +643,13 @@ fn compile_time_transform_in_lines( end, body, unroll: true, - .. + location, } => { let (Some(start), Some(end)) = (start.as_scalar(), end.as_scalar()) else { - return Err("Cannot unroll loop with non-constant bounds".to_string()); + return Err(format!( + "line {}: Cannot unroll loop with non-constant bounds", + location + )); }; let unroll_index = unroll_counter.get_next(); let (internal_vars, _) = find_variable_usage(body, const_arrays); @@ -686,6 +690,8 @@ fn try_expand_match_range( value: &Expression, targets: &[AssignmentTarget], location: SourceLocation, + const_arrays: &BTreeMap, + vector_len: &VectorLenTracker, ) -> Result>, String> { let Expression::FunctionCall { function_name, args, .. @@ -746,12 +752,12 @@ fn try_expand_match_range( return Err("match_range: expected range(start, end)".into()); } let start = ra[0] - .as_scalar() - .ok_or("match_range: range start must be constant")? + .compile_time_eval(const_arrays, vector_len) + .ok_or(format!("match_range: range start must be constant (at {location})"))? .to_usize(); let end = ra[1] - .as_scalar() - .ok_or("match_range: range end must be constant")? + .compile_time_eval(const_arrays, vector_len) + .ok_or(format!("match_range: range end must be constant (at {location})"))? .to_usize(); // Parse lambda diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index c78e2005..e90fd389 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -19,17 +19,20 @@ use trace_gen::*; pub const SECURITY_BITS: usize = 123; // TODO 128 bits security? (with Poseidon over 20 field elements or with a more subtle soundness analysis (cf. https://eprint.iacr.org/2021/188.pdf)) pub const GRINDING_BITS: usize = 18; +pub const MAX_NUM_VARIABLES_TO_SEND_COEFFS: usize = 8; +pub const WHIR_INITIAL_FOLDING_FACTOR: usize = 7; +pub const WHIR_SUBSEQUENT_FOLDING_FACTOR: usize = 5; pub fn default_whir_config(starting_log_inv_rate: usize, prox_gaps_conjecture: bool) -> WhirConfigBuilder { WhirConfigBuilder { - folding_factor: FoldingFactor::new(7, 5), + folding_factor: FoldingFactor::new(WHIR_INITIAL_FOLDING_FACTOR, WHIR_SUBSEQUENT_FOLDING_FACTOR), soundness_type: if prox_gaps_conjecture { SecurityAssumption::CapacityBound // TODO update formula with State of the Art Conjecture } else { SecurityAssumption::JohnsonBound }, pow_bits: GRINDING_BITS, - max_num_variables_to_send_coeffs: 9, + max_num_variables_to_send_coeffs: MAX_NUM_VARIABLES_TO_SEND_COEFFS, rs_domain_initial_reduction_factor: 5, security_level: SECURITY_BITS, starting_log_inv_rate, diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 72660624..db7525bf 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -142,7 +142,11 @@ pub fn prove_execution( let air_alpha = prover_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); - for (table, trace) in traces.iter() { + let tables_log_heights: BTreeMap = + traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); + let tables_sorted = sort_tables_by_height(&tables_log_heights); + for (table, _) in &tables_sorted { + let trace = &traces[table]; let this_air_claims = prove_bus_and_air( &mut prover_state, table, @@ -185,7 +189,6 @@ pub fn prove_execution( ), ]; - let tables_log_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); let global_statements_base = stacked_pcs_global_statements( stacked_pcs_witness.stacked_n_vars, log2_strict_usize(memory.len()), diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 1b257369..b0a7f48c 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -97,7 +97,8 @@ pub fn verify_execution( let air_alpha = verifier_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); - for (table, log_n_rows) in &table_n_vars { + let tables_sorted = sort_tables_by_height(&table_n_vars); + for (table, log_n_rows) in &tables_sorted { let this_air_claims = verify_bus_and_air( &mut verifier_state, table, diff --git a/crates/rec_aggregation/fiat_shamir.py b/crates/rec_aggregation/fiat_shamir.py index e5494f95..7b66058c 100644 --- a/crates/rec_aggregation/fiat_shamir.py +++ b/crates/rec_aggregation/fiat_shamir.py @@ -95,12 +95,25 @@ def fs_receive_chunks(fs, n_chunks: Const): @inline -def fs_receive_ef(fs, n): +def fs_receive_ef_inlined(fs, n): new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, 8)) for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): assert ef_ptr[i] == 0 return new_fs, ef_ptr +def fs_receive_ef_by_log_dynamic(fs, log_n, min_value: Const, max_value: Const): + debug_assert(log_n < max_value) + debug_assert(min_value <= log_n) + new_fs: Imu + ef_ptr: Imu + new_fs, ef_ptr = match_range(log_n, range(min_value, max_value), lambda ln: fs_receive_ef(fs, 2**ln)) + return new_fs, ef_ptr + +def fs_receive_ef(fs, n: Const): + new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, 8)) + for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): + assert ef_ptr[i] == 0 + return new_fs, ef_ptr def fs_print_state(fs_state): for i in unroll(0, 9): diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py index 58c908e2..4b831488 100644 --- a/crates/rec_aggregation/recursion.py +++ b/crates/rec_aggregation/recursion.py @@ -3,8 +3,6 @@ N_TABLES = N_TABLES_PLACEHOLDER -MIN_WHIR_LOG_INV_RATE = MIN_WHIR_LOG_INV_RATE_PLACEHOLDER -MAX_WHIR_LOG_INV_RATE = MAX_WHIR_LOG_INV_RATE_PLACEHOLDER MIN_LOG_N_ROWS_PER_TABLE = MIN_LOG_N_ROWS_PER_TABLE_PLACEHOLDER MAX_LOG_N_ROWS_PER_TABLE = MAX_LOG_N_ROWS_PER_TABLE_PLACEHOLDER MIN_LOG_MEMORY_SIZE = MIN_LOG_MEMORY_SIZE_PLACEHOLDER @@ -92,9 +90,11 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip assert log_memory <= GUEST_BYTECODE_LEN stacked_n_vars = compute_stacked_n_vars(log_memory, log_bytecode_padded, table_heights) - assert stacked_n_vars <= TWO_ADICITY + WHIR_FOLDING_FACTORS[0] - whir_log_inv_rate + assert stacked_n_vars <= TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - whir_log_inv_rate - fs, whir_base_root, whir_base_ood_points, whir_base_ood_evals = parse_whir_commitment_const(fs, WHIR_NUM_OOD_COMMIT) + num_oods = get_num_oods(whir_log_inv_rate, stacked_n_vars) + num_ood_at_commitment = num_oods[0] + fs, whir_base_root, whir_base_ood_points, whir_base_ood_evals = parse_commitment(fs, num_ood_at_commitment) fs, logup_c = fs_sample_ef(fs) @@ -111,8 +111,8 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip memory_and_acc_prefix = multilinear_location_prefix(0, n_vars_logup_gkr - log_memory, point_gkr) - fs, value_acc = fs_receive_ef(fs, 1) - fs, value_memory = fs_receive_ef(fs, 1) + fs, value_acc = fs_receive_ef_inlined(fs, 1) + fs, value_memory = fs_receive_ef_inlined(fs, 1) retrieved_numerators_value: Mut = opposite_extension_ret(mul_extension_ret(memory_and_acc_prefix, value_acc)) @@ -146,7 +146,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip bytecode_value_corrected, one_minus_self_extension_ret(logup_alphas + i * DIM) ) - fs, value_bytecode_acc = fs_receive_ef(fs, 1) + fs, value_bytecode_acc = fs_receive_ef_inlined(fs, 1) retrieved_numerators_value = sub_extension_ret( retrieved_numerators_value, mul_extension_ret(bytecode_multilinear_location_prefix, value_bytecode_acc) ) @@ -183,6 +183,17 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip ) offset += powers_of_two(log_bytecode_padded) + # Dispatch based on table height ordering (sorted by descending height) + if maximum(table_log_heights[1], table_log_heights[2]) == table_log_heights[1]: + continue_recursion_ordered(1, 2, fs, offset, retrieved_numerators_value, retrieved_denominators_value, table_heights, table_log_heights, point_gkr, n_vars_logup_gkr, logup_alphas_eq_poly, logup_c, numerators_value, denominators_value, log_memory, inner_public_memory, inner_public_memory_log_size, stacked_n_vars, whir_log_inv_rate, whir_base_root, whir_base_ood_points, whir_base_ood_evals, num_ood_at_commitment, log_n_cycles, log_bytecode_padded, bytecode_and_acc_point, value_memory, value_acc, value_bytecode_acc) + else: + continue_recursion_ordered(2, 1, fs, offset, retrieved_numerators_value, retrieved_denominators_value, table_heights, table_log_heights, point_gkr, n_vars_logup_gkr, logup_alphas_eq_poly, logup_c, numerators_value, denominators_value, log_memory, inner_public_memory, inner_public_memory_log_size, stacked_n_vars, whir_log_inv_rate, whir_base_root, whir_base_ood_points, whir_base_ood_evals, num_ood_at_commitment, log_n_cycles, log_bytecode_padded, bytecode_and_acc_point, value_memory, value_acc, value_bytecode_acc) + + return + + +@inline +def continue_recursion_ordered(second_table, third_table, fs, offset, retrieved_numerators_value, retrieved_denominators_value, table_heights, table_log_heights, point_gkr, n_vars_logup_gkr, logup_alphas_eq_poly, logup_c, numerators_value, denominators_value, log_memory, inner_public_memory, inner_public_memory_log_size, stacked_n_vars, whir_log_inv_rate, whir_base_root, whir_base_ood_points, whir_base_ood_evals, num_ood_at_commitment, log_n_cycles, log_bytecode_padded, bytecode_and_acc_point, value_memory, value_acc, value_bytecode_acc): bus_numerators_values = DynArray([]) bus_denominators_values = DynArray([]) pcs_points = DynArray([]) # [[_; N]; N_TABLES] @@ -196,7 +207,14 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip for _ in unroll(0, total_num_cols): pcs_values[i][0].push(DynArray([])) - for table_index in unroll(0, N_TABLES): + for sorted_pos in unroll(0, N_TABLES): + table_index: Imu + if sorted_pos == 0: + table_index = EXECUTION_TABLE_INDEX + if sorted_pos == 1: + table_index = second_table + if sorted_pos == 2: + table_index = third_table # I] Bus (data flow between tables) log_n_rows = table_log_heights[table_index] @@ -208,9 +226,9 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip # 0] Bytecode lookup bytecode_prefix = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) - fs, eval_on_pc = fs_receive_ef(fs, 1) + fs, eval_on_pc = fs_receive_ef_inlined(fs, 1) pcs_values[EXECUTION_TABLE_INDEX][0][COL_PC].push(eval_on_pc) - fs, instr_evals = fs_receive_ef(fs, N_INSTRUCTION_COLUMNS) + fs, instr_evals = fs_receive_ef_inlined(fs, N_INSTRUCTION_COLUMNS) for i in unroll(0, N_INSTRUCTION_COLUMNS): global_index = N_COMMITTED_EXEC_COLUMNS + i pcs_values[EXECUTION_TABLE_INDEX][0][global_index].push(instr_evals + i * DIM) @@ -224,12 +242,12 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip prefix = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) - fs, eval_on_selector = fs_receive_ef(fs, 1) + fs, eval_on_selector = fs_receive_ef_inlined(fs, 1) retrieved_numerators_value = add_extension_ret( retrieved_numerators_value, mul_extension_ret(prefix, eval_on_selector) ) - fs, eval_on_data = fs_receive_ef(fs, 1) + fs, eval_on_data = fs_receive_ef_inlined(fs, 1) retrieved_denominators_value = add_extension_ret( retrieved_denominators_value, mul_extension_ret(prefix, eval_on_data) ) @@ -244,11 +262,11 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip for lookup_f_index in unroll(0, len(LOOKUPS_F_INDEXES[table_index])): col_index = LOOKUPS_F_INDEXES[table_index][lookup_f_index] - fs, index_eval = fs_receive_ef(fs, 1) + fs, index_eval = fs_receive_ef_inlined(fs, 1) debug_assert(len(pcs_values[table_index][0][col_index]) == 0) pcs_values[table_index][0][col_index].push(index_eval) for i in unroll(0, len(LOOKUPS_F_VALUES[table_index][lookup_f_index])): - fs, value_eval = fs_receive_ef(fs, 1) + fs, value_eval = fs_receive_ef_inlined(fs, 1) col_index = LOOKUPS_F_VALUES[table_index][lookup_f_index][i] debug_assert(len(pcs_values[table_index][0][col_index]) == 0) pcs_values[table_index][0][col_index].push(value_eval) @@ -272,7 +290,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip for lookup_ef_index in unroll(0, len(LOOKUPS_EF_INDEXES[table_index])): col_index = LOOKUPS_EF_INDEXES[table_index][lookup_ef_index] - fs, index_eval = fs_receive_ef(fs, 1) + fs, index_eval = fs_receive_ef_inlined(fs, 1) if len(pcs_values[table_index][0][col_index]) == 0: pcs_values[table_index][0][col_index].push(index_eval) else: @@ -280,7 +298,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip copy_5(index_eval, pcs_values[table_index][0][col_index][0]) for i in unroll(0, DIM): - fs, value_eval = fs_receive_ef(fs, 1) + fs, value_eval = fs_receive_ef_inlined(fs, 1) pref = multilinear_location_prefix( offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr ) # TODO there is some duplication here @@ -322,10 +340,17 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip fs, air_alpha = fs_sample_ef(fs) air_alpha_powers = powers_const(air_alpha, MAX_NUM_AIR_CONSTRAINTS + 1) - for table_index in unroll(0, N_TABLES): + for sorted_pos in unroll(0, N_TABLES): + table_index: Imu + if sorted_pos == 0: + table_index = EXECUTION_TABLE_INDEX + if sorted_pos == 1: + table_index = second_table + if sorted_pos == 2: + table_index = third_table log_n_rows = table_log_heights[table_index] - bus_numerator_value = bus_numerators_values[table_index] - bus_denominator_value = bus_denominators_values[table_index] + bus_numerator_value = bus_numerators_values[sorted_pos] + bus_denominator_value = bus_denominators_values[sorted_pos] total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index] bus_final_value: Mut = bus_numerator_value @@ -346,7 +371,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip n_down_columns_ef = len(AIR_DOWN_COLUMNS_EF[table_index]) n_up_columns = n_up_columns_f + n_up_columns_ef n_down_columns = n_down_columns_f + n_down_columns_ef - fs, inner_evals = fs_receive_ef(fs, n_up_columns + n_down_columns) + fs, inner_evals = fs_receive_ef_inlined(fs, n_up_columns + n_down_columns) air_constraints_eval = evaluate_air_constraints( table_index, inner_evals, air_alpha_powers, bus_beta, logup_alphas_eq_poly @@ -378,14 +403,14 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip matrix_down_sc_eval = next_mle(outer_point, inner_point, log_n_rows) - fs, evals_f_on_down_columns = fs_receive_ef(fs, n_down_columns_f) + fs, evals_f_on_down_columns = fs_receive_ef_inlined(fs, n_down_columns_f) batched_col_down_sc_eval: Mut = dot_product_ret( evals_f_on_down_columns, batching_scalar_powers, n_down_columns_f, EE ) evals_ef_on_down_columns: Imu if n_down_columns_ef != 0: - fs, evals_ef_on_down_columns = fs_receive_ef(fs, n_down_columns_ef) + fs, evals_ef_on_down_columns = fs_receive_ef_inlined(fs, n_down_columns_ef) batched_col_down_sc_eval = add_extension_ret( batched_col_down_sc_eval, dot_product_ret( @@ -411,7 +436,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip evals_f_on_down_columns + i * DIM ) for i in unroll(0, n_down_columns_ef): - fs, transposed = fs_receive_ef(fs, DIM) + fs, transposed = fs_receive_ef_inlined(fs, DIM) copy_5( evals_ef_on_down_columns + i * DIM, dot_product_with_the_base_vectors(transposed), @@ -429,7 +454,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip pcs_values[table_index][last_index_2][i].push(inner_evals + i * DIM) for i in unroll(0, n_up_columns_ef): - fs, transposed = fs_receive_ef(fs, DIM) + fs, transposed = fs_receive_ef_inlined(fs, DIM) copy_5( inner_evals + (n_up_columns_f + n_down_columns_f + i) * DIM, dot_product_with_the_base_vectors(transposed), @@ -455,11 +480,12 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip # WHIR BASE combination_randomness_gen: Mut fs, combination_randomness_gen = fs_sample_ef(fs) - combination_randomness_powers: Mut = powers_const( - combination_randomness_gen, WHIR_NUM_OOD_COMMIT + TOTAL_WHIR_STATEMENTS + combination_randomness_powers: Mut = powers( + combination_randomness_gen, num_ood_at_commitment + TOTAL_WHIR_STATEMENTS ) - whir_sum: Mut = dot_product_ret(whir_base_ood_evals, combination_randomness_powers, WHIR_NUM_OOD_COMMIT, EE) - curr_randomness: Mut = combination_randomness_powers + WHIR_NUM_OOD_COMMIT * DIM + whir_sum: Mut = Array(DIM) + dot_product_ee_dynamic(whir_base_ood_evals, combination_randomness_powers, whir_sum, num_ood_at_commitment) + curr_randomness: Mut = combination_randomness_powers + num_ood_at_commitment * DIM whir_sum = add_extension_ret(mul_extension_ret(value_memory, curr_randomness), whir_sum) curr_randomness += DIM @@ -475,7 +501,14 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip whir_sum = add_extension_ret(mul_extension_ret(embed_in_ef(ENDING_PC), curr_randomness), whir_sum) curr_randomness += DIM - for table_index in unroll(0, N_TABLES): + for sorted_pos in unroll(0, N_TABLES): + table_index: Imu + if sorted_pos == 0: + table_index = EXECUTION_TABLE_INDEX + if sorted_pos == 1: + table_index = second_table + if sorted_pos == 2: + table_index = third_table debug_assert(len(pcs_points[table_index]) == len(pcs_values[table_index])) for i in unroll(0, len(pcs_values[table_index])): for j in unroll(0, len(pcs_values[table_index][i])): @@ -501,7 +534,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip whir_sum, ) - curr_randomness = combination_randomness_powers + WHIR_NUM_OOD_COMMIT * DIM + curr_randomness = combination_randomness_powers + num_ood_at_commitment * DIM eq_memory_and_acc_point = eq_mle_extension( folding_randomness_global + (stacked_n_vars - log_memory) * DIM, @@ -571,7 +604,14 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip s = add_extension_ret(s, mul_extension_ret(curr_randomness, prefix_pc_end)) curr_randomness += DIM - for table_index in unroll(0, N_TABLES): + for sorted_pos in unroll(0, N_TABLES): + table_index: Imu + if sorted_pos == 0: + table_index = EXECUTION_TABLE_INDEX + if sorted_pos == 1: + table_index = second_table + if sorted_pos == 2: + table_index = third_table log_n_rows = table_log_heights[table_index] n_rows = table_heights[table_index] total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index] @@ -597,7 +637,6 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip offset += n_rows * total_num_cols copy_5(mul_extension_ret(s, final_value), end_sum) - return @@ -629,8 +668,8 @@ def fingerprint_bytecode(instr_evals, eval_on_pc, logup_alphas_eq_poly): def verify_gkr_quotient(fs: Mut, n_vars): - fs, nums = fs_receive_ef(fs, 2) - fs, denoms = fs_receive_ef(fs, 2) + fs, nums = fs_receive_ef_inlined(fs, 2) + fs, denoms = fs_receive_ef_inlined(fs, 2) q1 = div_extension_ret(nums, denoms) q2 = div_extension_ret(nums + DIM, denoms + DIM) @@ -669,7 +708,7 @@ def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): num_plus_alpha_mul_claim_den = add_extension_ret(claim_num, alpha_mul_claim_den) postponed_point = Array((n_vars + 1) * DIM) fs, postponed_value = sumcheck_verify_helper(fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point + DIM) - fs, inner_evals = fs_receive_ef(fs, 4) + fs, inner_evals = fs_receive_ef_inlined(fs, 4) a_num = inner_evals b_num = inner_evals + DIM a_den = inner_evals + 2 * DIM @@ -726,6 +765,6 @@ def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, bus_bet case 2: res = evaluate_air_constraints_table_2(inner_evals, air_alpha_powers, bus_beta, logup_alphas_eq_poly) return res - + EVALUATE_AIR_FUNCTIONS_PLACEHOLDER diff --git a/crates/rec_aggregation/src/recursion.rs b/crates/rec_aggregation/src/recursion.rs index 63879afc..7155a04f 100644 --- a/crates/rec_aggregation/src/recursion.rs +++ b/crates/rec_aggregation/src/recursion.rs @@ -4,24 +4,20 @@ use std::rc::Rc; use std::time::Instant; use lean_compiler::{CompilationFlags, ProgramSource, compile_program, compile_program_with_flags}; -use lean_prover::default_whir_config; use lean_prover::prove_execution::prove_execution; use lean_prover::verify_execution::verify_execution; +use lean_prover::{ + MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_INITIAL_FOLDING_FACTOR, WHIR_SUBSEQUENT_FOLDING_FACTOR, default_whir_config, +}; use lean_vm::*; use multilinear_toolkit::prelude::symbolic::{ SymbolicExpression, SymbolicOperation, get_symbolic_constraints_and_bus_data_values, }; use multilinear_toolkit::prelude::*; +use sub_protocols::min_stacked_n_vars; use utils::{BYTECODE_TABLE_INDEX, Counter, MEMORY_TABLE_INDEX}; pub fn run_recursion_benchmark(count: usize, log_inv_rate: usize, prox_gaps_conjecture: bool, tracing: bool) { - let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("recursion.py") - .to_str() - .unwrap() - .to_string(); - - let inner_whir_config = default_whir_config(log_inv_rate, prox_gaps_conjecture); let program_to_prove = r#" DIM = 5 POSEIDON_OF_ZERO = POSEIDON_OF_ZERO_PLACEHOLDER @@ -50,6 +46,45 @@ def main(): return "# .replace("POSEIDON_OF_ZERO_PLACEHOLDER", &POSEIDON_16_NULL_HASH_PTR.to_string()); + run_recursion_benchmark_with_program(count, log_inv_rate, prox_gaps_conjecture, tracing, &program_to_prove); +} + +#[test] +fn test_end2end_recursion_poseidon_heavy() { + // Poseidon table larger than dot_product table (reversed ordering) + let program_to_prove = r#" +DIM = 5 +POSEIDON_OF_ZERO = POSEIDON_OF_ZERO_PLACEHOLDER +BE = 1 + +def main(): + for i in range(0, 1000): + null_ptr = ZERO_VEC_PTR + poseidon_of_zero = POSEIDON_OF_ZERO + poseidon16(null_ptr, null_ptr, poseidon_of_zero) + poseidon16(null_ptr, null_ptr, poseidon_of_zero) + poseidon16(null_ptr, null_ptr, poseidon_of_zero) + dot_product(ZERO_VEC_PTR, ZERO_VEC_PTR, ZERO_VEC_PTR, 2, BE) + return +"# + .replace("POSEIDON_OF_ZERO_PLACEHOLDER", &POSEIDON_16_NULL_HASH_PTR.to_string()); + run_recursion_benchmark_with_program(1, 2, false, false, &program_to_prove); +} + +fn run_recursion_benchmark_with_program( + count: usize, + log_inv_rate: usize, + prox_gaps_conjecture: bool, + tracing: bool, + program_to_prove: &str, +) { + let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("recursion.py") + .to_str() + .unwrap() + .to_string(); + + let inner_whir_config = default_whir_config(log_inv_rate, prox_gaps_conjecture); let bytecode_to_prove = compile_program(&ProgramSource::Raw(program_to_prove.to_string())); precompute_dft_twiddles::(1 << 24); let inner_public_input = vec![]; @@ -73,20 +108,65 @@ def main(): let mut replacements = whir_recursion_placeholder_replacements(&outer_whir_config); - assert!( - verif_details.log_memory >= verif_details.table_n_vars[&Table::execution()] - && verif_details - .table_n_vars - .values() + let log_bytecode = log2_ceil_usize(bytecode_to_prove.instructions.len()); + let min_stacked = min_stacked_n_vars(log_bytecode); + + let mut all_potential_num_queries = vec![]; + let mut all_potential_grinding = vec![]; + let mut all_potential_num_oods = vec![]; + for log_inv_rate in MIN_WHIR_LOG_INV_RATE..=MAX_WHIR_LOG_INV_RATE { + let max_n_vars = F::TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate; + let whir_config_builder = default_whir_config(log_inv_rate, prox_gaps_conjecture); + let whir_config = WhirConfig::::new(&whir_config_builder, max_n_vars); + let mut num_queries = vec![]; + let mut grinding_bits = vec![]; + for round in &whir_config.round_parameters { + num_queries.push(round.num_queries); + grinding_bits.push(round.pow_bits); + } + num_queries.push(whir_config.final_queries); + grinding_bits.push(whir_config.final_pow_bits); + all_potential_num_queries.push(format!( + "[{}]", + num_queries.iter().map(|q| q.to_string()).collect::>().join(", ") + )); + all_potential_grinding.push(format!( + "[{}]", + grinding_bits + .iter() + .map(|q| q.to_string()) .collect::>() - .windows(2) - .all(|w| w[0] >= w[1]), - "TODO a more general recursion program", + .join(", ") + )); + + // OOD samples for each possible stacked_n_vars + let mut oods_for_rate = vec![]; + for n_vars in min_stacked..=max_n_vars { + let cfg = WhirConfig::::new(&whir_config_builder, n_vars); + let mut oods = vec![cfg.committment_ood_samples]; + for round in &cfg.round_parameters { + oods.push(round.ood_samples); + } + oods_for_rate.push(format!( + "[{}]", + oods.iter().map(|o| o.to_string()).collect::>().join(", ") + )); + } + all_potential_num_oods.push(format!("[{}]", oods_for_rate.join(", "))); + } + replacements.insert( + "WHIR_ALL_POTENTIAL_NUM_QUERIES_PLACEHOLDER".to_string(), + format!("[{}]", all_potential_num_queries.join(", ")), ); - assert_eq!( - verif_details.table_n_vars.keys().copied().collect::>(), - vec![Table::execution(), Table::dot_product(), Table::poseidon16()] + replacements.insert( + "WHIR_ALL_POTENTIAL_GRINDING_PLACEHOLDER".to_string(), + format!("[{}]", all_potential_grinding.join(", ")), ); + replacements.insert( + "WHIR_ALL_POTENTIAL_NUM_OODS_PLACEHOLDER".to_string(), + format!("[{}]", all_potential_num_oods.join(", ")), + ); + replacements.insert("MIN_STACKED_N_VARS_PLACEHOLDER".to_string(), min_stacked.to_string()); // VM recursion parameters (different from WHIR) replacements.insert("N_TABLES_PLACEHOLDER".to_string(), N_TABLES.to_string()); @@ -106,6 +186,18 @@ def main(): "MAX_WHIR_LOG_INV_RATE_PLACEHOLDER".to_string(), MAX_WHIR_LOG_INV_RATE.to_string(), ); + replacements.insert( + "MAX_NUM_VARIABLES_TO_SEND_COEFFS_PLACEHOLDER".to_string(), + MAX_NUM_VARIABLES_TO_SEND_COEFFS.to_string(), + ); + replacements.insert( + "WHIR_INITIAL_FOLDING_FACTOR_PLACEHOLDER".to_string(), + WHIR_INITIAL_FOLDING_FACTOR.to_string(), + ); + replacements.insert( + "WHIR_SUBSEQUENT_FOLDING_FACTOR_PLACEHOLDER".to_string(), + WHIR_SUBSEQUENT_FOLDING_FACTOR.to_string(), + ); replacements.insert( "MAX_LOG_N_ROWS_PER_TABLE_PLACEHOLDER".to_string(), format!( @@ -357,42 +449,20 @@ def main(): pub(crate) fn whir_recursion_placeholder_replacements(whir_config: &WhirConfig) -> BTreeMap { let mut num_queries = vec![]; - let mut ood_samples = vec![]; - let mut grinding_bits = vec![]; let mut folding_factors = vec![]; for round in &whir_config.round_parameters { num_queries.push(round.num_queries.to_string()); - ood_samples.push(round.ood_samples.to_string()); - grinding_bits.push(round.pow_bits.to_string()); folding_factors.push(round.folding_factor.to_string()); } folding_factors.push(whir_config.final_round_config().folding_factor.to_string()); - grinding_bits.push(whir_config.final_pow_bits.to_string()); num_queries.push(whir_config.final_queries.to_string()); let end = "_PLACEHOLDER"; let mut replacements = BTreeMap::new(); - replacements.insert( - format!("WHIR_NUM_QUERIES{}", end), - format!("[{}]", num_queries.join(", ")), - ); - replacements.insert( - format!("WHIR_NUM_OOD_COMMIT{}", end), - whir_config.committment_ood_samples.to_string(), - ); - replacements.insert(format!("WHIR_NUM_OODS{}", end), format!("[{}]", ood_samples.join(", "))); - replacements.insert( - format!("WHIR_GRINDING_BITS{}", end), - format!("[{}]", grinding_bits.join(", ")), - ); replacements.insert( format!("WHIR_FOLDING_FACTORS{}", end), format!("[{}]", folding_factors.join(", ")), ); - replacements.insert( - format!("WHIR_FINAL_VARS{}", end), - whir_config.n_vars_of_final_polynomial().to_string(), - ); replacements.insert( format!("WHIR_FIRST_RS_REDUCTION_FACTOR{}", end), whir_config.rs_domain_initial_reduction_factor.to_string(), diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 1eaec292..2d183999 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -25,9 +25,9 @@ def div_ceil_dynamic(a, b: Const): def powers(alpha, n): # alpha: EF # n: F - assert n < 128 + assert n < 256 assert 0 < n - res = match_range(n, range(1, 128), lambda i: powers_const(alpha, i)) + res = match_range(n, range(1, 256), lambda i: powers_const(alpha, i)) return res @@ -59,8 +59,8 @@ def unit_root_pow_const(domain_size: Const, index_bits): def poly_eq_extension_dynamic(point, n): - debug_assert(n < 8) - res = match_range(n, range(0, 1), lambda i: ONE_VEC_PTR, range(1, 8), lambda i: poly_eq_extension(point, i)) + debug_assert(n < 9) + res = match_range(n, range(0, 1), lambda i: ONE_VEC_PTR, range(1, 9), lambda i: poly_eq_extension(point, i)) return res @@ -82,8 +82,7 @@ def poly_eq_extension(point, n: Const): return res + (2**n - 1) * DIM -@inline -def poly_eq_base(point, n): +def poly_eq_base(point, n: Const): # Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] res = Array((2 ** (n + 1) - 1)) @@ -184,11 +183,9 @@ def expand_from_univariate_ext(alpha, n): def dot_product_be_dynamic(a, b, res, n): - for i in unroll(6, 10): - if n == 2**i: - dot_product(a, b, res, 2**i, BE) - return - assert False, "dot_product_be_dynamic called with unsupported n" + debug_assert(n <= 256) + match_range(n, range(1, 257), lambda i: dot_product(a, b, res, i, BE)) + return def dot_product_ee_dynamic(a, b, res, n): diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 4e4de9f5..05dc65e2 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -1,15 +1,17 @@ from snark_lib import * from fiat_shamir import * -WHIR_FOLDING_FACTORS = WHIR_FOLDING_FACTORS_PLACEHOLDER -WHIR_FINAL_VARS = WHIR_FINAL_VARS_PLACEHOLDER +WHIR_INITIAL_FOLDING_FACTOR = WHIR_INITIAL_FOLDING_FACTOR_PLACEHOLDER +WHIR_SUBSEQUENT_FOLDING_FACTOR = WHIR_SUBSEQUENT_FOLDING_FACTOR_PLACEHOLDER WHIR_FIRST_RS_REDUCTION_FACTOR = WHIR_FIRST_RS_REDUCTION_FACTOR_PLACEHOLDER -WHIR_NUM_OOD_COMMIT = WHIR_NUM_OOD_COMMIT_PLACEHOLDER -WHIR_NUM_OODS = WHIR_NUM_OODS_PLACEHOLDER -WHIR_GRINDING_BITS = WHIR_GRINDING_BITS_PLACEHOLDER -WHIR_NUM_QUERIES = WHIR_NUM_QUERIES_PLACEHOLDER -WHIR_N_ROUNDS = len(WHIR_NUM_QUERIES) - 1 +MIN_WHIR_LOG_INV_RATE = MIN_WHIR_LOG_INV_RATE_PLACEHOLDER +MAX_WHIR_LOG_INV_RATE = MAX_WHIR_LOG_INV_RATE_PLACEHOLDER +MAX_NUM_VARIABLES_TO_SEND_COEFFS = MAX_NUM_VARIABLES_TO_SEND_COEFFS_PLACEHOLDER +WHIR_ALL_POTENTIAL_NUM_QUERIES = WHIR_ALL_POTENTIAL_NUM_QUERIES_PLACEHOLDER +WHIR_ALL_POTENTIAL_GRINDING = WHIR_ALL_POTENTIAL_GRINDING_PLACEHOLDER +WHIR_ALL_POTENTIAL_NUM_OODS = WHIR_ALL_POTENTIAL_NUM_OODS_PLACEHOLDER +MIN_STACKED_N_VARS = MIN_STACKED_N_VARS_PLACEHOLDER def whir_open( fs: Mut, @@ -20,13 +22,20 @@ def whir_open( combination_randomness_powers_0, claimed_sum: Mut, ): - all_folding_randomness = Array(WHIR_N_ROUNDS + 2) - all_ood_points = Array(WHIR_N_ROUNDS) - all_circle_values = Array(WHIR_N_ROUNDS + 1) - all_combination_randomness_powers = Array(WHIR_N_ROUNDS) + n_rounds, n_final_vars, num_queries, num_oods, grinding_bits = get_whir_params(n_vars, initial_log_inv_rate) + n_final_coeffs = powers_of_two(n_final_vars) + folding_factors = Array(n_rounds + 1) + folding_factors[0] = WHIR_INITIAL_FOLDING_FACTOR + for i in range(1, n_rounds + 1): + folding_factors[i] = WHIR_SUBSEQUENT_FOLDING_FACTOR + + all_folding_randomness = Array(n_rounds + 2) + all_ood_points = Array(n_rounds) + all_circle_values = Array(n_rounds + 1) + all_combination_randomness_powers = Array(n_rounds) domain_sz: Mut = n_vars + initial_log_inv_rate - for r in unroll(0, WHIR_N_ROUNDS): + for r in range(0, n_rounds): is_first_round: Imu if r == 0: is_first_round = 1 @@ -43,84 +52,84 @@ def whir_open( ) = whir_round( fs, root, - WHIR_FOLDING_FACTORS[r], - 2 ** WHIR_FOLDING_FACTORS[r], + folding_factors[r], + powers_of_two(folding_factors[r]), is_first_round, - WHIR_NUM_QUERIES[r], + num_queries[r], domain_sz, claimed_sum, - WHIR_GRINDING_BITS[r], - WHIR_NUM_OODS[r], + grinding_bits[r], + num_oods[r + 1], ) if r == 0: domain_sz -= WHIR_FIRST_RS_REDUCTION_FACTOR else: domain_sz -= 1 - fs, all_folding_randomness[WHIR_N_ROUNDS], claimed_sum = sumcheck_verify( - fs, WHIR_FOLDING_FACTORS[WHIR_N_ROUNDS], claimed_sum, 2 + fs, all_folding_randomness[n_rounds], claimed_sum = sumcheck_verify( + fs, WHIR_SUBSEQUENT_FOLDING_FACTOR, claimed_sum, 2 ) - fs, final_coeffcients = fs_receive_ef(fs, 2**WHIR_FINAL_VARS) + fs, final_coeffcients = fs_receive_ef_by_log_dynamic(fs, n_final_vars, MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS+ 1) - fs, all_circle_values[WHIR_N_ROUNDS], final_folds = sample_stir_indexes_and_fold( + fs, all_circle_values[n_rounds], final_folds = sample_stir_indexes_and_fold( fs, - WHIR_NUM_QUERIES[WHIR_N_ROUNDS], + num_queries[n_rounds], 0, - WHIR_FOLDING_FACTORS[WHIR_N_ROUNDS], - 2 ** WHIR_FOLDING_FACTORS[WHIR_N_ROUNDS], + WHIR_SUBSEQUENT_FOLDING_FACTOR, + 2 ** WHIR_SUBSEQUENT_FOLDING_FACTOR, domain_sz, root, - all_folding_randomness[WHIR_N_ROUNDS], - WHIR_GRINDING_BITS[WHIR_N_ROUNDS], + all_folding_randomness[n_rounds], + grinding_bits[n_rounds], ) - final_circle_values = all_circle_values[WHIR_N_ROUNDS] - for i in range(0, WHIR_NUM_QUERIES[WHIR_N_ROUNDS]): - powers_of_2_rev = expand_from_univariate_base_const(final_circle_values[i], WHIR_FINAL_VARS) - poly_eq = poly_eq_base(powers_of_2_rev, WHIR_FINAL_VARS) + final_circle_values = all_circle_values[n_rounds] + for i in range(0, num_queries[n_rounds]): + powers_of_2_rev = expand_from_univariate_base(final_circle_values[i], n_final_vars) + poly_eq = match_range(n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), lambda n: poly_eq_base(powers_of_2_rev, n)) final_pol_evaluated_on_circle = Array(DIM) - dot_product( + dot_product_be_dynamic( poly_eq, final_coeffcients, final_pol_evaluated_on_circle, - 2**WHIR_FINAL_VARS, - BE, + n_final_coeffs, ) copy_5(final_pol_evaluated_on_circle, final_folds + i * DIM) - fs, all_folding_randomness[WHIR_N_ROUNDS + 1], end_sum = sumcheck_verify(fs, WHIR_FINAL_VARS, claimed_sum, 2) + fs, all_folding_randomness[n_rounds + 1], end_sum = sumcheck_verify(fs, n_final_vars, claimed_sum, 2) folding_randomness_global = Array(n_vars * DIM) start: Mut = folding_randomness_global - for i in unroll(0, WHIR_N_ROUNDS + 1): - for j in unroll(0, WHIR_FOLDING_FACTORS[i]): + for i in range(0, n_rounds + 1): + for j in range(0, folding_factors[i]): copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) - start += WHIR_FOLDING_FACTORS[i] * DIM - for j in unroll(0, WHIR_FINAL_VARS): - copy_5(all_folding_randomness[WHIR_N_ROUNDS + 1] + j * DIM, start + j * DIM) + start += folding_factors[i] * DIM + for j in range(0, n_final_vars): + copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM) - all_ood_recovered_evals = Array(WHIR_NUM_OOD_COMMIT * DIM) - for i in unroll(0, WHIR_NUM_OOD_COMMIT): + all_ood_recovered_evals = Array(num_oods[0] * DIM) + for i in range(0, num_oods[0]): expanded_from_univariate = expand_from_univariate_ext(ood_points_commit + i * DIM, n_vars) ood_rec = eq_mle_extension(expanded_from_univariate, folding_randomness_global, n_vars) copy_5(ood_rec, all_ood_recovered_evals + i * DIM) - s: Mut = dot_product_ret( + s: Mut = Array(DIM) + dot_product_ee_dynamic( all_ood_recovered_evals, combination_randomness_powers_0, - WHIR_NUM_OOD_COMMIT, - EE, + s, + num_oods[0], ) n_vars_remaining: Mut = n_vars my_folding_randomness: Mut = folding_randomness_global - for i in unroll(0, WHIR_N_ROUNDS): - n_vars_remaining -= WHIR_FOLDING_FACTORS[i] - my_ood_recovered_evals = Array(WHIR_NUM_OODS[i] * DIM) + for i in range(0, n_rounds): + n_vars_remaining -= folding_factors[i] + my_ood_recovered_evals = Array(num_oods[i + 1] * DIM) combination_randomness_powers = all_combination_randomness_powers[i] - my_folding_randomness += WHIR_FOLDING_FACTORS[i] * DIM - for j in unroll(0, WHIR_NUM_OODS[i]): + my_folding_randomness += folding_factors[i] * DIM + for j in range(0, num_oods[i + 1]): expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars_remaining) ood_rec = eq_mle_extension(expanded_from_univariate, my_folding_randomness, n_vars_remaining) copy_5(ood_rec, my_ood_recovered_evals + j * DIM) @@ -129,25 +138,27 @@ def whir_open( my_ood_recovered_evals, combination_randomness_powers, summed_ood, - WHIR_NUM_OODS[i], + num_oods[i + 1], ) - s6s = Array((WHIR_NUM_QUERIES[i]) * DIM) + s6s = Array((num_queries[i]) * DIM) circle_value_i = all_circle_values[i] - for j in range(0, WHIR_NUM_QUERIES[i]): # unroll ? + for j in range(0, num_queries[i]): # unroll ? expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars_remaining) temp = eq_mle_base_extension(expanded_from_univariate, my_folding_randomness, n_vars_remaining) copy_5(temp, s6s + j * DIM) - s7 = dot_product_ret( + s7 = Array(DIM) + dot_product_ee_dynamic( s6s, - combination_randomness_powers + WHIR_NUM_OODS[i] * DIM, - WHIR_NUM_QUERIES[i], - EE, + combination_randomness_powers + num_oods[i + 1] * DIM, + s7, + num_queries[i], ) s = add_extension_ret(s, s7) s = add_extension_ret(summed_ood, s) - poly_eq_final = poly_eq_extension(all_folding_randomness[WHIR_N_ROUNDS + 1], WHIR_FINAL_VARS) - final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**WHIR_FINAL_VARS, EE) + poly_eq_final = poly_eq_extension_dynamic(all_folding_randomness[n_rounds + 1], n_final_vars) + final_value = Array(DIM) + dot_product_ee_dynamic(poly_eq_final, final_coeffcients, final_value, n_final_coeffs) # copy_5(mul_extension_ret(s, final_value), end_sum); return fs, folding_randomness_global, s, final_value, end_sum @@ -161,7 +172,7 @@ def sumcheck_verify(fs: Mut, n_steps, claimed_sum, degree: Const): def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, challenges): for sc_round in range(0, n_steps): - fs, poly = fs_receive_ef(fs, degree + 1) + fs, poly = fs_receive_ef_inlined(fs, degree + 1) sum_over_boolean_hypercube = polynomial_sum_at_0_and_1(poly, degree) copy_5(sum_over_boolean_hypercube, claimed_sum) fs, rand = fs_sample_ef(fs) @@ -223,7 +234,7 @@ def sample_stir_indexes_and_fold( if merkle_leaves_in_basefield == 1: for i in range(0, num_queries): - dot_product(answers[i], poly_eq, folds + i * DIM, 2 ** WHIR_FOLDING_FACTORS[0], BE) + dot_product_be_dynamic(answers[i], poly_eq, folds + i * DIM, two_pow_folding_factor) else: for i in range(0, num_queries): dot_product_ee_dynamic(answers[i], poly_eq, folds + i * DIM, two_pow_folding_factor) @@ -306,22 +317,63 @@ def parse_commitment(fs: Mut, num_ood): ood_evals: Imu debug_assert(num_ood < 4) debug_assert(num_ood != 0) - match num_ood: - case 0: - _ = 0 # unreachable - case 1: - fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 1) - case 2: - fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 2) - case 3: - fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 3) - case 4: - fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 4) + fs, root, ood_points, ood_evals = match_range(num_ood, range(1, 4), lambda n: parse_whir_commitment_const(fs, n)) return fs, root, ood_points, ood_evals def parse_whir_commitment_const(fs: Mut, num_ood: Const): fs, root = fs_receive_chunks(fs, 1) fs, ood_points = fs_sample_many_ef(fs, num_ood) - fs, ood_evals = fs_receive_ef(fs, num_ood) + fs, ood_evals = fs_receive_ef_inlined(fs, num_ood) return fs, root, ood_points, ood_evals + +@inline +def get_whir_params(n_vars, log_inv_rate): + debug_assert(WHIR_INITIAL_FOLDING_FACTOR < n_vars) + nv_except_first_round = n_vars - WHIR_INITIAL_FOLDING_FACTOR + debug_assert(MAX_NUM_VARIABLES_TO_SEND_COEFFS < nv_except_first_round) + n_rounds = div_ceil_dynamic(nv_except_first_round - MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) + final_vars = nv_except_first_round - n_rounds * WHIR_SUBSEQUENT_FOLDING_FACTOR + + debug_assert(MIN_WHIR_LOG_INV_RATE <= log_inv_rate) + debug_assert(log_inv_rate <= MAX_WHIR_LOG_INV_RATE) + num_queries: Imu + num_queries = match_range(log_inv_rate, range(MIN_WHIR_LOG_INV_RATE, MAX_WHIR_LOG_INV_RATE + 1), lambda r: get_num_queries(r)) + + grinding_bits: Imu + grinding_bits = match_range(log_inv_rate, range(MIN_WHIR_LOG_INV_RATE, MAX_WHIR_LOG_INV_RATE + 1), lambda r: get_grinding_bits(r)) + + num_oods = get_num_oods(log_inv_rate, n_vars) + + return n_rounds, final_vars, num_queries, num_oods, grinding_bits + +def get_num_queries(log_inv_rate: Const): + max = len(WHIR_ALL_POTENTIAL_NUM_QUERIES[log_inv_rate - MIN_WHIR_LOG_INV_RATE]) + num_queries = Array(max) + for i in unroll(0, max): + num_queries[i] = WHIR_ALL_POTENTIAL_NUM_QUERIES[log_inv_rate - MIN_WHIR_LOG_INV_RATE][i] + return num_queries + + +def get_grinding_bits(log_inv_rate: Const): + max = len(WHIR_ALL_POTENTIAL_GRINDING[log_inv_rate - MIN_WHIR_LOG_INV_RATE]) + grinding_bits = Array(max) + for i in unroll(0, max): + grinding_bits[i] = WHIR_ALL_POTENTIAL_GRINDING[log_inv_rate - MIN_WHIR_LOG_INV_RATE][i] + return grinding_bits + +def get_num_oods(log_inv_rate, n_vars): + res = match_range(log_inv_rate, range(MIN_WHIR_LOG_INV_RATE, MAX_WHIR_LOG_INV_RATE + 1), lambda r: get_num_oods_const_rate(r, n_vars)) + return res + +def get_num_oods_const_rate(log_inv_rate: Const, n_vars): + res = match_range(n_vars, range(MIN_STACKED_N_VARS, TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), lambda nv: get_num_oods_const(log_inv_rate, nv)) + return res + + +def get_num_oods_const(log_inv_rate: Const, n_vars: Const): + max = len(WHIR_ALL_POTENTIAL_NUM_OODS[log_inv_rate - MIN_WHIR_LOG_INV_RATE][n_vars - MIN_STACKED_N_VARS]) + num_oods = Array(max) + for i in unroll(0, max): + num_oods[i] = WHIR_ALL_POTENTIAL_NUM_OODS[log_inv_rate - MIN_WHIR_LOG_INV_RATE][n_vars - MIN_STACKED_N_VARS][i] + return num_oods \ No newline at end of file diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 4c1dd19b..480a4644 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -1,4 +1,7 @@ -use lean_vm::{COL_PC, CommittedStatements, ENDING_PC, STARTING_PC, sort_tables_by_height}; +use lean_vm::{ + ALL_TABLES, COL_PC, CommittedStatements, ENDING_PC, MIN_LOG_MEMORY_SIZE, MIN_LOG_N_ROWS_PER_TABLE, STARTING_PC, + sort_tables_by_height, +}; use lean_vm::{EF, F, Table, TableT, TableTrace}; use multilinear_toolkit::prelude::*; use owo_colors::OwoColorize; @@ -178,14 +181,22 @@ pub fn stacked_pcs_parse_commitment( fn compute_stacked_n_vars( log_memory: usize, log_bytecode: usize, - tables_heights: &BTreeMap, + tables_log_heights: &BTreeMap, ) -> VarCount { - let max_table_log_n_rows = tables_heights.values().copied().max().unwrap(); + let max_table_log_n_rows = tables_log_heights.values().copied().max().unwrap(); let total_len = (2 << log_memory) + (1 << log_bytecode.max(max_table_log_n_rows)) - + tables_heights + + tables_log_heights .iter() .map(|(table, log_n_rows)| table.n_committed_columns() << log_n_rows) .sum::(); log2_ceil_usize(total_len) } + +pub fn min_stacked_n_vars(log_bytecode: usize) -> usize { + let mut min_tables_log_heights = BTreeMap::new(); + for table in ALL_TABLES { + min_tables_log_heights.insert(table, MIN_LOG_N_ROWS_PER_TABLE); + } + compute_stacked_n_vars(MIN_LOG_MEMORY_SIZE, log_bytecode, &min_tables_log_heights) +}