Skip to content

Commit 08783ca

Browse files
committed
wip convert local final ram circuit to gkr-iop circuit
1 parent 40887fc commit 08783ca

File tree

11 files changed

+229
-313
lines changed

11 files changed

+229
-313
lines changed

ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ impl<E: ExtensionField> MmuConfig<'_, E> {
4646
let hints_config = cs.register_table_circuit::<HintsCircuit<E>>();
4747
let stack_init_config = cs.register_table_circuit::<StackInitCircuit<E>>();
4848
let heap_init_config = cs.register_table_circuit::<HeapInitCircuit<E>>();
49+
println!("register LocalFinalCircuit");
4950
let local_final_circuit = cs.register_table_circuit::<LocalFinalCircuit<E>>();
51+
println!("end register LocalFinalCircuit");
5052
let ram_bus_circuit = cs.register_table_circuit::<RBCircuit<E>>();
5153

5254
Self {

ceno_zkvm/src/scheme/cpu/mod.rs

Lines changed: 13 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -681,98 +681,19 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<C
681681
Some(gkr_proof),
682682
))
683683
} else {
684-
let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) =
685-
if next_pow2_instance_padding(num_instances) == num_instances {
686-
let span = entered_span!("fixed::evals + witin::evals");
687-
let mut evals = input
688-
.witness
689-
.par_iter()
690-
.chain(input.fixed.par_iter())
691-
.map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()]))
692-
.collect::<Vec<_>>();
693-
let fixed_in_evals = evals.split_off(input.witness.len());
694-
let wits_in_evals = evals;
695-
exit_span!(span);
696-
(wits_in_evals, fixed_in_evals, None, rt_tower)
697-
} else {
698-
assert!(cs.r_table_expressions.len() <= 1);
699-
assert!(cs.w_table_expressions.len() <= 1);
700-
701-
let sel_type = SelectorType::Prefix(E::BaseField::ZERO, 0.into());
702-
let mut sel_mle = sel_type.compute(&rt_tower, num_instances).unwrap();
703-
704-
// `wit` := witin ++ fixed
705-
// we concat eq in between `wit` := witin ++ eqs ++ fixed
706-
let all_witins = input
707-
.witness
708-
.iter()
709-
.map(|mle| Either::Left(mle.as_ref()))
710-
.chain(vec![Either::Right(&mut sel_mle)])
711-
.chain(input.fixed.iter().map(|mle| Either::Left(mle.as_ref())))
712-
.collect_vec();
713-
assert_eq!(
714-
all_witins.len() as WitnessId,
715-
cs.num_witin + cs.num_structural_witin + cs.num_fixed as WitnessId,
716-
"all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}",
717-
all_witins.len(),
718-
cs.num_witin,
719-
cs.num_structural_witin,
720-
cs.num_fixed,
721-
);
722-
let builder = VirtualPolynomialsBuilder::new_with_mles(
723-
num_threads,
724-
rt_tower.len(),
725-
all_witins,
726-
);
727-
728-
let alpha_pows_expr = (2..)
729-
.take(cs.r_table_expressions.len() + cs.w_table_expressions.len())
730-
.map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO))
731-
.collect_vec();
732-
let zero_check_expr: Expression<E> = cs
733-
.r_table_expressions
734-
.iter()
735-
.take(1)
736-
.chain(cs.w_table_expressions.iter().take(1))
737-
.zip_eq(&alpha_pows_expr)
738-
.map(|(expr, alpha)| alpha * expr.expr.expr())
739-
.sum();
740-
let zero_check_monomial = monomialize_expr_to_wit_terms(
741-
&zero_check_expr,
742-
cs.num_witin as WitnessId,
743-
cs.num_structural_witin as WitnessId,
744-
cs.num_fixed as WitnessId,
745-
);
746-
let main_sumcheck_challenges = chain!(
747-
challenges.iter().copied(),
748-
get_challenge_pows(
749-
cs.w_table_expressions.len() + cs.r_table_expressions.len(),
750-
transcript,
751-
)
752-
)
753-
.collect_vec();
754-
755-
let span = entered_span!("IOPProverState::prove", profiling_4 = true);
756-
let (proof, prover_state) = IOPProverState::prove(
757-
builder.to_virtual_polys_with_monomial_terms(
758-
&zero_check_monomial,
759-
&[],
760-
&main_sumcheck_challenges,
761-
),
762-
transcript,
763-
);
764-
exit_span!(span);
765-
let rt = prover_state
766-
.challenges
767-
.iter()
768-
.map(|c| c.elements)
769-
.collect_vec();
770-
let mut wits_in_evals = prover_state.get_mle_flatten_final_evaluations();
771-
let mut rest = wits_in_evals.split_off(cs.num_witin as usize);
772-
let rest = rest.split_off(cs.num_structural_witin as usize);
773-
let fixed_in_evals = rest;
774-
(wits_in_evals, fixed_in_evals, Some(proof.proofs), rt)
775-
};
684+
let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) = {
685+
let span = entered_span!("fixed::evals + witin::evals");
686+
let mut evals = input
687+
.witness
688+
.par_iter()
689+
.chain(input.fixed.par_iter())
690+
.map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()]))
691+
.collect::<Vec<_>>();
692+
let fixed_in_evals = evals.split_off(input.witness.len());
693+
let wits_in_evals = evals;
694+
exit_span!(span);
695+
(wits_in_evals, fixed_in_evals, None, rt_tower)
696+
};
776697

777698
Ok((
778699
rt,

ceno_zkvm/src/scheme/prover.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl<
210210
let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold(
211211
(vec![], vec![]),
212212
|(mut points, mut evaluations), (index, (circuit_name, pk))| {
213+
println!("prove circuit_name {circuit_name}");
213214
let num_instances = circuit_name_num_instances_mapping
214215
.get(&circuit_name)
215216
.copied()

ceno_zkvm/src/scheme/utils.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,16 @@ pub fn build_main_witness<
345345
}
346346

347347
if let Some(gkr_circuit) = gkr_circuit {
348-
// opcode must have at least one read/write/lookup
348+
// circuit must have at least one read/write/lookup
349349
assert!(
350-
cs.lk_expressions.is_empty()
351-
|| !cs.r_expressions.is_empty()
352-
|| !cs.w_expressions.is_empty(),
353-
"assert opcode circuit"
350+
cs.r_expressions.len()
351+
+ cs.w_expressions.len()
352+
+ cs.lk_expressions.len()
353+
+ cs.r_table_expressions.len()
354+
+ cs.w_table_expressions.len()
355+
+ cs.lk_table_expressions.len()
356+
> 0,
357+
"assert circuit"
354358
);
355359

356360
let (_, gkr_circuit_out) = gkr_witness::<E, PCS, PB, PD>(

ceno_zkvm/src/scheme/verifier.rs

Lines changed: 7 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
164164
for (index, proof) in &vm_proof.chip_proofs {
165165
assert!(proof.num_instances > 0);
166166
let circuit_name = &self.vk.circuit_index_to_name[index];
167+
println!("verify circuit_name {circuit_name}");
167168
let circuit_vk = &self.vk.circuit_vks[circuit_name];
168169

169170
// check chip proof is well-formed
@@ -356,9 +357,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
356357
} = &composed_cs;
357358
let num_instances = proof.num_instances;
358359
let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = (
359-
cs.r_expressions.len(),
360-
cs.w_expressions.len(),
361-
cs.lk_expressions.len(),
360+
cs.r_expressions.len() + cs.r_table_expressions.len(),
361+
cs.w_expressions.len() + cs.w_table_expressions.len(),
362+
cs.lk_expressions.len() + cs.lk_table_expressions.len() * 2,
362363
);
363364
let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance;
364365

@@ -529,9 +530,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
529530
"[prod_record] mismatch length"
530531
);
531532

532-
let input_opening_point = if next_pow2_instance_padding(proof.num_instances)
533-
== proof.num_instances
534-
{
533+
let ram_bus_circuit = false;
534+
let input_opening_point = if !ram_bus_circuit {
535535
// evaluate the evaluation of structural mles at input_opening_point by verifier
536536
let structural_evals = if with_rw {
537537
// only iterate r set, as read/write set round should match
@@ -620,73 +620,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
620620
}
621621
rt_tower
622622
} else {
623-
// TODO LocalFinalTable goes here, merge flow into gkr_iop
624-
assert_eq!(cs.lk_table_expressions.len(), 0);
625-
assert!(proof.main_sumcheck_proofs.is_some());
626-
assert_eq!(cs.num_structural_witin, 1);
627-
assert_eq!(prod_point_and_eval.len(), 1);
628-
629-
// verify opening same point layer sumcheck
630-
let alpha_pow = get_challenge_pows(
631-
cs.r_table_expressions.len() + cs.w_table_expressions.len(),
632-
transcript,
633-
);
634-
635-
// \sum_i alpha_{i} * (out_r_eval{i} - ONE)
636-
// + \sum_i alpha_{i} * (out_w_eval{i} - ONE)
637-
let claim_sum = prod_point_and_eval
638-
.iter()
639-
.zip_eq(alpha_pow.iter())
640-
.map(|(point_and_eval, alpha)| *alpha * (point_and_eval.eval - E::ONE))
641-
.sum::<E>();
642-
let sel_subclaim = IOPVerifierState::verify(
643-
claim_sum,
644-
&IOPProof {
645-
proofs: proof.main_sumcheck_proofs.clone().unwrap(),
646-
},
647-
&VPAuxInfo {
648-
max_degree: SEL_DEGREE,
649-
max_num_variables: expected_max_rounds,
650-
phantom: PhantomData,
651-
},
652-
transcript,
653-
);
654-
let (input_opening_point, sumcheck_eval) = (
655-
sel_subclaim.point.iter().map(|c| c.elements).collect_vec(),
656-
sel_subclaim.expected_evaluation,
657-
);
658-
let structural_evals = vec![eq_eval_less_or_equal_than(
659-
proof.num_instances - 1,
660-
&prod_point_and_eval[0].point,
661-
&input_opening_point,
662-
)];
663-
664-
let expected_evals = interleave(
665-
&cs.r_table_expressions, // r
666-
&cs.w_table_expressions, // w
667-
)
668-
.map(|rw| &rw.expr)
669-
.zip(alpha_pow.iter())
670-
.map(|(expr, alpha)| {
671-
*alpha
672-
* eval_by_expr_with_instance(
673-
&proof.fixed_in_evals,
674-
&proof.wits_in_evals,
675-
&structural_evals,
676-
pi,
677-
challenges,
678-
expr,
679-
)
680-
.right()
681-
.unwrap()
682-
})
683-
.sum::<E>();
684-
if expected_evals != sumcheck_eval {
685-
return Err(ZKVMError::VerifyError(
686-
"sel evaluation verify failed".into(),
687-
));
688-
}
689-
input_opening_point
623+
unimplemented!("shard ram bus circuit go here");
690624
};
691625

692626
// assume public io is tiny vector, so we evaluate it directly without PCS

ceno_zkvm/src/structs.rs

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ impl<E: ExtensionField> ComposedConstrainSystem<E> {
135135
}
136136

137137
pub fn is_opcode_circuit(&self) -> bool {
138-
self.zkvm_v1_css.lk_table_expressions.is_empty()
139-
&& self.zkvm_v1_css.r_table_expressions.is_empty()
140-
&& self.zkvm_v1_css.w_table_expressions.is_empty()
138+
self.gkr_circuit.is_some()
141139
}
142140

143141
/// return number of lookup operation
@@ -219,18 +217,13 @@ impl<E: ExtensionField> ZKVMConstraintSystem<E> {
219217
pub fn register_table_circuit<TC: TableCircuit<E>>(&mut self) -> TC::TableConfig {
220218
let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name()));
221219
let mut circuit_builder = CircuitBuilder::<E>::new(&mut cs);
222-
let config = TC::construct_circuit(&mut circuit_builder, &self.params).unwrap();
223-
assert!(
224-
self.circuit_css
225-
.insert(
226-
TC::name(),
227-
ComposedConstrainSystem {
228-
zkvm_v1_css: cs,
229-
gkr_circuit: None
230-
}
231-
)
232-
.is_none()
233-
);
220+
let (config, gkr_iop_circuit) =
221+
TC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap();
222+
let cs = ComposedConstrainSystem {
223+
zkvm_v1_css: cs,
224+
gkr_circuit: gkr_iop_circuit,
225+
};
226+
assert!(self.circuit_css.insert(TC::name(), cs).is_none());
234227
config
235228
}
236229

ceno_zkvm/src/tables/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams};
22
use ff_ext::ExtensionField;
33
use std::collections::HashMap;
4+
use itertools::Itertools;
5+
use multilinear_extensions::{StructuralWitInType, ToExpr};
46
use witness::RowMajorMatrix;
7+
use gkr_iop::chip::Chip;
8+
use gkr_iop::gkr::GKRCircuit;
9+
use gkr_iop::gkr::layer::Layer;
10+
use gkr_iop::selector::SelectorType;
11+
512
mod range;
613
pub use range::*;
714

@@ -29,6 +36,14 @@ pub trait TableCircuit<E: ExtensionField> {
2936
params: &ProgramParams,
3037
) -> Result<Self::TableConfig, ZKVMError>;
3138

39+
fn build_gkr_iop_circuit(
40+
cb: &mut CircuitBuilder<E>,
41+
param: &ProgramParams,
42+
) -> Result<(Self::TableConfig, Option<GKRCircuit<E>>), ZKVMError> {
43+
let config = Self::construct_circuit(cb, param)?;
44+
Ok((config, None))
45+
}
46+
3247
fn generate_fixed_traces(
3348
config: &Self::TableConfig,
3449
num_fixed: usize,

0 commit comments

Comments
 (0)