Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/lints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,5 @@ jobs:
run: taplo --version || cargo install taplo-cli
- name: Run taplo
run: taplo fmt --check --diff
- name: Ensure Cargo.lock not modified by build
run: git diff --exit-code Cargo.lock
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

91 changes: 81 additions & 10 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ Hints:
.into()
};

for ((w_rlc_expr, annotation), _) in (cs
for ((w_rlc_expr, annotation), (ram_type_expr, _)) in (cs
.w_expressions
.iter()
.chain(cs.w_table_expressions.iter().map(|expr| &expr.expr)))
Expand All @@ -1157,8 +1157,19 @@ Hints:
.chain(cs.w_table_expressions_namespace_map.iter()),
)
.zip_eq(cs.w_ram_types.iter())
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
{
let ram_type_mle = wit_infer_by_expr(
ram_type_expr,
cs.num_witin,
cs.num_structural_witin,
cs.num_fixed as WitnessId,
fixed,
witness,
structural_witness,
&pi_mles,
&challenges,
);
let ram_type_vec = ram_type_mle.get_ext_field_vec();
let write_rlc_records = wit_infer_by_expr(
w_rlc_expr,
cs.num_witin,
Expand All @@ -1170,13 +1181,32 @@ Hints:
&pi_mles,
&challenges,
);
let w_selector_vec = w_selector.get_base_field_vec();
let write_rlc_records =
filter_mle_by_selector_mle(write_rlc_records, w_selector.clone());
filter_mle_by_predicate(write_rlc_records, |i, _v| {
ram_type_vec[i] == E::from_canonical_u32($ram_type as u32)
&& w_selector_vec[i] == E::BaseField::ONE
});
if write_rlc_records.is_empty() {
continue;
}

let mut records = vec![];
let mut writes_within_expr_dedup = HashSet::new();
for (row, record_rlc) in enumerate(write_rlc_records) {
// TODO: report error
assert_eq!(writes.insert(record_rlc), true);
assert_eq!(
writes_within_expr_dedup.insert(record_rlc),
true,
"within expression write duplicated on RAMType {:?}",
$ram_type
);
assert_eq!(
writes.insert(record_rlc),
true,
"crossing-chip write duplicated on RAMType {:?}",
$ram_type
);
records.push((record_rlc, row));
}
writes_grp_by_annotations
Expand Down Expand Up @@ -1212,7 +1242,7 @@ Hints:
)
.into()
};
for ((r_rlc_expr, annotation), (_, r_exprs)) in (cs
for ((r_rlc_expr, annotation), (ram_type_expr, r_exprs)) in (cs
.r_expressions
.iter()
.chain(cs.r_table_expressions.iter().map(|expr| &expr.expr)))
Expand All @@ -1222,8 +1252,19 @@ Hints:
.chain(cs.r_table_expressions_namespace_map.iter()),
)
.zip_eq(cs.r_ram_types.iter())
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
{
let ram_type_mle = wit_infer_by_expr(
ram_type_expr,
cs.num_witin,
cs.num_structural_witin,
cs.num_fixed as WitnessId,
fixed,
witness,
structural_witness,
&pi_mles,
&challenges,
);
let ram_type_vec = ram_type_mle.get_ext_field_vec();
let read_records = wit_infer_by_expr(
r_rlc_expr,
cs.num_witin,
Expand All @@ -1235,8 +1276,14 @@ Hints:
&pi_mles,
&challenges,
);
let read_records =
filter_mle_by_selector_mle(read_records, r_selector.clone());
let r_selector_vec = r_selector.get_base_field_vec();
let read_records = filter_mle_by_predicate(read_records, |i, _v| {
ram_type_vec[i] == E::from_canonical_u32($ram_type as u32)
&& r_selector_vec[i] == E::BaseField::ONE
});
if read_records.is_empty() {
continue;
}

if $ram_type == RAMType::GlobalState {
// r_exprs = [GlobalState, pc, timestamp]
Expand Down Expand Up @@ -1269,9 +1316,21 @@ Hints:
};

let mut records = vec![];
let mut reads_within_expr_dedup = HashSet::new();
for (row, record) in enumerate(read_records) {
// TODO: return error
assert_eq!(reads.insert(record), true);
assert_eq!(
reads_within_expr_dedup.insert(record),
true,
"within expression read duplicated on RAMType {:?}",
$ram_type
);
assert_eq!(
reads.insert(record),
true,
"crossing-chip read duplicated on RAMType {:?}",
$ram_type
);
records.push((record, row));
}
reads_grp_by_annotations
Expand Down Expand Up @@ -1467,6 +1526,19 @@ fn print_errors<E: ExtensionField, K: LkMultiplicityKey>(
}
}

fn filter_mle_by_predicate<E, F>(target_mle: ArcMultilinearExtension<E>, mut predicate: F) -> Vec<E>
where
E: ExtensionField,
F: FnMut(usize, &E) -> bool,
{
target_mle
.get_ext_field_vec()
.iter()
.enumerate()
.filter_map(|(i, v)| if predicate(i, v) { Some(*v) } else { None })
.collect_vec()
}

fn filter_mle_by_selector_mle<E: ExtensionField>(
target_mle: ArcMultilinearExtension<E>,
selector: ArcMultilinearExtension<E>,
Expand All @@ -1487,7 +1559,6 @@ fn filter_mle_by_selector_mle<E: ExtensionField>(

#[cfg(test)]
mod tests {

use super::*;
use crate::{
ROMType,
Expand Down
94 changes: 78 additions & 16 deletions gkr_iop/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ pub struct ConstraintSystem<E: ExtensionField> {
pub r_expressions_namespace_map: Vec<String>,
// for each read expression we store its ram type and original value before doing RLC
// the original value will be used for debugging
pub r_ram_types: Vec<(RAMType, Vec<Expression<E>>)>,
pub r_ram_types: Vec<(Expression<E>, Vec<Expression<E>>)>,

pub w_selector: Option<SelectorType<E>>,
pub w_expressions: Vec<Expression<E>>,
pub w_expressions_namespace_map: Vec<String>,
// for each write expression we store its ram type and original value before doing RLC
// the original value will be used for debugging
pub w_ram_types: Vec<(RAMType, Vec<Expression<E>>)>,
pub w_ram_types: Vec<(Expression<E>, Vec<Expression<E>>)>,

/// init/final ram expression
pub r_table_expressions: Vec<SetTableExpression<E>>,
Expand Down Expand Up @@ -329,12 +329,27 @@ impl<E: ExtensionField> ConstraintSystem<E> {
N: FnOnce() -> NR,
{
let rlc_record = self.rlc_chip_record(record.clone());
assert_eq!(
rlc_record.degree(),
1,
"rlc record degree {} != 1",
rlc_record.degree()
);
self.r_table_rlc_record(
name_fn,
(ram_type as u64).into(),
table_spec,
record,
rlc_record,
)
}

pub fn r_table_rlc_record<NR, N>(
&mut self,
name_fn: N,
ram_type: Expression<E>,
table_spec: SetTableSpec,
record: Vec<Expression<E>>,
rlc_record: Expression<E>,
) -> Result<(), CircuitBuilderError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.r_table_expressions.push(SetTableExpression {
expr: rlc_record,
table_spec,
Expand All @@ -358,12 +373,27 @@ impl<E: ExtensionField> ConstraintSystem<E> {
N: FnOnce() -> NR,
{
let rlc_record = self.rlc_chip_record(record.clone());
assert_eq!(
rlc_record.degree(),
1,
"rlc record degree {} != 1",
rlc_record.degree()
);
self.w_table_rlc_record(
name_fn,
(ram_type as u64).into(),
table_spec,
record,
rlc_record,
)
}

pub fn w_table_rlc_record<NR, N>(
&mut self,
name_fn: N,
ram_type: Expression<E>,
table_spec: SetTableSpec,
record: Vec<Expression<E>>,
rlc_record: Expression<E>,
) -> Result<(), CircuitBuilderError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.w_table_expressions.push(SetTableExpression {
expr: rlc_record,
table_spec,
Expand All @@ -387,7 +417,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
self.r_expressions_namespace_map.push(path);
// Since r_expression is RLC(record) and when we're debugging
// it's helpful to recover the value of record itself.
self.r_ram_types.push((ram_type, record));
self.r_ram_types.push(((ram_type as u64).into(), record));
Ok(())
}

Expand All @@ -401,7 +431,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
self.w_expressions.push(rlc_record);
let path = self.ns.compute_path(name_fn().into());
self.w_expressions_namespace_map.push(path);
self.w_ram_types.push((ram_type, record));
self.w_ram_types.push(((ram_type as u64).into(), record));
Ok(())
}

Expand Down Expand Up @@ -579,6 +609,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
.r_table_record(name_fn, ram_type, table_spec, record)
}

pub fn r_table_rlc_record<NR, N>(
&mut self,
name_fn: N,
ram_type: Expression<E>,
table_spec: SetTableSpec,
record: Vec<Expression<E>>,
rlc_record: Expression<E>,
) -> Result<(), CircuitBuilderError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.cs
.r_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record)
}

pub fn w_table_record<NR, N>(
&mut self,
name_fn: N,
Expand All @@ -594,6 +640,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
.w_table_record(name_fn, ram_type, table_spec, record)
}

pub fn w_table_rlc_record<NR, N>(
&mut self,
name_fn: N,
ram_type: Expression<E>,
table_spec: SetTableSpec,
record: Vec<Expression<E>>,
rlc_record: Expression<E>,
) -> Result<(), CircuitBuilderError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.cs
.w_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record)
}

pub fn read_record<NR, N>(
&mut self,
name_fn: N,
Expand Down