From d9b4cd5d873e7e6c3c24d783d431cef3563e0453 Mon Sep 17 00:00:00 2001 From: Prasad Kumkar Date: Wed, 11 Feb 2026 20:48:36 +0530 Subject: [PATCH 1/3] feat: cythonised poly operations, added 1024 ring size benchmarks, improved ring & root class structure --- .dockerignore | 104 ++++++-- Dockerfile | 43 ++- docs/BENCHMARK.md | 31 ++- dot_ring/__init__.py | 3 + dot_ring/ring_proof/columns/columns.py | 68 +---- dot_ring/ring_proof/params.py | 67 +++++ dot_ring/ring_proof/pcs/kzg.py | 2 +- dot_ring/ring_proof/polynomial/ops.py | 76 +----- dot_ring/ring_proof/polynomial/poly_ops.pyx | 244 ++++++++++++++++++ dot_ring/ring_proof/proof/aggregation_poly.py | 4 +- .../ring_proof/proof/linearization_poly.py | 19 +- dot_ring/scripts/export_python_proof.py | 18 +- dot_ring/vrf/ring/ring_root.py | 86 +++++- dot_ring/vrf/ring/ring_vrf.py | 138 ++++------ pyproject.toml | 2 +- setup.py | 6 + tests/benchmark/bench_ring_large.py | 179 +++++++++++++ tests/benchmark/bench_ring_proof.py | 23 +- tests/benchmark/test_bench_ring.py | 18 +- tests/benchmark/test_bench_ring_proof.py | 12 +- tests/test_bandersnatch_ark.py | 12 +- tests/test_coverage/test_columns.py | 32 ++- tests/test_coverage/test_ops.py | 7 +- tests/test_curve_ops/test_gaps.py | 55 ++-- tests/test_ring_vrf/test_ring_vrf.py | 12 +- tests/test_vectors.py | 29 ++- 26 files changed, 914 insertions(+), 376 deletions(-) create mode 100644 dot_ring/ring_proof/polynomial/poly_ops.pyx create mode 100644 tests/benchmark/bench_ring_large.py diff --git a/.dockerignore b/.dockerignore index ceafc86..42c7a18 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,23 +1,93 @@ +# Git .git .gitignore -.dockerignore -.blst -.venv +.gitattributes +.github + +# Python __pycache__ -*.pyc -*.pyo -*.pyd +*.py[cod] +*$py.class +*.so .Python -env -venv -ENV -env.bak -venv.bak -build -dist -*.egg-info -.mypy_cache -.pytest_cache +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +testenv/ + +# Testing +.pytest_cache/ .coverage -htmlcov +.coverage.* +htmlcov/ +.tox/ +.nox/ +coverage.xml +*.cover +.hypothesis/ + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo +*.swn .DS_Store + +# Build artifacts +.blst/ +*.o +*.a +*.so +*.dylib +*.dll +*.pyd +*.pyc + +# Documentation +docs/_build/ + +# Output and temporary files +output/ +*.log +*.tmp +perf/ + +# Docker +Dockerfile +docker-compose.yml +.dockerignore + +# CI/CD +.github/ +.gitlab-ci.yml +.travis.yml + +# Other +*.md +!README.md +LICENSE diff --git a/Dockerfile b/Dockerfile index 6e79e4d..8a99b30 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,11 @@ -# Base Python image FROM python:3.12-slim -# Set working directory WORKDIR /app -# Install system dependencies needed for blst Python binding and Cython -RUN apt-get update && apt-get install -y \ - git \ - gcc \ - g++ \ - python3-dev \ - make \ - swig \ - libgmp-dev \ - libmpfr-dev \ - libmpc-dev \ +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + git gcc g++ make swig \ + python3-dev libgmp-dev libmpfr-dev libmpc-dev \ && rm -rf /var/lib/apt/lists/* # Install uv @@ -23,18 +14,20 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv # Copy project files COPY . . -# Install dependencies -RUN uv sync --extra dev +# Set version for setuptools-scm +ENV SETUPTOOLS_SCM_PRETEND_VERSION=1.0.0 -# Setup environment (blst + cython) -RUN uv run python scripts/setup_env.py +# Install build tools first +RUN uv pip install --system setuptools wheel Cython build + +# Build blst and Cython extensions +RUN python scripts/setup_env.py + +# Install project with all dependencies from pyproject.toml +RUN uv pip install --system --no-build-isolation -e ".[dev]" # Run tests -RUN uv run pytest tests/ \ - --cov=dot_ring \ - --cov-report=term-missing \ - --cov-report=html \ - -v \ - --tb=short - -CMD ["uv", "run", "python"] +RUN uv run pytest tests/ -v --tb=short + +# Default: run tests +CMD ["uv", "run", "pytest", "tests/", "-v"] diff --git a/docs/BENCHMARK.md b/docs/BENCHMARK.md index 58350cf..eb47ca8 100644 --- a/docs/BENCHMARK.md +++ b/docs/BENCHMARK.md @@ -36,15 +36,25 @@ VRF with Pedersen commitment for public key blinding. ## Ring VRF -Ring VRF with SNARK-based ring membership proof (8-member ring). +Ring VRF with SNARK-based ring membership proof. +**Proof size**: 784 bytes (constant across all ring sizes) + +### 8-member ring (domain size: 512) + +| Operation | Min | Mean | Stddev | +|-----------|-----|------|--------| +| Ring Root Construction | 28.07 ms | 28.28 ms | 0.14 ms | +| Proof Generation | 153.35 ms | 155.18 ms | 1.42 ms | +| Verification | 4.05 ms | 4.35 ms | 0.19 ms | + +### 1023-member ring (domain size: 2048) | Operation | Min | Mean | Stddev | |-----------|-----|------|--------| -| Ring Root Construction | 28.11 ms | 28.54 ms | 0.42 ms | -| Proof Generation | 251.16 ms | 253.68 ms | 1.88 ms | -| Verification | 3.98 ms | 4.16 ms | 0.12 ms | +| Ring Root Construction | 330.76 ms | 334.71 ms | 5.07 ms | +| Proof Generation | 525.28 ms | 543.04 ms | 29.13 ms | +| Verification | 4.09 ms | 4.22 ms | 0.14 ms | -**Proof size**: 784 bytes --- @@ -52,11 +62,14 @@ Ring VRF with SNARK-based ring membership proof (8-member ring). ```bash # IETF VRF -uv run python tests/bench_ietf.py +uv run python tests/benchmark/bench_ietf.py # Pedersen VRF -uv run python tests/bench_pedersen.py +uv run python tests/benchmark/bench_pedersen.py + +# Ring VRF (8-member ring, domain size 512) +uv run python tests/benchmark/bench_ring_proof.py -# Ring VRF -uv run python tests/bench_ring_proof.py +# Ring VRF (1023-member ring, domain size 2048) +uv run python tests/benchmark/bench_ring_large.py ``` \ No newline at end of file diff --git a/dot_ring/__init__.py b/dot_ring/__init__.py index 74e07b5..0263618 100644 --- a/dot_ring/__init__.py +++ b/dot_ring/__init__.py @@ -51,6 +51,7 @@ from dot_ring.keygen import secret_from_seed from dot_ring.vrf.ietf.ietf import IETF_VRF from dot_ring.vrf.pedersen.pedersen import PedersenVRF +from dot_ring.vrf.ring.ring_root import Ring, RingRoot from dot_ring.vrf.ring.ring_vrf import RingVRF # ============================================================================= @@ -116,4 +117,6 @@ "BLS12_381_G2_RO", "BLS12_381_G2_NU", "secret_from_seed", + "Ring", + "RingRoot", ] diff --git a/dot_ring/ring_proof/columns/columns.py b/dot_ring/ring_proof/columns/columns.py index c257a18..9c5bfbc 100644 --- a/dot_ring/ring_proof/columns/columns.py +++ b/dot_ring/ring_proof/columns/columns.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import cast -from dot_ring.ring_proof.constants import DEFAULT_SIZE, MAX_RING_SIZE, OMEGAS, S_PRIME, Blinding_Base, PaddingPoint, SeedPoint +from dot_ring.ring_proof.constants import DEFAULT_SIZE, MAX_RING_SIZE, OMEGAS, S_PRIME, SeedPoint from dot_ring.ring_proof.curve.bandersnatch import TwistedEdwardCurve as TE from dot_ring.ring_proof.helpers import Helpers as H from dot_ring.ring_proof.params import RingProofParams @@ -42,72 +42,6 @@ def commit(self) -> None: self.commitment = KZG.commit(self.coeffs) -@dataclass(slots=True) -class PublicColumnBuilder: - size: int = DEFAULT_SIZE - prime: int = S_PRIME - omega: int = OMEGAS[DEFAULT_SIZE] - max_ring_size: int = MAX_RING_SIZE - padding_rows: int = 4 - - @classmethod - def from_params(cls, params: RingProofParams) -> PublicColumnBuilder: - return cls( - size=params.domain_size, - prime=params.prime, - omega=params.omega, - max_ring_size=params.max_ring_size, - padding_rows=params.padding_rows, - ) - - def _pad_ring_with_padding_point(self, pk_ring: list[tuple[int, int]]) -> list[tuple[int, int]]: - """Pad ring in‑place with the special padding point until size.""" - # padding_sw = sw.from_twisted_edwards(PaddingPoint) - padding_sw = PaddingPoint - while len(pk_ring) < self.max_ring_size: - pk_ring.append(padding_sw) - return pk_ring - - def _h_vector(self, blinding_base: tuple[int, int] = Blinding_Base) -> list[tuple[int, int]]: - """Return `[2⁰·H, 2¹·H, …]` in short‑Weierstrass coords.""" - # sw_bb = sw.from_twisted_edwards(blinding_base) - sw_bb = blinding_base - # print("Blinding Base:",sw_bb) - res = [cast(tuple[int, int], TE.mul(pow(2, i, S_PRIME), sw_bb)) for i in range(self.size)] - return res # B_Neck - - def build(self, ring_pk: list[tuple[int, int]]) -> tuple[Column, Column, Column]: - """Return (Px, Py, s) columns fully committed.""" - if len(ring_pk) < self.max_ring_size: - ring_pk = self._pad_ring_with_padding_point(ring_pk) - if len(ring_pk) > self.size - self.padding_rows: - raise ValueError(f"ring size {len(ring_pk)} exceeds max supported size {self.size - self.padding_rows}") - # 1. ensure ring size - fill_count = self.size - self.padding_rows - len(ring_pk) - if fill_count > 0: - if self.size == len(_H_VEC_DEFAULT): - h_vec = _H_VEC_DEFAULT - else: - h_vec = self._h_vector() - ring_pk.extend(h_vec[:fill_count]) - if self.padding_rows > 0: - ring_pk.extend([(0, 0)] * self.padding_rows) - - # 2. unzip into x/y vectors - px, py = H.unzip(ring_pk) - - # 3. selector vector - sel = [1 if i < self.max_ring_size else 0 for i in range(self.size)] - - # 4. Columns - col_px = Column("Px", px, size=self.size) - col_py = Column("Py", py, size=self.size) - col_s = Column("s", sel, size=self.size) - for col in (col_px, col_py, col_s): - col.interpolate(self.omega, self.prime) - col.commit() - return col_px, col_py, col_s - @dataclass(slots=True) class WitnessColumnBuilder: diff --git a/dot_ring/ring_proof/params.py b/dot_ring/ring_proof/params.py index 5ee8d4a..855734d 100644 --- a/dot_ring/ring_proof/params.py +++ b/dot_ring/ring_proof/params.py @@ -2,7 +2,10 @@ from dataclasses import dataclass from functools import lru_cache +from typing import ClassVar +from dot_ring.curve.curve import CurveVariant +from dot_ring.curve.specs.bandersnatch import Bandersnatch from dot_ring.ring_proof.constants import D_2048, DEFAULT_SIZE, MAX_RING_SIZE, OMEGA_2048, S_PRIME @@ -10,6 +13,15 @@ def _is_power_of_two(n: int) -> bool: return n > 0 and (n & (n - 1)) == 0 +def _next_power_of_two(n: int) -> int: + """Find the next power of 2 greater than or equal to n.""" + if n <= 0: + return 1 + if _is_power_of_two(n): + return n + return 1 << (n.bit_length()) + + def _omega_for_domain(domain_size: int, prime: int = S_PRIME, base_root: int = OMEGA_2048, base_size: int = 2048) -> int: if base_size % domain_size != 0: raise ValueError(f"Domain size {domain_size} must divide {base_size}") @@ -90,6 +102,7 @@ class RingProofParams: prime: int = S_PRIME base_root: int = OMEGA_2048 base_root_size: int = 2048 + cv: ClassVar[CurveVariant] = Bandersnatch def __post_init__(self) -> None: radix_domain_size = self.radix_domain_size @@ -144,3 +157,57 @@ def last_index(self) -> int: @property def max_effective_ring_size(self) -> int: return self.domain_size - self.padding_rows + + @classmethod + def from_ring_size( + cls, + ring_size: int, + padding_rows: int = 4, + prime: int = S_PRIME, + base_root: int = OMEGA_2048, + base_root_size: int = 2048, + ) -> RingProofParams: + """ + Automatically construct RingProofParams based on ring size. + + Calculates the minimum domain size needed to accommodate the ring + and constructs appropriate parameters. + + Args: + ring_size: Number of members in the ring + padding_rows: Number of padding rows (default: 4) + prime: Field prime (default: S_PRIME) + base_root: Base root of unity (default: OMEGA_2048) + base_root_size: Base root size (default: 2048) + + Returns: + RingProofParams configured for the given ring size + """ + if ring_size <= 0: + raise ValueError(f"ring_size must be positive, got {ring_size}") + + # Calculate minimum domain size needed: + # domain_size >= ring_size + padding_rows + min_domain_size = ring_size + padding_rows + + # Round up to next power of 2 + domain_size = _next_power_of_two(min_domain_size) + + # Ensure domain_size is reasonable (between 16 and 8192) + if domain_size < 16: + domain_size = 16 + elif domain_size > 8192: + raise ValueError( + f"Ring size {ring_size} requires domain size {domain_size}, " + f"which exceeds maximum supported size of 8192. " + f"Maximum ring size is {8192 - padding_rows}." + ) + + return cls( + domain_size=domain_size, + max_ring_size=ring_size, + padding_rows=padding_rows, + prime=prime, + base_root=base_root, + base_root_size=base_root_size, + ) diff --git a/dot_ring/ring_proof/pcs/kzg.py b/dot_ring/ring_proof/pcs/kzg.py index d668020..279cbe9 100644 --- a/dot_ring/ring_proof/pcs/kzg.py +++ b/dot_ring/ring_proof/pcs/kzg.py @@ -8,7 +8,7 @@ import dot_ring.blst as _blst # type: ignore[import-untyped] -from ..polynomial.ops import poly_evaluate_single +from ..polynomial.poly_ops import poly_evaluate_single from .pairing import blst_final_verify, blst_miller_loop from .srs import G1Point, srs from .utils import CoeffVector, Scalar, blst_p1_to_fq_tuple, g1_to_blst, synthetic_div diff --git a/dot_ring/ring_proof/polynomial/ops.py b/dot_ring/ring_proof/polynomial/ops.py index e1d1000..451285f 100644 --- a/dot_ring/ring_proof/polynomial/ops.py +++ b/dot_ring/ring_proof/polynomial/ops.py @@ -3,6 +3,14 @@ from dot_ring.ring_proof.constants import D_512, D_2048, OMEGA_2048 from dot_ring.ring_proof.constants import OMEGA_512 as OMEGA from dot_ring.ring_proof.polynomial.fft import evaluate_poly_fft, inverse_fft +from dot_ring.ring_proof.polynomial.poly_ops import ( + poly_evaluate_single, + poly_mul_linear, + poly_multiply_naive, +) +from dot_ring.ring_proof.polynomial.poly_ops import ( + poly_scalar_mul as poly_scalar_mul_fast, +) def mod_inverse(val: int, prime: int) -> int: @@ -12,33 +20,6 @@ def mod_inverse(val: int, prime: int) -> int: return pow(val, prime - 2, prime) -def poly_add(poly1: list | Sequence[int], poly2: list | Sequence[int], prime: int) -> list[int]: - """Add two polynomials in a prime field.""" - # Make them the same length - result_len = max(len(poly1), len(poly2)) - result = [0] * result_len - - for i in range(len(poly1)): - result[i] = poly1[i] - for i in range(len(poly2)): - result[i] = (result[i] + poly2[i]) % prime - return result - - -def poly_subtract(poly1: list[int], poly2: list[int], prime: int) -> list[int]: - """Subtract poly2 from poly1 in a prime field.""" - # Make them the same length - result_len = max(len(poly1), len(poly2)) - result = [0] * result_len - - for i in range(len(poly1)): - result[i] = poly1[i] - - for i in range(len(poly2)): - result[i] = (result[i] - poly2[i]) % prime - return result - - GENERATOR = 5 _root_of_unity_cache: dict[tuple[int, int], int] = {} @@ -84,11 +65,7 @@ def poly_multiply(poly1: list[int], poly2: list[int], prime: int) -> list[int]: # Truncate to expected length (higher terms should be 0) return coeffs[:result_len] - result = [0] * result_len - for i in range(len(poly1)): - for j in range(len(poly2)): - result[i + j] = (result[i + j] + poly1[i] * poly2[j]) % prime - return result + return poly_multiply_naive(poly1, poly2, prime) def poly_division_general(coeffs: list[int], domain_size: int) -> list[int]: @@ -127,18 +104,6 @@ def poly_division_general(coeffs: list[int], domain_size: int) -> list[int]: return quotient -def poly_scalar(poly: list | Sequence[int], scalar: int, prime: int) -> list[int]: - """Multiply a polynomial by a scalar in a prime field.""" - result = [(coef * scalar) % prime for coef in poly] - return result - - -def poly_evaluate_single(poly: list | Sequence[int], x: int, prime: int) -> int: - result = 0 - for coef in reversed(poly): - result = (result * x + coef) % prime - return result - def poly_evaluate(poly: list | Sequence[int], xs: list | int | Sequence[int], prime: int) -> list[int] | int: """Evaluate polynomial at points xs. @@ -172,27 +137,6 @@ def poly_evaluate(poly: list | Sequence[int], xs: list | int | Sequence[int], pr return results -def poly_mul_linear(poly: list[int], a: int, b: int, prime: int) -> list[int]: - """Multiply poly by (ax + b) in O(n) time.""" - # result = poly * (ax + b) = a * (poly * x) + b * poly - # poly * x is [0] + poly - # result[i] = a * poly[i-1] + b * poly[i] - - n = len(poly) - result = [0] * (n + 1) - - # Handle first element (i=0): result[0] = b * poly[0] - result[0] = (b * poly[0]) % prime - - for i in range(1, n): - result[i] = (a * poly[i - 1] + b * poly[i]) % prime - - # Handle last element (i=n): result[n] = a * poly[n-1] - result[n] = (a * poly[n - 1]) % prime - - return result - - def lagrange_basis_polynomial(x_coords: list[int], i: int, prime: int) -> list[int]: """ Compute the i-th Lagrange basis polynomial. @@ -239,7 +183,7 @@ def lagrange_basis_polynomial(x_coords: list[int], i: int, prime: int) -> list[i # inv_denominator = mod_inverse(denominator, prime) inv_denominator = pow(denominator, -1, prime) # Scale the numerator polynomial - basis_poly = poly_scalar(numerator, inv_denominator, prime) + basis_poly = poly_scalar_mul_fast(numerator, inv_denominator, prime) return basis_poly diff --git a/dot_ring/ring_proof/polynomial/poly_ops.pyx b/dot_ring/ring_proof/polynomial/poly_ops.pyx new file mode 100644 index 0000000..db576db --- /dev/null +++ b/dot_ring/ring_proof/polynomial/poly_ops.pyx @@ -0,0 +1,244 @@ +# cython: language_level=3 +# cython: boundscheck=False +# cython: wraparound=False +# cython: cdivision=True +# cython: initializedcheck=False + +def poly_add(list poly1, list poly2, object prime): + """Add two polynomials in a prime field (optimized Cython version).""" + cdef int len1 = len(poly1) + cdef int len2 = len(poly2) + cdef int result_len = max(len1, len2) + cdef int i + cdef object val + cdef list result = [0] * result_len + + # Add poly1 coefficients + for i in range(len1): + result[i] = poly1[i] + + # Add poly2 coefficients + for i in range(len2): + val = result[i] + poly2[i] + if val >= prime: + val -= prime + result[i] = val + + return result + + +def poly_subtract(list poly1, list poly2, object prime): + """Subtract poly2 from poly1 in a prime field (optimized Cython version).""" + cdef int len1 = len(poly1) + cdef int len2 = len(poly2) + cdef int result_len = max(len1, len2) + cdef int i + cdef object val, val1, val2 + cdef list result = [0] * result_len + + # Add poly1 coefficients + for i in range(len1): + result[i] = poly1[i] + + # Subtract poly2 coefficients + for i in range(len2): + val1 = result[i] + val2 = poly2[i] + if val1 >= val2: + result[i] = val1 - val2 + else: + result[i] = prime - (val2 - val1) + + return result + + +def poly_scalar_mul(list poly, object scalar, object prime): + """Multiply a polynomial by a scalar in a prime field (optimized Cython version).""" + cdef int n = len(poly) + cdef int i + cdef object coef + cdef list result = [None] * n + + # Normalize scalar + if scalar >= prime: + scalar = scalar % prime + + # Special case: scalar is 0 or 1 + if scalar == 0: + return [0] * n + elif scalar == 1: + return [poly[i] % prime if poly[i] >= prime else poly[i] for i in range(n)] + + # General case + for i in range(n): + coef = poly[i] + if coef >= prime: + coef = coef % prime + result[i] = (coef * scalar) % prime + + return result + + +def poly_evaluate_single(list poly, object x, object prime): + """Evaluate polynomial at point x using Horner's method (optimized Cython version).""" + cdef int n = len(poly) + cdef int i + cdef object result = 0 + cdef object coef + + # Normalize x + if x >= prime: + x = x % prime + + # Horner's method: evaluate from highest degree to lowest + for i in range(n - 1, -1, -1): + coef = poly[i] + if coef >= prime: + coef = coef % prime + result = (result * x + coef) % prime + + return result + + +def poly_multiply_naive(list poly1, list poly2, object prime): + """Multiply two polynomials using naive O(n²) algorithm (optimized for small polynomials).""" + cdef int len1 = len(poly1) + cdef int len2 = len(poly2) + cdef int result_len = len1 + len2 - 1 + cdef int i, j + cdef object val1, val2, prod, current + cdef list result = [0] * result_len + + for i in range(len1): + val1 = poly1[i] + if val1 >= prime: + val1 = val1 % prime + + if val1 == 0: + continue + + for j in range(len2): + val2 = poly2[j] + if val2 >= prime: + val2 = val2 % prime + + if val2 == 0: + continue + + prod = (val1 * val2) % prime + current = result[i + j] + prod + if current >= prime: + current -= prime + result[i + j] = current + + return result + + +def poly_eval_domain(list poly, list domain, object prime): + """Evaluate polynomial at multiple points (for non-FFT domains).""" + cdef int n_points = len(domain) + cdef int i + cdef object x + cdef list result = [None] * n_points + + for i in range(n_points): + x = domain[i] + if x >= prime: + x = x % prime + result[i] = poly_evaluate_single(poly, x, prime) + + return result + + +def vect_scalar_mul_inplace(list vec, object scalar, object prime): + """Multiply each element by scalar modulo prime (in-place version for reduced allocations).""" + cdef int n = len(vec) + cdef int i + cdef object val + + if scalar >= prime: + scalar = scalar % prime + + if scalar == 0: + for i in range(n): + vec[i] = 0 + elif scalar != 1: + for i in range(n): + val = vec[i] + if val >= prime: + val = val % prime + vec[i] = (val * scalar) % prime + else: + # scalar == 1, just normalize + for i in range(n): + if vec[i] >= prime: + vec[i] = vec[i] % prime + + return vec + + +def vect_add_inplace(list a, list b, object prime): + """Add vector b to vector a in-place (modifies a).""" + cdef int n = len(a) + cdef int m = len(b) + cdef int i + cdef object val_a, val_b, result + + if m != n: + raise ValueError("Vector lengths must match") + + for i in range(n): + val_a = a[i] + val_b = b[i] + if val_a >= prime: + val_a = val_a % prime + if val_b >= prime: + val_b = val_b % prime + result = val_a + val_b + if result >= prime: + result -= prime + a[i] = result + + return a + + +def poly_mul_linear(list poly, object a, object b, object prime): + """Multiply poly by (ax + b) in O(n) time (optimized version).""" + cdef int n = len(poly) + cdef int i + cdef object coef, term1, term2, prev_coef + cdef list result = [None] * (n + 1) + + # Normalize a and b + if a >= prime: + a = a % prime + if b >= prime: + b = b % prime + + # Handle first element: result[0] = b * poly[0] + coef = poly[0] + if coef >= prime: + coef = coef % prime + result[0] = (b * coef) % prime + + # Handle middle elements: result[i] = a * poly[i-1] + b * poly[i] + for i in range(1, n): + prev_coef = poly[i - 1] + if prev_coef >= prime: + prev_coef = prev_coef % prime + + coef = poly[i] + if coef >= prime: + coef = coef % prime + + term1 = (a * prev_coef) % prime + term2 = (b * coef) % prime + result[i] = (term1 + term2) % prime + + # Handle last element: result[n] = a * poly[n-1] + coef = poly[n - 1] + if coef >= prime: + coef = coef % prime + result[n] = (a * coef) % prime + + return result diff --git a/dot_ring/ring_proof/proof/aggregation_poly.py b/dot_ring/ring_proof/proof/aggregation_poly.py index 3fbd39e..904b878 100644 --- a/dot_ring/ring_proof/proof/aggregation_poly.py +++ b/dot_ring/ring_proof/proof/aggregation_poly.py @@ -1,6 +1,6 @@ from dot_ring.ring_proof.constants import S_PRIME from dot_ring.ring_proof.pcs.kzg import KZG -from dot_ring.ring_proof.polynomial.ops import poly_add, poly_scalar +from dot_ring.ring_proof.polynomial.poly_ops import poly_add, poly_scalar_mul class AggPoly: @@ -20,7 +20,7 @@ def aggregated_poly(cls, fixed_cols: list, witness_cols: list, Q_p: list[int], c V_list = cf_vectors agg_poly = [0] for i in range(len(poly_I)): - agg_poly = poly_add(agg_poly, poly_scalar(poly_I[i], V_list[i], S_PRIME), S_PRIME) + agg_poly = poly_add(agg_poly, poly_scalar_mul(poly_I[i], V_list[i], S_PRIME), S_PRIME) return agg_poly # two proof openings diff --git a/dot_ring/ring_proof/proof/linearization_poly.py b/dot_ring/ring_proof/proof/linearization_poly.py index cd5c2ce..b531dc3 100644 --- a/dot_ring/ring_proof/proof/linearization_poly.py +++ b/dot_ring/ring_proof/proof/linearization_poly.py @@ -4,11 +4,9 @@ from dot_ring.curve.specs.bandersnatch import BandersnatchParams from dot_ring.ring_proof.columns.columns import Column from dot_ring.ring_proof.constants import S_PRIME -from dot_ring.ring_proof.polynomial.ops import ( - poly_add, - poly_evaluate, - poly_scalar, -) +from dot_ring.ring_proof.polynomial.ops import poly_evaluate +from dot_ring.ring_proof.polynomial.poly_ops import poly_add +from dot_ring.ring_proof.polynomial.poly_ops import poly_scalar_mul as poly_scalar from dot_ring.ring_proof.transcript.phases import phase2_eval_point from dot_ring.ring_proof.transcript.transcript import Transcript @@ -71,11 +69,6 @@ def compute_l2(self) -> list[int]: res = poly_add(term1, term2, S_PRIME) return res - # inner = (self.b_zeta * pow((self.acc_x_zeta - self.P_x_zeta) % S_PRIME, 2, S_PRIME)) % S_PRIME - # left = poly_scalar(self.wts[1].coeffs, inner, S_PRIME) - # right = poly_scalar(self.wts[2].coeffs, (1 - self.b_zeta) % S_PRIME, S_PRIME) - # return poly_scalar(poly_add(left, right, S_PRIME), self.scalar_term, S_PRIME) - def compute_l3(self) -> list[int]: b = self.b_zeta x1, y1 = self.acc_x_zeta, self.acc_y_zeta @@ -92,12 +85,6 @@ def compute_l3(self) -> list[int]: res = poly_add(term1, term2, S_PRIME) return res - # term1_scalar = (self.b_zeta * ((self.acc_y_zeta - self.P_y_zeta) % S_PRIME) + (1 - self.b_zeta)) % S_PRIME - # term2_scalar = (self.b_zeta * ((self.acc_x_zeta - self.P_x_zeta) % S_PRIME)) % S_PRIME - # term1 = poly_scalar(self.wts[1].coeffs, term1_scalar, S_PRIME) - # term2 = poly_scalar(self.wts[2].coeffs, term2_scalar, S_PRIME) - # return poly_scalar(poly_add(term1, term2, S_PRIME), self.scalar_term, S_PRIME) - def linearize(self, l1: list[int], l2: list[int], l3: list[int]) -> list[int]: l_agg = [0] for i, li in enumerate([l1, l2, l3]): diff --git a/dot_ring/scripts/export_python_proof.py b/dot_ring/scripts/export_python_proof.py index 704595b..f8e5ec9 100644 --- a/dot_ring/scripts/export_python_proof.py +++ b/dot_ring/scripts/export_python_proof.py @@ -19,10 +19,10 @@ from py_ecc.optimized_bls12_381 import normalize as nm from dot_ring.curve.specs.bandersnatch import Bandersnatch, BandersnatchParams, BandersnatchPoint -from dot_ring.ring_proof.columns.columns import PublicColumnBuilder as PC from dot_ring.ring_proof.curve.bandersnatch import TwistedEdwardCurve from dot_ring.ring_proof.params import RingProofParams from dot_ring.ring_proof.pcs.srs import srs +from dot_ring.vrf.ring.ring_root import Ring, RingRoot from dot_ring.vrf.ring.ring_vrf import RingVRF from tests.utils.python_to_rust_serde import ( serialize_bls12_381_g1, @@ -140,9 +140,9 @@ def export_variant(variant: VariantSpec, output_dir: Path) -> dict[str, Any]: prover_index = min(variant.prover_index, variant.ring_size - 1) keys_bytes, keys_points, prover_index = generate_test_keys(num_keys=variant.ring_size, prover_index=prover_index) - # Build fixed columns (this mutates the list, so use a copy) - ring_keys_for_columns = list(keys_points) - fixed_cols = PC.from_params(params).build(ring_keys_for_columns) + # Build ring and ring root + ring = Ring(keys_bytes, params) + ring_root = RingRoot.from_ring(ring, params) # Producer key producer_key_bytes = keys_bytes[prover_index] @@ -152,9 +152,9 @@ def export_variant(variant: VariantSpec, output_dir: Path) -> dict[str, Any]: proof_components = RingVRF[Bandersnatch].generate_bls_signature( blinding_factor=blinding_factor, producer_key=producer_key_bytes, - keys=keys_bytes, + ring=ring, transcript_challenge=b"w3f-ring-proof-test", - params=params, + ring_root=ring_root, ) # Compute result point (blinded public key) @@ -206,9 +206,9 @@ def export_variant(variant: VariantSpec, output_dir: Path) -> dict[str, Any]: # Serialize fixed column commitments fixed_cols_cmts_affine = [ - nm(fixed_cols[0].commitment), - nm(fixed_cols[1].commitment), - nm(fixed_cols[2].commitment), + nm(ring_root.px.commitment), + nm(ring_root.py.commitment), + nm(ring_root.s.commitment), ] verifier_key_bytes = bytearray() diff --git a/dot_ring/vrf/ring/ring_root.py b/dot_ring/vrf/ring/ring_root.py index 5f0dbb9..bd6bb7f 100644 --- a/dot_ring/vrf/ring/ring_root.py +++ b/dot_ring/vrf/ring/ring_root.py @@ -1,24 +1,98 @@ from dataclasses import dataclass +from functools import cache from typing import Any, cast from dot_ring.ring_proof.columns.columns import Column +from dot_ring.ring_proof.constants import DEFAULT_SIZE, S_PRIME, Blinding_Base, PaddingPoint +from dot_ring.ring_proof.curve.bandersnatch import TwistedEdwardCurve as TE from dot_ring.ring_proof.helpers import Helpers as H +from dot_ring.ring_proof.params import RingProofParams +@cache +def _h_vector(blinding_base: tuple[int, int] = Blinding_Base, size: int = DEFAULT_SIZE) -> list[tuple[int, int]]: + """Return `[2⁰·H, 2¹·H, …]` in short‑Weierstrass coords.""" + res = [cast(tuple[int, int], TE.mul(pow(2, i, S_PRIME), blinding_base)) for i in range(size)] + return res + + +class Ring: + nm_points: list[tuple[int, int]] + params: RingProofParams + + def __init__(self, keys: list[bytes], params: RingProofParams | None = None) -> None: + """ + Initialize a Ring from a list of public keys. + + Args: + keys: List of public keys (as bytes) for ring members + params: Ring proof parameters. If None, automatically constructed based on ring size. + + Example: + >>> # Auto-construct params based on ring size + >>> ring = Ring(keys) # Will use appropriate domain size + >>> + >>> # Or specify params explicitly + >>> params = RingProofParams(domain_size=2048, max_ring_size=1023) + >>> ring = Ring(keys, params) + """ + # Auto-construct params if not provided + if params is None: + params = RingProofParams.from_ring_size(len(keys)) + + self.params = params + + if len(keys) > params.domain_size - params.padding_rows: + raise ValueError(f"ring size {len(keys)} exceeds max supported size {params.domain_size - params.padding_rows}") + + self.nm_points = [] + for key in keys: + if isinstance(key, (str, bytes)): + point = params.cv.point.string_to_point(key) + if isinstance(point, str): + # Handle invalid point string + continue + self.nm_points.append((cast(int, point.x), cast(int, point.y))) + else: + # Handle non-string/bytes keys if necessary, or skip/raise + continue + + # Pad with special point if needed + while len(self.nm_points) < params.max_ring_size: + self.nm_points.append(PaddingPoint) + + # Ensure ring size + fill_count = params.domain_size - params.padding_rows - len(self.nm_points) + if fill_count > 0: + h_vec = _h_vector(size=params.domain_size) + self.nm_points.extend(h_vec[:fill_count]) + if params.padding_rows > 0: + self.nm_points.extend([(0, 0)] * params.padding_rows) + + @dataclass class RingRoot: px: Column py: Column s: Column + + @classmethod + def from_ring(cls, ring: Ring, params: RingProofParams): + # Px, Py, s points + px, py = H.unzip(ring.nm_points) + selector_vec = [1 if i < params.max_ring_size else 0 for i in range(params.domain_size)] + # Columns + px_col = Column("Px", px, size=params.domain_size) + py_col = Column("Py", py, size=params.domain_size) + s_col = Column("s", selector_vec, size=params.domain_size) + for col in (px_col, py_col, s_col): + col.interpolate(params.omega, params.prime) + col.commit() + return cls(px=px_col, py=py_col, s=s_col) def to_bytes(self) -> bytes: - # Assuming H.bls_g1_compress expects a tuple representation of the commitment - # and that self.px.commitment, etc., are convertible to such a tuple. - # If the commitment is already a G1 point object, casting to tuple might be incorrect - # or require a specific conversion method not shown here. - # This implementation assumes 'ring_root' refers to the commitments themselves. comm_keys = ( - H.bls_g1_compress(cast(Any, self.px.commitment)), # Cast to Any or a more specific tuple type if known + H.bls_g1_compress(cast(Any, self.px.commitment)), H.bls_g1_compress(cast(Any, self.py.commitment)), H.bls_g1_compress(cast(Any, self.s.commitment)), ) diff --git a/dot_ring/vrf/ring/ring_vrf.py b/dot_ring/vrf/ring/ring_vrf.py index d41961f..75787ce 100644 --- a/dot_ring/vrf/ring/ring_vrf.py +++ b/dot_ring/vrf/ring/ring_vrf.py @@ -5,7 +5,6 @@ from dot_ring.curve.point import CurvePoint from dot_ring.ring_proof.columns.columns import Column, WitnessColumnBuilder -from dot_ring.ring_proof.columns.columns import PublicColumnBuilder as PC from dot_ring.ring_proof.constants import ( S_PRIME, Blinding_Base, @@ -16,7 +15,6 @@ from dot_ring.ring_proof.constraints.constraints import RingConstraintBuilder from dot_ring.ring_proof.curve.bandersnatch import TwistedEdwardCurve from dot_ring.ring_proof.helpers import Helpers as H -from dot_ring.ring_proof.params import RingProofParams from dot_ring.ring_proof.pcs.kzg import Opening from dot_ring.ring_proof.pcs.srs import srs from dot_ring.ring_proof.proof.aggregation_poly import AggPoly @@ -28,7 +26,7 @@ from dot_ring.vrf.pedersen.pedersen import PedersenVRF from ..vrf import VRF -from .ring_root import RingRoot +from .ring_root import Ring, RingRoot @dataclass @@ -177,9 +175,9 @@ def generate_bls_signature( cls, blinding_factor: int, producer_key: bytes | str, - keys: list[Any] | str | bytes, + ring: Ring, transcript_challenge: bytes = b"Bandersnatch_SHA-512_ELL2", - params: RingProofParams | None = None, + ring_root: RingRoot | None = None, ) -> tuple[ Column, Column, @@ -198,9 +196,17 @@ def generate_bls_signature( Any, ]: """ - Returns the Ring Proof as an output + Returns the Ring Proof as an output. + + Args: + blinding_factor: Blinding factor from Pedersen VRF + producer_key: Public key of the prover + ring: Ring object containing member keys and params + transcript_challenge: Challenge for Fiat-Shamir + ring_root: Optional pre-computed ring root for performance """ - params = params or RingProofParams() + # Use params from the ring object + params = ring.params producer_key_point = cls.cv.point.string_to_point(producer_key) if isinstance(producer_key_point, str) or producer_key_point.is_identity(): @@ -210,25 +216,14 @@ def generate_bls_signature( cast(int, producer_key_point.x), cast(int, producer_key_point.y), ) - keys_as_bs_points = [] - - for key in keys: - if isinstance(key, (str, bytes)): - point = cls.cv.point.string_to_point(key) - if isinstance(point, str): - # Handle invalid point string - continue - keys_as_bs_points.append((cast(int, point.x), cast(int, point.y))) - else: - # Handle non-string/bytes keys if necessary, or skip/raise - continue - - ring_root = PC.from_params(params) # ring_root builder - fixed_cols = ring_root.build(keys_as_bs_points) - s_v = fixed_cols[-1].evals - producer_index = keys_as_bs_points.index(producer_key_pt) + + if not ring_root: + ring_root = RingRoot.from_ring(ring, params) # ring_root builder + + s_v = ring_root.s.evals + producer_index = ring.nm_points.index(producer_key_pt) witness_obj = WitnessColumnBuilder.from_params( - keys_as_bs_points, + ring.nm_points, s_v, producer_index, blinding_factor, @@ -239,9 +234,9 @@ def generate_bls_signature( Result_plus_Seed = witness_obj.result_p_seed(witness_relation_res) constraints = RingConstraintBuilder( Result_plus_Seed=Result_plus_Seed, # type: ignore - px=cast(list[int], fixed_cols[0].coeffs), - py=cast(list[int], fixed_cols[1].coeffs), - s=cast(list[int], fixed_cols[2].coeffs), + px=cast(list[int], ring_root.px.coeffs), + py=cast(list[int], ring_root.py.coeffs), + s=cast(list[int], ring_root.s.coeffs), b=cast(list[int], witness_res[0].coeffs), acc_x=cast(list[int], witness_res[1].coeffs), acc_y=cast(list[int], witness_res[2].coeffs), @@ -251,9 +246,9 @@ def generate_bls_signature( constraint_dict = constraints.compute() fixed_col_commits = [ - H.to_int(nm(fixed_cols[0].commitment)), - H.to_int(nm(fixed_cols[1].commitment)), - H.to_int(nm(fixed_cols[2].commitment)), + H.to_int(nm(ring_root.px.commitment)), + H.to_int(nm(ring_root.py.commitment)), + H.to_int(nm(ring_root.s.commitment)), ] ws = witness_res @@ -282,7 +277,7 @@ def generate_bls_signature( l_obj = LAggPoly( t, list(H.to_int(C_q_nm)), - list(fixed_cols), + list([ring_root.px, ring_root.py, ring_root.s]), list(ws), alpha, domain=params.domain, @@ -294,7 +289,7 @@ def generate_bls_signature( zeta, zeta_omega, l_agg, - list(fixed_cols), + list([ring_root.px, ring_root.py, ring_root.s]), list(ws), Q_p, phase3_nu_vector(current_t, list(rel_poly_evals.values()), l_zw), @@ -330,16 +325,12 @@ def generate_bls_signature( def verify_ring_proof( self, message: bytes | CurvePoint, - ring_root: RingRoot | bytes, - params: RingProofParams | None = None, + ring: Ring, + ring_root: RingRoot, ) -> bool: """ Verifies the Ring Proof """ - params = params or RingProofParams() - # Decompress ring_root once at the start - if isinstance(ring_root, bytes): - ring_root = RingRoot.from_bytes(ring_root) fixed_cols_cmts = [ ring_root.px.commitment, ring_root.py.commitment, @@ -391,37 +382,10 @@ def verify_ring_proof( rltn, res_plus_seed, SeedPoint, - params.domain, - padding_rows=params.padding_rows, + ring.params.domain, + padding_rows=ring.params.padding_rows, ).is_valid() - @classmethod - def construct_ring_root( - cls, - keys: list[bytes], - params: RingProofParams | None = None, - ) -> RingRoot: - """ - Constructs the Ring Root - """ - params = params or RingProofParams() - keys_as_bs_points = [] - for key in keys: - if not isinstance(key, (str, bytes)): - continue - point = cls.cv.point.string_to_point(key) - - if isinstance(point, str) or point.is_identity(): - keys_as_bs_points.append((PaddingPoint[0], PaddingPoint[1])) - - else: - keys_as_bs_points.append((cast(int, point.x), cast(int, point.y))) - - ring_root = PC.from_params(params) # ring_root builder - fixed_cols = ring_root.build(keys_as_bs_points) - - return RingRoot(*fixed_cols) - @classmethod def prove( cls, @@ -429,17 +393,32 @@ def prove( ad: bytes, secret_key: bytes, producer_key: bytes, - keys: list[bytes], - params: RingProofParams | None = None, + ring: Ring, + ring_root: RingRoot | None = None, ) -> "RingVRF": """ - Generate ring VRF proof (pedersen vrf proof + ring_proof) + Generate ring VRF proof (pedersen vrf proof + ring_proof). + + Args: + alpha: VRF input + ad: Additional data + secret_key: Prover's secret key + producer_key: Prover's public key + ring: Ring object containing member keys. Params are auto-constructed if not provided to Ring. + ring_root: Pre-computed ring root. If provided, skips expensive ring column construction (~335ms for 1023 members). + + Returns: + RingVRF proof + + Examples: + >>> ring = Ring(keys) # Automatically determines optimal domain size + >>> proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring) """ - # pedersen_proof pedersen_proof = PedersenVRF[cast(Any, cls).cv].prove(alpha, secret_key, ad) # type: ignore[misc] - # ring_proof - ring_proof = cls.generate_bls_signature(pedersen_proof._blinding_factor, producer_key, keys, params=params) + ring_proof = cls.generate_bls_signature( + pedersen_proof._blinding_factor, producer_key, ring=ring, ring_root=ring_root + ) return cls(pedersen_proof, *ring_proof) @@ -459,20 +438,15 @@ def verify( self, input: bytes, ad_data: bytes, - ring_root: RingRoot | bytes, - params: RingProofParams | None = None, + ring: Ring, + ring_root: RingRoot ) -> bool: """ Verify ring VRF proof (pedersen_proof + ring_proof) """ - # Decompress ring_root once at the start - if isinstance(ring_root, bytes): - ring_root = RingRoot.from_bytes(ring_root) - - # is pedersen proof valid if self.pedersen_proof is None: raise ValueError("Pedersen proof is missing") p_proof_valid = self.pedersen_proof.verify(input, ad_data) - ring_proof_valid = self.verify_ring_proof(self.pedersen_proof.blinded_pk, ring_root, params=params) + ring_proof_valid = self.verify_ring_proof(self.pedersen_proof.blinded_pk, ring, ring_root) return p_proof_valid and ring_proof_valid diff --git a/pyproject.toml b/pyproject.toml index 23257c8..42c7e80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,7 +124,7 @@ disallow_untyped_defs = true exclude = ["build/"] [[tool.mypy.overrides]] -module = ["dot_ring.curve.field_arithmetic", "dot_ring.curve.fast_math", "dot_ring.ring_proof.polynomial.ntt", "dot_ring.blst.*", "py_ecc.*"] +module = ["dot_ring.curve.field_arithmetic", "dot_ring.curve.fast_math", "dot_ring.ring_proof.polynomial.ntt", "dot_ring.ring_proof.polynomial.poly_ops", "dot_ring.blst.*", "py_ecc.*"] ignore_missing_imports = true disallow_untyped_defs = false check_untyped_defs = false diff --git a/setup.py b/setup.py index 56086f2..09b5d64 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,12 @@ def build_cython_extensions() -> list[Extension]: include_dirs=["dot_ring/curve/native_field"], extra_compile_args=["-O3", "-ffast-math"], ), + Extension( + "dot_ring.ring_proof.polynomial.poly_ops", + ["dot_ring/ring_proof/polynomial/poly_ops.pyx"], + extra_compile_args=compile_args, + extra_link_args=[] if sys.platform == "win32" else ["-flto"], + ), Extension( "dot_ring.curve.native_field.scalar", [ diff --git a/tests/benchmark/bench_ring_large.py b/tests/benchmark/bench_ring_large.py new file mode 100644 index 0000000..f95c28b --- /dev/null +++ b/tests/benchmark/bench_ring_large.py @@ -0,0 +1,179 @@ +import statistics +import sys +import time +from pathlib import Path + +# Add blst to path if needed +sys.path.insert(0, str(Path(__file__).parent / "blst" / "bindings" / "python")) + +from dot_ring import Bandersnatch, Ring, RingRoot, RingVRF +from dot_ring.keygen import secret_from_seed +from dot_ring.ring_proof.params import RingProofParams + + +def generate_ring_keys(ring_size: int) -> tuple[list[bytes], bytes, bytes]: + """Generate ring keys for benchmarking.""" + print(f"Generating {ring_size} keys for the ring...") + + keys = [] + prover_sk = None + prover_pk = None + + # Generate keys deterministically from seeds + for i in range(ring_size): + seed = f"ring_member_{i}".encode() + pk, sk = secret_from_seed(seed, Bandersnatch) + keys.append(pk) + + # Use the middle key as the prover + if i == ring_size // 2: + prover_sk = sk + prover_pk = pk + + assert prover_sk is not None + assert prover_pk is not None + print(f"Generated {len(keys)} keys") + + return keys, prover_sk, prover_pk + + +def benchmark_large_ring_proof( + ring_size: int = 1023, + domain_size: int = 2048, + warmup_iters: int = 4, + bench_iters: int = 5, +): + """Benchmark 1024 ring-sized proof generation and verification over 2048 domain.""" + + print("=" * 60) + print("1024 Ring VRF Proof Benchmark") + print("=" * 60) + print() + print(f"Ring size: {ring_size} members") + print(f"Domain size: {domain_size}") + print(f"Warmup iterations: {warmup_iters}") + print(f"Benchmark iterations: {bench_iters}") + print() + + # Generate test data + keys, s_k, p_k = generate_ring_keys(ring_size) + alpha = b"test_alpha" + ad = b"" + + # Create parameters for large ring + params = RingProofParams(domain_size=domain_size, max_ring_size=ring_size) + + print("Parameters configured:") + print(f" Domain size: {params.domain_size}") + print(f" Max ring size: {params.max_ring_size}") + print(f" Padding rows: {params.padding_rows}") + print() + + # ========================================================================= + # Construct Ring Root (one-time setup) + # ========================================================================= + start = time.perf_counter() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + ring_root_time = (time.perf_counter() - start) * 1000 + print(f"Ring root constructed in {ring_root_time:.2f} ms") + print() + + # ========================================================================= + # Warmup + # ========================================================================= + print("Warming up...") + for i in range(warmup_iters): + print(f" Warmup iteration {i+1}/{warmup_iters}...") + # Pass ring_root to avoid rebuilding it + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) + ring_vrf_proof.verify(alpha, ad, ring, ring_root) + print("Warmup complete") + + # ========================================================================= + # Benchmark Ring Root Construction + # ========================================================================= + print("\nBenchmarking ring root construction...") + root_const_times = [] + for _ in range(bench_iters): + start = time.perf_counter() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + elapsed = (time.perf_counter() - start) * 1000 + root_const_times.append(elapsed) + + # ========================================================================= + # Benchmark Proof Generation + # ========================================================================= + print("\nBenchmarking Proof Generation...") + proof_times = [] + proofs = [] + for i in range(bench_iters): + print(f" Iteration {i+1}/{bench_iters}...") + start = time.perf_counter() + # Pass ring_root to avoid rebuilding - this is the key optimization! + ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) + elapsed = (time.perf_counter() - start) * 1000 + proof_times.append(elapsed) + proofs.append(ring_vrf_proof) + + # ========================================================================= + # Benchmark Verification + # ========================================================================= + print("\nBenchmarking Verification...") + verify_times = [] + for i, proof in enumerate(proofs): + print(f" Iteration {i+1}/{bench_iters}...") + start = time.perf_counter() + result = proof.verify(alpha, ad, ring, ring_root) + elapsed = (time.perf_counter() - start) * 1000 + verify_times.append(elapsed) + assert result, "Verification failed!" + + # ========================================================================= + # Results + # ========================================================================= + print() + print("=" * 60) + print("RESULTS") + print("=" * 60) + print() + + def print_stats(name: str, times: list): + min_t = min(times) + mean_t = statistics.mean(times) + std_t = statistics.stdev(times) if len(times) > 1 else 0 + print(f"{name}:") + print(f" Min: {min_t:8.2f} ms") + print(f" Mean: {mean_t:8.2f} ms") + print(f" Stddev: {std_t:8.2f} ms") + print() + return min_t, mean_t + + print() + proof_min, proof_mean = print_stats("Ring Root Construction", root_const_times) + proof_min, proof_mean = print_stats("Proof Generation", proof_times) + verify_min, verify_mean = print_stats("Verification", verify_times) + + print("-" * 60) + print(f"Total (Proof + Verify) Min: {proof_min + verify_min:8.2f} ms") + print(f"Total (Proof + Verify) Mean: {proof_mean + verify_mean:8.2f} ms") + print() + + # Proof size + proof_bytes = ring_vrf_proof.to_bytes() + print(f"Proof size: {len(proof_bytes)} bytes") + print() + + return { + "ring_root_time": ring_root_time, + "proof": proof_times, + "verify": verify_times, + "proof_size": len(proof_bytes), + } + + +if __name__ == "__main__": + benchmark_large_ring_proof(ring_size=1023, domain_size=2048, warmup_iters=1, bench_iters=3) diff --git a/tests/benchmark/bench_ring_proof.py b/tests/benchmark/bench_ring_proof.py index e45c422..65f290d 100644 --- a/tests/benchmark/bench_ring_proof.py +++ b/tests/benchmark/bench_ring_proof.py @@ -20,12 +20,14 @@ sys.path.insert(0, str(Path(__file__).parent / "blst" / "bindings" / "python")) from dot_ring import Bandersnatch +from dot_ring.ring_proof.params import RingProofParams +from dot_ring.vrf.ring.ring_root import Ring, RingRoot from dot_ring.vrf.ring.ring_vrf import RingVRF def load_test_data(): """Load test vector data.""" - vector_path = Path(__file__).parent / "vectors" / "ark-vrf" / "bandersnatch_ed_sha512_ell2_ring.json" + vector_path = Path(__file__).parent.parent / "vectors" / "ark-vrf" / "bandersnatch_ed_sha512_ell2_ring.json" with open(vector_path) as f: return json.load(f)[0] @@ -34,7 +36,7 @@ def benchmark_ring_proof(warmup_iters: int = 3, bench_iters: int = 10): """Benchmark ring proof generation and verification.""" print("=" * 60) - print("Ring VRF Proof Benchmark (Python/Cython + gmpy2 + BLST)") + print("Ring VRF Proof Benchmark") print("=" * 60) print() @@ -58,10 +60,12 @@ def benchmark_ring_proof(warmup_iters: int = 3, bench_iters: int = 10): # Warmup # ========================================================================= print("Warming up...") + params = RingProofParams() for _ in range(warmup_iters): - ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, keys) - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) - ring_vrf_proof.verify(alpha, ad, ring_root) + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root=ring_root) + ring_vrf_proof.verify(alpha, ad, ring, ring_root) # ========================================================================= # Benchmark Ring Root Construction @@ -71,7 +75,8 @@ def benchmark_ring_proof(warmup_iters: int = 3, bench_iters: int = 10): ring_root = None for _ in range(bench_iters): start = time.perf_counter() - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) elapsed = (time.perf_counter() - start) * 1000 ring_root_times.append(elapsed) @@ -81,9 +86,11 @@ def benchmark_ring_proof(warmup_iters: int = 3, bench_iters: int = 10): print("Benchmarking Proof Generation...") proof_times = [] proofs = [] + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) for _ in range(bench_iters): start = time.perf_counter() - ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, keys) + ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) elapsed = (time.perf_counter() - start) * 1000 proof_times.append(elapsed) proofs.append(ring_vrf_proof) @@ -95,7 +102,7 @@ def benchmark_ring_proof(warmup_iters: int = 3, bench_iters: int = 10): verify_times = [] for proof in proofs: start = time.perf_counter() - result = proof.verify(alpha, ad, ring_root) + result = proof.verify(alpha, ad, ring, ring_root) elapsed = (time.perf_counter() - start) * 1000 verify_times.append(elapsed) assert result, "Verification failed!" diff --git a/tests/benchmark/test_bench_ring.py b/tests/benchmark/test_bench_ring.py index 2bc64e9..495b962 100644 --- a/tests/benchmark/test_bench_ring.py +++ b/tests/benchmark/test_bench_ring.py @@ -7,6 +7,8 @@ sys.path.insert(0, str(Path(__file__).parent)) from dot_ring.curve.specs.bandersnatch import Bandersnatch +from dot_ring.ring_proof.params import RingProofParams +from dot_ring.vrf.ring.ring_root import Ring, RingRoot from dot_ring.vrf.ring.ring_vrf import RingVRF from ..utils.profiler import Profiler @@ -37,8 +39,10 @@ def test_bench_ring_prove(): ad = bytes.fromhex(item["ad"]) keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"])) - # Store ring size - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + # Construct ring and ring root + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) p_k = RingVRF[Bandersnatch].get_public_key(s_k) # Benchmark proof generation with profiling @@ -49,7 +53,7 @@ def test_bench_ring_prove(): sort_by="cumulative", limit=25, ): - _ = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, keys) + _ = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) # Verify correctness assert p_k.hex() == item["pk"], "Invalid Public Key" @@ -79,8 +83,14 @@ def test_bench_ring_verify(): item = data[index] alpha = bytes.fromhex(item["alpha"]) ad = bytes.fromhex(item["ad"]) + keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"])) ring_root_bytes = bytes.fromhex(item["ring_pks_com"]) + # Construct ring and ring root + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_bytes(ring_root_bytes) + proof_hex = ( item["gamma"] + item["proof_pk_com"] + item["proof_r"] + item["proof_ok"] + item["proof_s"] + item["proof_sb"] + item["ring_proof"] ) @@ -97,7 +107,7 @@ def test_bench_ring_verify(): sort_by="cumulative", limit=100, ): - result = ring_vrf_proof.verify(alpha, ad, ring_root_bytes) + result = ring_vrf_proof.verify(alpha, ad, ring, ring_root) assert result, "Verification Failed" diff --git a/tests/benchmark/test_bench_ring_proof.py b/tests/benchmark/test_bench_ring_proof.py index a8a9cc3..9aa8e85 100644 --- a/tests/benchmark/test_bench_ring_proof.py +++ b/tests/benchmark/test_bench_ring_proof.py @@ -7,6 +7,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "dot_ring" / "blst" / "bindings" / "python")) from dot_ring import Bandersnatch +from dot_ring.vrf.ring.ring_root import Ring, RingRoot from dot_ring.vrf.ring.ring_vrf import RingVRF @@ -28,17 +29,18 @@ def ring_data(request): alpha = b"test_message" ad = b"test_ad" - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + ring = Ring(keys) + ring_root = RingRoot.from_ring(ring, ring.params) # Pre-calculate a proof for verification benchmark - proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, keys) + proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) - return {"s_k": s_k, "p_k": p_k, "keys": keys, "alpha": alpha, "ad": ad, "ring_root": ring_root, "proof": proof} + return {"s_k": s_k, "p_k": p_k, "ring": ring, "alpha": alpha, "ad": ad, "ring_root": ring_root, "proof": proof} def test_prove(benchmark, ring_data): - benchmark(RingVRF[Bandersnatch].prove, ring_data["alpha"], ring_data["ad"], ring_data["s_k"], ring_data["p_k"], ring_data["keys"]) + benchmark(RingVRF[Bandersnatch].prove, ring_data["alpha"], ring_data["ad"], ring_data["s_k"], ring_data["p_k"], ring_data["ring"], ring_data["ring_root"]) def test_verify(benchmark, ring_data): - benchmark(ring_data["proof"].verify, ring_data["alpha"], ring_data["ad"], ring_data["ring_root"]) + benchmark(ring_data["proof"].verify, ring_data["alpha"], ring_data["ad"], ring_data["ring"], ring_data["ring_root"]) diff --git a/tests/test_bandersnatch_ark.py b/tests/test_bandersnatch_ark.py index d4218e4..55a25f7 100644 --- a/tests/test_bandersnatch_ark.py +++ b/tests/test_bandersnatch_ark.py @@ -5,8 +5,10 @@ from time import time from dot_ring.curve.specs.bandersnatch import Bandersnatch +from dot_ring.ring_proof.params import RingProofParams from dot_ring.vrf.ietf.ietf import IETF_VRF from dot_ring.vrf.pedersen.pedersen import PedersenVRF +from dot_ring.vrf.ring.ring_root import Ring, RingRoot from dot_ring.vrf.ring.ring_vrf import RingVRF HERE = os.path.dirname(__file__) @@ -119,11 +121,13 @@ def test_ring_proof(): ad = bytes.fromhex(item["ad"]) keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"])) start = time() - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) ring_time = time() print(f"\nTime taken for Ring Root Construction: \t\t {1000 * (ring_time - start):.2f} ms") p_k = RingVRF[Bandersnatch].get_public_key(s_k) - ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, keys) + ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) pk_time = time() print(f"Time taken for Proof Generation: \t {1000 * (pk_time - ring_time):.2f} ms") proof_bytes = ring_vrf_proof.to_bytes() @@ -136,10 +140,10 @@ def test_ring_proof(): == item["gamma"] + item["proof_pk_com"] + item["proof_r"] + item["proof_ok"] + item["proof_s"] + item["proof_sb"] + item["ring_proof"] ), "Unexpected Proof" - assert ring_vrf_proof.verify(alpha, ad, ring_root), "Verification Failed" + assert ring_vrf_proof.verify(alpha, ad, ring, ring_root), "Verification Failed" assert proof_rt.to_bytes() == proof_bytes start = time() - assert proof_rt.verify(alpha, ad, ring_root) + assert proof_rt.verify(alpha, ad, ring, ring_root) verify_time = time() print(f"Time taken for Proof Verification: \t {1000 * (verify_time - start):.2f} ms") print(f"✅ Testcase {index + 1} of {os.path.basename(file_path)}") diff --git a/tests/test_coverage/test_columns.py b/tests/test_coverage/test_columns.py index 96c16b9..bb8f8f4 100644 --- a/tests/test_coverage/test_columns.py +++ b/tests/test_coverage/test_columns.py @@ -1,8 +1,10 @@ import pytest -from dot_ring.ring_proof.columns.columns import Column, PublicColumnBuilder, WitnessColumnBuilder -from dot_ring.ring_proof.constants import DEFAULT_SIZE, OMEGAS, S_PRIME, PaddingPoint +from dot_ring.curve.specs.bandersnatch import Bandersnatch +from dot_ring.ring_proof.columns.columns import Column, WitnessColumnBuilder +from dot_ring.ring_proof.constants import DEFAULT_SIZE, OMEGAS, S_PRIME from dot_ring.ring_proof.params import RingProofParams +from dot_ring.vrf.ring.ring_root import Ring, RingRoot def test_column_interpolate_rejects_oversize_evals(): @@ -17,23 +19,25 @@ def test_column_commit_requires_coeffs(): col.commit() -def test_public_builder_from_params_and_padding(): +def test_ring_from_params(): + """Test Ring construction with explicit params""" params = RingProofParams(domain_size=8, max_ring_size=3, padding_rows=1) - builder = PublicColumnBuilder.from_params(params) - assert builder.size == 8 - assert builder.max_ring_size == 3 + # Create dummy keys (just using some bytes) + keys = [b"key1" * 8, b"key2" * 8] + ring = Ring(keys, params) - ring = [(1, 1)] - padded = builder._pad_ring_with_padding_point(ring) - assert len(padded) == builder.max_ring_size - assert padded[-1] == PaddingPoint + assert ring.params == params + assert len(ring.nm_points) == params.domain_size -def test_public_builder_rejects_oversize_ring(): - builder = PublicColumnBuilder(size=8, max_ring_size=2, padding_rows=1) - ring_pk = [(0, 0)] * 8 +def test_ring_rejects_oversize_ring(): + """Test Ring rejects rings that are too large""" + params = RingProofParams(domain_size=8, max_ring_size=2, padding_rows=1) + # Create 8 keys which exceeds max_ring_size=2 and domain_size-padding_rows=7 + keys = [b"key" + bytes([i]) * 8 for i in range(8)] + with pytest.raises(ValueError, match="exceeds max supported size"): - builder.build(ring_pk) + Ring(keys, params) def test_witness_builder_from_params_and_bits_vector_error(): diff --git a/tests/test_coverage/test_ops.py b/tests/test_coverage/test_ops.py index 681141b..0d6fe35 100644 --- a/tests/test_coverage/test_ops.py +++ b/tests/test_coverage/test_ops.py @@ -1,20 +1,17 @@ """Tests for polynomial operations module to improve coverage.""" import pytest +from dot_ring.ring_proof.polynomial.poly_ops import poly_add, poly_evaluate_single, poly_mul_linear, poly_subtract +from dot_ring.ring_proof.polynomial.poly_ops import poly_scalar_mul as poly_scalar from dot_ring.ring_proof.constants import D_512, D_2048, S_PRIME from dot_ring.ring_proof.polynomial.ops import ( get_root_of_unity, lagrange_basis_polynomial, mod_inverse, - poly_add, poly_division_general, poly_evaluate, - poly_evaluate_single, - poly_mul_linear, poly_multiply, - poly_scalar, - poly_subtract, vect_scalar_mul, ) diff --git a/tests/test_curve_ops/test_gaps.py b/tests/test_curve_ops/test_gaps.py index 44fd491..57df399 100644 --- a/tests/test_curve_ops/test_gaps.py +++ b/tests/test_curve_ops/test_gaps.py @@ -2,6 +2,8 @@ from dot_ring import IETF_VRF, Bandersnatch, PedersenVRF from dot_ring.curve.specs.ed448 import Ed448_RO +from dot_ring.ring_proof.params import RingProofParams +from dot_ring.vrf.ring.ring_root import Ring, RingRoot class TestCoverageGaps: @@ -207,7 +209,10 @@ def test_ring_from_bytes_skip_pedersen(self): pk = PedersenVRF[Bandersnatch].get_public_key(sk) keys = [pk] - proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, keys) + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root) proof_bytes = proof.to_bytes() # Parse with skip_pedersen=True @@ -239,20 +244,19 @@ def test_ring_verify_missing_pedersen(self): pk = PedersenVRF[Bandersnatch].get_public_key(sk) keys = [pk] - proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, keys) + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root) proof_bytes = proof.to_bytes() ring_proof_bytes = proof_bytes[192:] parsed = RingVRF[Bandersnatch].from_bytes(ring_proof_bytes, skip_pedersen=True) - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) - with pytest.raises(ValueError, match="Pedersen proof is missing"): - parsed.verify(alpha, ad, ring_root) + parsed.verify(alpha, ad, ring, ring_root) def test_ring_construct_ring_root_invalid_keys(self): - """Test construct_ring_root with invalid keys.""" - from dot_ring.vrf.ring.ring_vrf import RingVRF - + """Test Ring construction with invalid keys.""" # Invalid key string invalid_key = b"invalid" # Identity point (if we can construct one as string) @@ -263,11 +267,14 @@ def test_ring_construct_ring_root_invalid_keys(self): ] # 33 bytes of zeros might be invalid or identity? # Should not raise, but handle gracefully (skip or use padding) - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + params = RingProofParams() + ring = Ring(keys, params) + assert ring is not None + ring_root = RingRoot.from_ring(ring, params) assert ring_root is not None def test_ring_verify_ring_proof_bytes_input(self): - """Test verify_ring_proof handles bytes input for message and ring_root.""" + """Test verify_ring_proof handles bytes input for message.""" from dot_ring.vrf.ring.ring_vrf import RingVRF alpha = b"test" @@ -276,28 +283,30 @@ def test_ring_verify_ring_proof_bytes_input(self): pk = PedersenVRF[Bandersnatch].get_public_key(sk) keys = [pk] - proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, keys) - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) - ring_root_bytes = ring_root.to_bytes() + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root) # message is usually a point (blinded_pk), but verify_ring_proof accepts bytes too. # But wait, verify_ring_proof takes `message: bytes | CurvePoint`. - # In `verify`, it calls `self.verify_ring_proof(self.pedersen_proof.blinded_pk, ring_root)`. + # In `verify`, it calls `self.verify_ring_proof(self.pedersen_proof.blinded_pk, ring, ring_root)`. # `blinded_pk` is a CurvePoint. # If we pass bytes, it tries to decode. # Let's call verify_ring_proof directly with bytes blinded_pk_bytes = proof.pedersen_proof.blinded_pk.point_to_string() - valid = proof.verify_ring_proof(blinded_pk_bytes, ring_root_bytes) + valid = proof.verify_ring_proof(blinded_pk_bytes, ring, ring_root) assert valid def test_ring_construct_ring_root_non_bytes_key(self): - """Test construct_ring_root with non-bytes/str key.""" - from dot_ring.vrf.ring.ring_vrf import RingVRF - + """Test Ring construction with non-bytes/str key.""" keys = [123] # type: ignore - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + params = RingProofParams() + ring = Ring(keys, params) + assert ring is not None + ring_root = RingRoot.from_ring(ring, params) assert ring_root is not None def test_ring_verify_ring_proof_invalid_message(self): @@ -310,10 +319,12 @@ def test_ring_verify_ring_proof_invalid_message(self): pk = PedersenVRF[Bandersnatch].get_public_key(sk) keys = [pk] - proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, keys) - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) + proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root) invalid_message = b"\xff" * 33 with pytest.raises(ValueError, match="Invalid message point"): - proof.verify_ring_proof(invalid_message, ring_root) + proof.verify_ring_proof(invalid_message, ring, ring_root) diff --git a/tests/test_ring_vrf/test_ring_vrf.py b/tests/test_ring_vrf/test_ring_vrf.py index df2b946..164c84e 100644 --- a/tests/test_ring_vrf/test_ring_vrf.py +++ b/tests/test_ring_vrf/test_ring_vrf.py @@ -3,6 +3,8 @@ import time from dot_ring.curve.specs.bandersnatch import Bandersnatch +from dot_ring.ring_proof.params import RingProofParams +from dot_ring.vrf.ring.ring_root import Ring, RingRoot from dot_ring.vrf.ring.ring_vrf import RingVRF HERE = os.path.dirname(__file__) @@ -22,12 +24,14 @@ def test_ring_proof(): keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"])) start_time = time.time() - ring_root = RingVRF[Bandersnatch].construct_ring_root(keys) + params = RingProofParams() + ring = Ring(keys, params) + ring_root = RingRoot.from_ring(ring, params) ring_time = time.time() print(f"\nTime taken for Ring Root Construction: \t\t {ring_time - start_time} seconds") p_k = RingVRF[Bandersnatch].get_public_key(s_k) - ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, keys) + ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) proof_bytes = ring_vrf_proof.to_bytes() proof_rt = RingVRF[Bandersnatch].from_bytes(proof_bytes) @@ -41,13 +45,13 @@ def test_ring_proof(): == item["gamma"] + item["proof_pk_com"] + item["proof_r"] + item["proof_ok"] + item["proof_s"] + item["proof_sb"] + item["ring_proof"] ), "Unexpected Proof" start = time.time() - assert ring_vrf_proof.verify(alpha, ad, ring_root), "Verification Failed" + assert ring_vrf_proof.verify(alpha, ad, ring, ring_root), "Verification Failed" print( "Time taken for Ring VRF Proof Verification: \t ", time.time() - start, " seconds", ) assert proof_rt.to_bytes() == proof_bytes - assert proof_rt.verify(alpha, ad, ring_root) + assert proof_rt.verify(alpha, ad, ring, ring_root) print(f"✅ Testcase {index + 1} of {os.path.basename(file_path)}") diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 639eb0a..dcb30f7 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -7,6 +7,8 @@ from dot_ring import IETF_VRF, P256, Bandersnatch, Ed25519, PedersenVRF, RingVRF from dot_ring.ring_proof.helpers import Helpers +from dot_ring.ring_proof.params import RingProofParams +from dot_ring.vrf.ring.ring_root import Ring, RingRoot # Alias Secp256r1 = P256 @@ -269,16 +271,20 @@ def verify_ring_vector(vector: dict[str, Any], curve) -> None: for i in range(0, len(ring_pks_bytes), point_len): ring_pks.append(ring_pks_bytes[i : i + point_len]) + # Construct ring and ring root + params = RingProofParams() + ring = Ring(ring_pks, params) + ring_root = RingRoot.from_ring(ring, params) + # Generate proof - proof = RingVRF[curve].prove(alpha, ad, sk, pk, ring_pks) + proof = RingVRF[curve].prove(alpha, ad, sk, pk, ring, ring_root) # Verify output point matches gamma_bytes = proof.pedersen_proof.output_point.point_to_string() assert gamma_bytes == expected_gamma, f"gamma mismatch: expected {expected_gamma.hex()}, got {gamma_bytes.hex()}" - # Construct ring root and verify - ring_root = RingVRF[curve].construct_ring_root(ring_pks) - assert proof.verify(alpha, ad, ring_root), "Proof verification failed" + # Verify proof + assert proof.verify(alpha, ad, ring, ring_root), "Proof verification failed" # Verify output hash beta = RingVRF[curve].proof_to_hash(proof.pedersen_proof.output_point) @@ -411,16 +417,21 @@ def test_wrong_ring_root(self): alpha = b"test_input" ad = b"test_ad" + # Construct rings and ring roots + params = RingProofParams() + ring_obj1 = Ring(ring1, params) + ring_root1 = RingRoot.from_ring(ring_obj1, params) + ring_obj2 = Ring(ring2, params) + ring_root2 = RingRoot.from_ring(ring_obj2, params) + # Generate proof for ring1 - proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring1) + proof = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring_obj1, ring_root1) # Verify with correct ring should pass - ring_root1 = RingVRF[Bandersnatch].construct_ring_root(ring1) - assert proof.verify(alpha, ad, ring_root1) + assert proof.verify(alpha, ad, ring_obj1, ring_root1) # Verify with wrong ring should fail - ring_root2 = RingVRF[Bandersnatch].construct_ring_root(ring2) - assert not proof.verify(alpha, ad, ring_root2) + assert not proof.verify(alpha, ad, ring_obj2, ring_root2) # ============================================================================= From f3ab1ceb3d9c8afc4dc3027ba8304fd46db5f1b8 Mon Sep 17 00:00:00 2001 From: Prasad Kumkar Date: Wed, 11 Feb 2026 20:52:12 +0530 Subject: [PATCH 2/3] fix: linting errors --- tests/benchmark/test_bench_ring_proof.py | 10 +++++++++- tests/test_coverage/test_columns.py | 3 +-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/benchmark/test_bench_ring_proof.py b/tests/benchmark/test_bench_ring_proof.py index 9aa8e85..b6c26da 100644 --- a/tests/benchmark/test_bench_ring_proof.py +++ b/tests/benchmark/test_bench_ring_proof.py @@ -39,7 +39,15 @@ def ring_data(request): def test_prove(benchmark, ring_data): - benchmark(RingVRF[Bandersnatch].prove, ring_data["alpha"], ring_data["ad"], ring_data["s_k"], ring_data["p_k"], ring_data["ring"], ring_data["ring_root"]) + benchmark( + RingVRF[Bandersnatch].prove, + ring_data["alpha"], + ring_data["ad"], + ring_data["s_k"], + ring_data["p_k"], + ring_data["ring"], + ring_data["ring_root"] + ) def test_verify(benchmark, ring_data): diff --git a/tests/test_coverage/test_columns.py b/tests/test_coverage/test_columns.py index bb8f8f4..576fc33 100644 --- a/tests/test_coverage/test_columns.py +++ b/tests/test_coverage/test_columns.py @@ -1,10 +1,9 @@ import pytest -from dot_ring.curve.specs.bandersnatch import Bandersnatch from dot_ring.ring_proof.columns.columns import Column, WitnessColumnBuilder from dot_ring.ring_proof.constants import DEFAULT_SIZE, OMEGAS, S_PRIME from dot_ring.ring_proof.params import RingProofParams -from dot_ring.vrf.ring.ring_root import Ring, RingRoot +from dot_ring.vrf.ring.ring_root import Ring def test_column_interpolate_rejects_oversize_evals(): From 26a260465a667411b74b5f1f7151c84edbcd7b80 Mon Sep 17 00:00:00 2001 From: Prasad Kumkar Date: Wed, 11 Feb 2026 20:53:57 +0530 Subject: [PATCH 3/3] style: clean up whitespace and formatting across multiple files --- dot_ring/ring_proof/columns/columns.py | 1 - dot_ring/ring_proof/polynomial/ops.py | 1 - dot_ring/vrf/ring/ring_root.py | 10 +++++----- dot_ring/vrf/ring/ring_vrf.py | 16 ++++------------ tests/benchmark/bench_ring_large.py | 10 +++++----- tests/benchmark/test_bench_ring_proof.py | 14 +++++++------- 6 files changed, 21 insertions(+), 31 deletions(-) diff --git a/dot_ring/ring_proof/columns/columns.py b/dot_ring/ring_proof/columns/columns.py index 9c5bfbc..9534c06 100644 --- a/dot_ring/ring_proof/columns/columns.py +++ b/dot_ring/ring_proof/columns/columns.py @@ -42,7 +42,6 @@ def commit(self) -> None: self.commitment = KZG.commit(self.coeffs) - @dataclass(slots=True) class WitnessColumnBuilder: ring_pk: list[tuple[int, int]] diff --git a/dot_ring/ring_proof/polynomial/ops.py b/dot_ring/ring_proof/polynomial/ops.py index 451285f..534fd82 100644 --- a/dot_ring/ring_proof/polynomial/ops.py +++ b/dot_ring/ring_proof/polynomial/ops.py @@ -104,7 +104,6 @@ def poly_division_general(coeffs: list[int], domain_size: int) -> list[int]: return quotient - def poly_evaluate(poly: list | Sequence[int], xs: list | int | Sequence[int], prime: int) -> list[int] | int: """Evaluate polynomial at points xs. diff --git a/dot_ring/vrf/ring/ring_root.py b/dot_ring/vrf/ring/ring_root.py index bd6bb7f..b1ea39d 100644 --- a/dot_ring/vrf/ring/ring_root.py +++ b/dot_ring/vrf/ring/ring_root.py @@ -44,7 +44,7 @@ def __init__(self, keys: list[bytes], params: RingProofParams | None = None) -> if len(keys) > params.domain_size - params.padding_rows: raise ValueError(f"ring size {len(keys)} exceeds max supported size {params.domain_size - params.padding_rows}") - + self.nm_points = [] for key in keys: if isinstance(key, (str, bytes)): @@ -56,11 +56,11 @@ def __init__(self, keys: list[bytes], params: RingProofParams | None = None) -> else: # Handle non-string/bytes keys if necessary, or skip/raise continue - + # Pad with special point if needed while len(self.nm_points) < params.max_ring_size: self.nm_points.append(PaddingPoint) - + # Ensure ring size fill_count = params.domain_size - params.padding_rows - len(self.nm_points) if fill_count > 0: @@ -68,14 +68,14 @@ def __init__(self, keys: list[bytes], params: RingProofParams | None = None) -> self.nm_points.extend(h_vec[:fill_count]) if params.padding_rows > 0: self.nm_points.extend([(0, 0)] * params.padding_rows) - + @dataclass class RingRoot: px: Column py: Column s: Column - + @classmethod def from_ring(cls, ring: Ring, params: RingProofParams): # Px, Py, s points diff --git a/dot_ring/vrf/ring/ring_vrf.py b/dot_ring/vrf/ring/ring_vrf.py index 75787ce..585f5f9 100644 --- a/dot_ring/vrf/ring/ring_vrf.py +++ b/dot_ring/vrf/ring/ring_vrf.py @@ -216,10 +216,10 @@ def generate_bls_signature( cast(int, producer_key_point.x), cast(int, producer_key_point.y), ) - + if not ring_root: ring_root = RingRoot.from_ring(ring, params) # ring_root builder - + s_v = ring_root.s.evals producer_index = ring.nm_points.index(producer_key_pt) witness_obj = WitnessColumnBuilder.from_params( @@ -416,9 +416,7 @@ def prove( """ pedersen_proof = PedersenVRF[cast(Any, cls).cv].prove(alpha, secret_key, ad) # type: ignore[misc] - ring_proof = cls.generate_bls_signature( - pedersen_proof._blinding_factor, producer_key, ring=ring, ring_root=ring_root - ) + ring_proof = cls.generate_bls_signature(pedersen_proof._blinding_factor, producer_key, ring=ring, ring_root=ring_root) return cls(pedersen_proof, *ring_proof) @@ -434,13 +432,7 @@ def parse_keys(cls, keys: bytes) -> list[bytes]: """ return [keys[32 * i : 32 * (i + 1)] for i in range(len(keys) // 32)] - def verify( - self, - input: bytes, - ad_data: bytes, - ring: Ring, - ring_root: RingRoot - ) -> bool: + def verify(self, input: bytes, ad_data: bytes, ring: Ring, ring_root: RingRoot) -> bool: """ Verify ring VRF proof (pedersen_proof + ring_proof) """ diff --git a/tests/benchmark/bench_ring_large.py b/tests/benchmark/bench_ring_large.py index f95c28b..0676435 100644 --- a/tests/benchmark/bench_ring_large.py +++ b/tests/benchmark/bench_ring_large.py @@ -84,14 +84,14 @@ def benchmark_large_ring_proof( # ========================================================================= print("Warming up...") for i in range(warmup_iters): - print(f" Warmup iteration {i+1}/{warmup_iters}...") + print(f" Warmup iteration {i + 1}/{warmup_iters}...") # Pass ring_root to avoid rebuilding it ring = Ring(keys, params) ring_root = RingRoot.from_ring(ring, params) ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) ring_vrf_proof.verify(alpha, ad, ring, ring_root) print("Warmup complete") - + # ========================================================================= # Benchmark Ring Root Construction # ========================================================================= @@ -105,13 +105,13 @@ def benchmark_large_ring_proof( root_const_times.append(elapsed) # ========================================================================= - # Benchmark Proof Generation + # Benchmark Proof Generation # ========================================================================= print("\nBenchmarking Proof Generation...") proof_times = [] proofs = [] for i in range(bench_iters): - print(f" Iteration {i+1}/{bench_iters}...") + print(f" Iteration {i + 1}/{bench_iters}...") start = time.perf_counter() # Pass ring_root to avoid rebuilding - this is the key optimization! ring_vrf_proof = RingVRF[Bandersnatch].prove(alpha, ad, s_k, p_k, ring, ring_root) @@ -125,7 +125,7 @@ def benchmark_large_ring_proof( print("\nBenchmarking Verification...") verify_times = [] for i, proof in enumerate(proofs): - print(f" Iteration {i+1}/{bench_iters}...") + print(f" Iteration {i + 1}/{bench_iters}...") start = time.perf_counter() result = proof.verify(alpha, ad, ring, ring_root) elapsed = (time.perf_counter() - start) * 1000 diff --git a/tests/benchmark/test_bench_ring_proof.py b/tests/benchmark/test_bench_ring_proof.py index b6c26da..3d6223e 100644 --- a/tests/benchmark/test_bench_ring_proof.py +++ b/tests/benchmark/test_bench_ring_proof.py @@ -40,13 +40,13 @@ def ring_data(request): def test_prove(benchmark, ring_data): benchmark( - RingVRF[Bandersnatch].prove, - ring_data["alpha"], - ring_data["ad"], - ring_data["s_k"], - ring_data["p_k"], - ring_data["ring"], - ring_data["ring_root"] + RingVRF[Bandersnatch].prove, + ring_data["alpha"], + ring_data["ad"], + ring_data["s_k"], + ring_data["p_k"], + ring_data["ring"], + ring_data["ring_root"], )