From 24a9cba2cbc97d2b20d53abbe356186a39ee40ac Mon Sep 17 00:00:00 2001 From: leonace924 Date: Fri, 9 Jan 2026 12:19:02 -0500 Subject: [PATCH] perf(epoch): reduce memory allocation in weight divergence check --- crates/epoch/src/commit_reveal.rs | 131 +++++++++++++++--------------- 1 file changed, 64 insertions(+), 67 deletions(-) diff --git a/crates/epoch/src/commit_reveal.rs b/crates/epoch/src/commit_reveal.rs index e8aa622a..535d679a 100644 --- a/crates/epoch/src/commit_reveal.rs +++ b/crates/epoch/src/commit_reveal.rs @@ -130,30 +130,40 @@ impl CommitRevealState { ); } - // Collect valid submissions - let submissions: Vec> = - self.reveals.values().map(|r| r.weights.clone()).collect(); - - if submissions.len() < min_validators { + // Check we have enough validators (without cloning) + let reveal_count = self.reveals.len(); + if reveal_count < min_validators { return Err(CommitRevealError::InsufficientValidators { required: min_validators, - got: submissions.len(), + got: reveal_count, }); } - // Validate that all submissions are consistent + // Get first submission's weights (we need this for the final output) + let first = self + .reveals + .values() + .next() + .map(|r| &r.weights) + .ok_or(CommitRevealError::InsufficientValidators { + required: min_validators, + got: 0, + })?; + + // Validate that all submissions are consistent using references (no cloning) // All validators read from shared chain DB, so submissions should be identical - let first = &submissions[0]; - let divergence_detected = self.check_submission_divergence(&submissions); + let divergence_detected = + check_submission_divergence(self.reveals.values().map(|r| &r.weights)); if divergence_detected { error!( "Epoch {}: Weight submissions diverged across {} validators! Using first submission.", self.epoch, - submissions.len() + reveal_count ); } + // Clone only the first submission for the final output let aggregated = weights::normalize_weights(first.clone()); let participating: Vec = self.reveals.keys().cloned().collect(); @@ -199,69 +209,56 @@ impl CommitRevealState { self.reveals.contains_key(validator) } - /// Check if submissions from different validators have diverged. - /// Returns true if divergence is detected. - fn check_submission_divergence(&self, submissions: &[Vec]) -> bool { - if submissions.len() <= 1 { - return false; - } - - let first = &submissions[0]; - - // Build a map of hotkey -> weight for the first submission - let first_weights: HashMap<&str, f64> = first - .iter() - .map(|w| (w.hotkey.as_str(), w.weight)) - .collect(); - - // Tolerance for floating-point comparison (0.1% difference allowed) - const WEIGHT_TOLERANCE: f64 = 0.001; - - for (idx, submission) in submissions.iter().enumerate().skip(1) { - // Check if same number of weight assignments - if submission.len() != first.len() { - warn!( - "Epoch {}: Submission {} has {} weights, first has {}", - self.epoch, - idx, - submission.len(), - first.len() - ); - return true; - } +} - // Check if same hotkeys with similar weights - for weight in submission { - match first_weights.get(weight.hotkey.as_str()) { - None => { - warn!( - "Epoch {}: Submission {} has hotkey {} not in first submission", - self.epoch, - idx, - &weight.hotkey[..16.min(weight.hotkey.len())] - ); - return true; - } - Some(&first_weight) => { - let diff = (weight.weight - first_weight).abs(); - if diff > WEIGHT_TOLERANCE { - warn!( - "Epoch {}: Weight divergence for hotkey {}: {} vs {} (diff: {:.4})", - self.epoch, - &weight.hotkey[..16.min(weight.hotkey.len())], - first_weight, - weight.weight, - diff - ); - return true; - } - } +/// Tolerance for floating-point weight comparison (0.1% difference allowed) +const WEIGHT_TOLERANCE: f64 = 0.001; + +/// Check if two weight vectors match within tolerance. +/// Returns true if weights match, false if they diverge. +fn weights_match(first: &[WeightAssignment], second: &[WeightAssignment], tolerance: f64) -> bool { + if first.len() != second.len() { + return false; + } + + // Build a map of hotkey -> weight for the first submission + let first_weights: HashMap<&str, f64> = first + .iter() + .map(|w| (w.hotkey.as_str(), w.weight)) + .collect(); + + // Check all weights in second exist in first with similar values + for weight in second { + match first_weights.get(weight.hotkey.as_str()) { + None => return false, + Some(&first_weight) => { + if (weight.weight - first_weight).abs() > tolerance { + return false; } } } + } - false + true +} + +/// Check if submissions from different validators have diverged (reference-based, no cloning). +/// Uses an iterator of references to avoid O(N * M) memory allocation. +fn check_submission_divergence<'a>( + mut submissions: impl Iterator>, +) -> bool { + let first = match submissions.next() { + Some(f) => f, + None => return false, + }; + + for submission in submissions { + if !weights_match(first, submission, WEIGHT_TOLERANCE) { + return true; + } } + + false } /// Errors for commit-reveal