diff --git a/.typos.toml b/.typos.toml index 716e3bb44..8fab52f64 100644 --- a/.typos.toml +++ b/.typos.toml @@ -8,7 +8,6 @@ IY = "IY" IZ = "IZ" iz = "iz" anc = "anc" -Pn = "Pn" emiss = "emiss" fo = "fo" # HUGR JSON format uses "typ" as a field name for type information @@ -42,3 +41,13 @@ agger = "agger" CPY = "CPY" # Abbreviation for "undetectable" in fault enumeration debug output UNDET = "UNDET" +# Legitimate "mis-" prefix used in comments/messages (mis-ordered, +# mis-declared, mis-emitted, mis-shape, mis-count, mis-mapping, ...) +mis = "mis" +# Valid spelling of "unparseable" (used in comments + a test name) +unparseable = "unparseable" +Unparseable = "Unparseable" +# Prep-gate naming: the "Pn -> PN" entry-point rename (PNX/PNY/PNZ) +Pn = "Pn" +PN = "PN" +pn = "pn" diff --git a/crates/pecos-core/src/angle.rs b/crates/pecos-core/src/angle.rs index 64d805959..9d82d9711 100644 --- a/crates/pecos-core/src/angle.rs +++ b/crates/pecos-core/src/angle.rs @@ -148,6 +148,46 @@ where } } + /// Converts the angle to turns in `[0, 1)` (the inverse of `from_turns`). + /// + /// # Panics + /// This function will panic if the conversion of `fraction` or `max_value` to `f64` fails. + pub fn to_turns(&self) -> f64 { + let max_value = T::max_value() + .to_f64() + .expect("Failed to convert max_value to f64"); + self.fraction + .to_f64() + .expect("Failed to convert fraction to f64") + / max_value + } + + /// Converts the angle to signed turns in `(-0.5, 0.5]`. + pub fn to_turns_signed(&self) -> f64 { + let t = self.to_turns(); + if t > 0.5 { t - 1.0 } else { t } + } + + /// Converts the angle to half-turns in `[0, 2)` (π radians = 1 half-turn). + /// + /// Half-turns are the unit used by some backends (e.g. Guppy's `angle`), + /// where a full turn is `2.0`. + /// + /// # Panics + /// This function will panic if the conversion of `fraction` or `max_value` to `f64` fails. + pub fn to_half_turns(&self) -> f64 { + self.to_turns() * 2.0 + } + + /// Converts the angle to signed half-turns in `(-1, 1]`. + /// + /// Like [`Self::to_radians_signed`], this principal-value form avoids the + /// spurious global phase that the unsigned `[0, 2)` form introduces when a + /// half-angle (`θ/2`) computation crosses the 2π wrap point. + pub fn to_half_turns_signed(&self) -> f64 { + self.to_turns_signed() * 2.0 + } + /// Creates an angle from a value in radians. /// /// # Panics @@ -668,6 +708,35 @@ mod tests { use rand::RngExt; use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI, TAU}; + #[test] + fn test_to_turns_and_half_turns() { + // Quarter turn = pi/2 rad = 0.25 turns = 0.5 half-turns. + let q = Angle64::QUARTER_TURN; + assert!((q.to_turns() - 0.25).abs() < 1e-12); + assert!((q.to_half_turns() - 0.5).abs() < 1e-12); + assert!((q.to_turns_signed() - 0.25).abs() < 1e-12); + assert!((q.to_half_turns_signed() - 0.5).abs() < 1e-12); + + // Half-turn boundary: to_turns stays unsigned (0.5), signed maps to 0.5. + let h = Angle64::HALF_TURN; + assert!((h.to_turns() - 0.5).abs() < 1e-12); + assert!((h.to_half_turns() - 1.0).abs() < 1e-12); + + // Three-quarter turn: unsigned 0.75 turns; signed wraps to -0.25 turns + // (-0.5 half-turns), mirroring to_radians_signed. + let tq = Angle64::THREE_QUARTERS_TURN; + assert!((tq.to_turns() - 0.75).abs() < 1e-12); + assert!((tq.to_turns_signed() - (-0.25)).abs() < 1e-12); + assert!((tq.to_half_turns_signed() - (-0.5)).abs() < 1e-12); + + // half-turns == radians / pi for the unsigned form. + let a = Angle64::from_radians(0.7); + assert!((a.to_half_turns() - a.to_radians() / PI).abs() < 1e-9); + // round-trip turns. + let b = Angle64::from_turns(0.3); + assert!((b.to_turns() - 0.3).abs() < 1e-9); + } + // Basic Construction and Properties #[test] fn test_constructors() { diff --git a/crates/pecos-core/src/gate_type.rs b/crates/pecos-core/src/gate_type.rs index 2c2c94386..63c532ca8 100644 --- a/crates/pecos-core/src/gate_type.rs +++ b/crates/pecos-core/src/gate_type.rs @@ -97,12 +97,12 @@ pub enum GateType { // TODO: MPauli instead of the other variants? // PX = 130 - // PnX = 131 + // PNX = 131 // PY = 132 - // PnY = 133 + // PNY = 133 // PZ = 134 PZ = 134, - // PnZ + // PNZ /// Allocate a qubit in the |0⟩ state QAlloc = 135, /// Free/deallocate a qubit diff --git a/crates/pecos-llvm/src/llvm_compat.rs b/crates/pecos-llvm/src/llvm_compat.rs index 483a2469d..90ad38f80 100644 --- a/crates/pecos-llvm/src/llvm_compat.rs +++ b/crates/pecos-llvm/src/llvm_compat.rs @@ -226,7 +226,7 @@ impl<'ctx> LLFunctionType<'ctx> { } /// Wrapper for LLVM types that mirrors llvmlite's type hierarchy -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq)] pub enum LLType<'ctx> { Void, Int(IntType<'ctx>), @@ -236,6 +236,38 @@ pub enum LLType<'ctx> { Array(ArrayType<'ctx>), } +// inkwell 0.8.0 only derives `Hash` for `IntType`; the other type wrappers +// are `Eq` (LLVM type-ref pointer equality) but not `Hash`. Hash the same +// `LLVMTypeRef` pointer so `Hash` stays consistent with that `Eq`. +impl std::hash::Hash for LLType<'_> { + fn hash(&self, state: &mut H) { + use inkwell::types::AsTypeRef; + match self { + LLType::Void => 0u8.hash(state), + LLType::Int(t) => { + 1u8.hash(state); + (t.as_type_ref() as usize).hash(state); + } + LLType::Float(t) => { + 2u8.hash(state); + (t.as_type_ref() as usize).hash(state); + } + LLType::Pointer(t) => { + 3u8.hash(state); + (t.as_type_ref() as usize).hash(state); + } + LLType::Struct(t) => { + 4u8.hash(state); + (t.as_type_ref() as usize).hash(state); + } + LLType::Array(t) => { + 5u8.hash(state); + (t.as_type_ref() as usize).hash(state); + } + } + } +} + impl<'ctx> LLType<'ctx> { /// Create void type #[must_use] @@ -717,6 +749,106 @@ impl<'ctx> LLIRBuilder<'ctx> { Ok(LLValue::Pointer(result)) } } + + // ======================================================================== + // Memory ops + casts (unblocks the standard CReg model) + // ======================================================================== + + /// `alloca ` -- stack slot. Caller positions the builder (B2 + /// places `CReg` buffers in the entry block via `position_at_end`). + pub fn alloca(&self, ll_type: LLType<'ctx>, name: &str) -> LLResult> { + let basic_ty = ll_type + .to_basic_metadata_type() + .ok_or_else(|| PecosError::Generic("Cannot alloca a void type".into()))?; + let result = self + .builder + .build_alloca(basic_ty, name) + .map_err(|e| PecosError::Generic(format!("Failed to build alloca: {e}")))?; + Ok(LLValue::Pointer(result)) + } + + /// `load` (LLVM-14 typed pointer: pointee inferred from `ptr`). + pub fn load(&self, ptr: LLValue<'ctx>, name: &str) -> LLResult> { + let result = self + .builder + .build_load(ptr.as_pointer_value(), name) + .map_err(|e| PecosError::Generic(format!("Failed to build load: {e}")))?; + Ok(match result { + BasicValueEnum::IntValue(v) => LLValue::Int(v), + BasicValueEnum::FloatValue(v) => LLValue::Float(v), + BasicValueEnum::PointerValue(v) => LLValue::Pointer(v), + BasicValueEnum::ArrayValue(v) => LLValue::Array(v), + other => { + return Err(PecosError::Generic(format!( + "load: unsupported loaded value type: {other:?}" + ))); + } + }) + } + + /// `store` -- discards inkwell's returned pointer (Python `-> None`). + pub fn store(&self, ptr: LLValue<'ctx>, value: LLValue<'ctx>) -> LLResult<()> { + self.builder + .build_store(ptr.as_pointer_value(), value.to_basic_value()) + .map_err(|e| PecosError::Generic(format!("Failed to build store: {e}")))?; + Ok(()) + } + + /// `zext` int value to a wider int type. + pub fn zext( + &self, + value: LLValue<'ctx>, + dest_type: LLType<'ctx>, + name: &str, + ) -> LLResult> { + let result = self + .builder + .build_int_z_extend(value.as_int_value(), dest_type.as_int_type(), name) + .map_err(|e| PecosError::Generic(format!("Failed to build zext: {e}")))?; + Ok(LLValue::Int(result)) + } + + /// `trunc` int value to a narrower int type. + pub fn trunc( + &self, + value: LLValue<'ctx>, + dest_type: LLType<'ctx>, + name: &str, + ) -> LLResult> { + let result = self + .builder + .build_int_truncate(value.as_int_value(), dest_type.as_int_type(), name) + .map_err(|e| PecosError::Generic(format!("Failed to build trunc: {e}")))?; + Ok(LLValue::Int(result)) + } + + /// Unsigned integer comparison (mirrors `icmp_signed` with U-predicates). + pub fn icmp_unsigned( + &self, + op: &str, + lhs: LLValue<'ctx>, + rhs: LLValue<'ctx>, + name: &str, + ) -> LLResult> { + let predicate = match op { + "==" => IntPredicate::EQ, + "!=" => IntPredicate::NE, + "<" => IntPredicate::ULT, + ">" => IntPredicate::UGT, + "<=" => IntPredicate::ULE, + ">=" => IntPredicate::UGE, + _ => { + return Err(PecosError::Generic(format!( + "Unknown comparison operator: {op}" + ))); + } + }; + let result = self + .builder + .build_int_compare(predicate, lhs.as_int_value(), rhs.as_int_value(), name) + .map_err(|e| PecosError::Generic(format!("Failed to build icmp: {e}")))?; + Ok(LLValue::Int(result)) + } } // ============================================================================ @@ -766,4 +898,18 @@ impl LLConstant { )), } } + + /// Zero/`zeroinitializer` constant of `ll_type` (backs + /// `Constant(ty, None)`; Array -> `zeroinitializer`, Int -> `iN 0`). + pub fn zero(ll_type: LLType<'_>) -> LLResult> { + match ll_type { + LLType::Int(t) => Ok(LLValue::Int(t.const_zero())), + LLType::Float(t) => Ok(LLValue::Float(t.const_zero())), + LLType::Pointer(t) => Ok(LLValue::Pointer(t.const_zero())), + LLType::Array(t) => Ok(LLValue::Array(t.const_zero())), + LLType::Void | LLType::Struct(_) => Err(PecosError::Generic( + "Cannot create a zero constant for void/struct type".to_string(), + )), + } + } } diff --git a/crates/pecos-simulators/src/arbitrary_rotation_gateable.rs b/crates/pecos-simulators/src/arbitrary_rotation_gateable.rs index b38959d62..1ded50a7c 100644 --- a/crates/pecos-simulators/src/arbitrary_rotation_gateable.rs +++ b/crates/pecos-simulators/src/arbitrary_rotation_gateable.rs @@ -254,6 +254,76 @@ pub trait ArbitraryRotationGateable: CliffordGateable { self.rxx(theta, pairs).ryy(phi, pairs).rzz(lambda, pairs) } + /// Applies a controlled-RZ rotation: target qubit gets RZ(theta) when control = |1>. + /// + /// `CRZ(theta) = block-diag(I, RZ(theta)) = diag(1, 1, exp(-i*theta/2), exp(i*theta/2))`. + /// + /// Default 2q-minimal decomposition (1 RZZ + 1 single-qubit RZ on the + /// target): `CRZ(theta) = (I o RZ(theta/2)) . RZZ(-theta/2)`. + /// Verified: with the trait's `RZ = exp(-i*theta/2*Z)` and `RZZ = + /// exp(-i*theta/2*Z*Z)` conventions, the product on the c=0 sector + /// gives `RZ(theta/2) . exp(i*theta/4*I) = I` up to global phase, and + /// on c=1 (where ZZ acts as -Z on target) gives `RZ(theta/2) . X . + /// RZ(theta/2) . X = RZ(theta)` -- i.e. the convention-1 controlled + /// rotation. The non-PECOS-prefactor convention requires no extra + /// RZ on the control. + /// + /// # Parameters + /// - `theta`: The rotation angle on the target. + /// - `pairs`: Pairs of qubit indices `[(control, target), ...]`. + /// + /// # Returns + /// A mutable reference to `Self` for method chaining. + #[inline] + fn crz(&mut self, theta: Angle64, pairs: &[(QubitId, QubitId)]) -> &mut Self { + // Half-angle first, THEN negate -- `Angle` is a wrapping fraction + // of a full turn (modulo 2pi), so `-theta / 2` would halve the wrapped + // 2*pi - theta and produce pi - theta/2, not -theta/2. + let half = theta / 2u64; + let targets: QubitBuf = pairs.iter().map(|&(_, t)| t).collect(); + self.rzz(-half, pairs).rz(half, &targets) + } + + /// Applies a controlled-RX rotation: target qubit gets RX(theta) when control = |1>. + /// + /// Default decomposition: `CRX(theta) = (I o H) . CRZ(theta) . (I o H)`, + /// using `H.Z.H = X` so the c=1 sector applies `H.RZ(theta).H = RX(theta)`. + /// Same 2q cost as `crz` (1 RZZ). + /// + /// # Parameters + /// - `theta`: The rotation angle on the target. + /// - `pairs`: Pairs of qubit indices `[(control, target), ...]`. + /// + /// # Returns + /// A mutable reference to `Self` for method chaining. + #[inline] + fn crx(&mut self, theta: Angle64, pairs: &[(QubitId, QubitId)]) -> &mut Self { + let targets: QubitBuf = pairs.iter().map(|&(_, t)| t).collect(); + self.h(&targets).crz(theta, pairs).h(&targets) + } + + /// Applies a controlled-RY rotation: target qubit gets RY(theta) when control = |1>. + /// + /// Default decomposition: `CRY(theta) = (I o S.H) . CRZ(theta) . (I o H.Sdg)`, + /// using `S.X.Sdg = Y` (so `S.Rx.Sdg = Ry`) and `H.Rz.H = Rx`, giving + /// `S.H.RZ.H.Sdg = RY`. Same 2q cost as `crz` (1 RZZ). + /// + /// # Parameters + /// - `theta`: The rotation angle on the target. + /// - `pairs`: Pairs of qubit indices `[(control, target), ...]`. + /// + /// # Returns + /// A mutable reference to `Self` for method chaining. + #[inline] + fn cry(&mut self, theta: Angle64, pairs: &[(QubitId, QubitId)]) -> &mut Self { + let targets: QubitBuf = pairs.iter().map(|&(_, t)| t).collect(); + self.szdg(&targets) + .h(&targets) + .crz(theta, pairs) + .h(&targets) + .sz(&targets) + } + /// Applies a general 2-qubit unitary via KAK decomposition: /// U = (U3(before[0]) x U3(before[1])) * RXXRYYRZZ(interaction) * (U3(after[0]) x U3(after[1])) /// diff --git a/crates/pecos-simulators/tests/test_state_vec.rs b/crates/pecos-simulators/tests/test_state_vec.rs index ed3dc7aa9..8f9fa71b3 100644 --- a/crates/pecos-simulators/tests/test_state_vec.rs +++ b/crates/pecos-simulators/tests/test_state_vec.rs @@ -2405,7 +2405,7 @@ mod detailed_tq_gate_cases { use num_complex::Complex64; use pecos_core::{Angle64, QubitId}; use pecos_simulators::{ArbitraryRotationGateable, CliffordGateable, StateVec, qid}; - use std::f64::consts::{FRAC_PI_2, FRAC_PI_3, FRAC_PI_4, PI}; + use std::f64::consts::{FRAC_PI_2, FRAC_PI_3, FRAC_PI_4, FRAC_PI_6, PI}; #[test] fn test_cx_decomposition() { @@ -2484,6 +2484,64 @@ mod detailed_tq_gate_cases { assert_states_equal(state_rxx.state(), state_decomposed.state()); } + #[test] + fn test_crz_decomposition() { + // Cross-codegen: default crz impl is (I o RZ(theta/2)) . RZZ(-theta/2). + // Verify the 1-RZZ default matches direct controlled-RZ semantics: + // c=0: identity on target (no rotation applied) + // c=1: target gets RZ(theta) + let theta = Angle64::from_radians(FRAC_PI_3); + + // c=0: |00> stays |00> (no rotation when control = 0). + let mut crz_c0 = StateVec::new(2); + let mut baseline_c0 = StateVec::new(2); + crz_c0.crz(theta, &[(QubitId(0), QubitId(1))]); + assert_states_equal(crz_c0.state(), baseline_c0.state()); + + // c=1: |10> -> |1, RZ(theta)|0>> -- target receives exactly RZ(theta). + let mut crz_c1 = StateVec::new(2); + let mut direct_rz = StateVec::new(2); + crz_c1.x(&qid(0)).crz(theta, &[(QubitId(0), QubitId(1))]); + direct_rz.x(&qid(0)).rz(theta, &qid(1)); + assert_states_equal(crz_c1.state(), direct_rz.state()); + } + + #[test] + fn test_crx_decomposition() { + // Cross-codegen: default crx impl is (I o H) . CRZ . (I o H). + let theta = Angle64::from_radians(FRAC_PI_4); + + let mut crx_c0 = StateVec::new(2); + let mut baseline_c0 = StateVec::new(2); + crx_c0.crx(theta, &[(QubitId(0), QubitId(1))]); + assert_states_equal(crx_c0.state(), baseline_c0.state()); + + // c=1: target gets RX(theta). + let mut crx_c1 = StateVec::new(2); + let mut direct_rx = StateVec::new(2); + crx_c1.x(&qid(0)).crx(theta, &[(QubitId(0), QubitId(1))]); + direct_rx.x(&qid(0)).rx(theta, &qid(1)); + assert_states_equal(crx_c1.state(), direct_rx.state()); + } + + #[test] + fn test_cry_decomposition() { + // Cross-codegen: default cry impl is (I o S.H) . CRZ . (I o H.Sdg). + let theta = Angle64::from_radians(FRAC_PI_6); + + let mut cry_c0 = StateVec::new(2); + let mut baseline_c0 = StateVec::new(2); + cry_c0.cry(theta, &[(QubitId(0), QubitId(1))]); + assert_states_equal(cry_c0.state(), baseline_c0.state()); + + // c=1: target gets RY(theta). + let mut cry_c1 = StateVec::new(2); + let mut direct_ry = StateVec::new(2); + cry_c1.x(&qid(0)).cry(theta, &[(QubitId(0), QubitId(1))]); + direct_ry.x(&qid(0)).ry(theta, &qid(1)); + assert_states_equal(cry_c1.state(), direct_ry.state()); + } + #[test] fn test_two_qubit_unitary_swap_simple() { let mut state_vec = StateVec::new(2); diff --git a/docs/development/parallel-blocks-and-optimization.md b/docs/development/parallel-blocks-and-optimization.md index e345002ed..1fb774022 100644 --- a/docs/development/parallel-blocks-and-optimization.md +++ b/docs/development/parallel-blocks-and-optimization.md @@ -271,7 +271,7 @@ Here's a more complex example showing parallel phase gates: ```python import numpy as np -from pecos.slr import Main, Parallel, QReg +from pecos.slr import Main, Parallel, QReg, rad from pecos.slr.qeclib import qubit as qb @@ -280,7 +280,7 @@ def qft_layer(q, n, k): operations = [] for j in range(k + 1, n): angle = np.pi / (2 ** (j - k)) - operations.append(qb.CRZ[angle](q[j], q[k])) + operations.append(qb.CRZ(rad(angle), q[j], q[k])) return Parallel(*operations) if len(operations) > 1 else operations[0] diff --git a/docs/development/slr-qeclib.md b/docs/development/slr-qeclib.md index 6e4603e3f..0e23d79ec 100644 --- a/docs/development/slr-qeclib.md +++ b/docs/development/slr-qeclib.md @@ -97,10 +97,10 @@ prog = Main( c := CReg("c", 4), # Initialization block Block( - qb.Prep(q[0], "Z"), - qb.Prep(q[1], "Z"), - qb.Prep(q[2], "X"), - qb.Prep(q[3], "X"), + qb.PZ(q[0]), # |0> + qb.PZ(q[1]), # |0> + qb.PX(q[2]), # |+> + qb.PX(q[3]), # |+> ), # Entanglement block Block( @@ -223,7 +223,7 @@ The `qeclib` module provides quantum operations organized by category: ### Qubit Operations (`pecos.slr.qeclib.qubit`) ```python -from pecos.slr import Main, QReg, CReg +from pecos.slr import Main, QReg, CReg, rad from pecos.slr.qeclib import qubit as qb prog = Main( @@ -240,18 +240,18 @@ prog = Main( qb.SZdg(q[0]), # S dagger qb.T(q[0]), # T gate qb.Tdg(q[0]), # T dagger - # Rotations (angle in radians) - qb.RX(q[0], 0.5), - qb.RY(q[0], 0.5), - qb.RZ(q[0], 0.5), + # Rotations (typed angle: rad(...) / turns(...)) + qb.RX(rad(0.5), q[0]), + qb.RY(rad(0.5), q[0]), + qb.RZ(rad(0.5), q[0]), # Two-qubit gates qb.CX(q[0], q[1]), # CNOT qb.CY(q[0], q[1]), qb.CZ(q[0], q[1]), # Measurements and preparations qb.Measure(q[0]) > c[0], - qb.Prep(q[0], "Z"), # Prepare |0> - qb.Prep(q[0], "X"), # Prepare |+> + qb.PZ(q[0]), # Prepare |0> + qb.PX(q[0]), # Prepare |+> ) ``` @@ -429,11 +429,10 @@ def surface_code_syndrome(d: int): ancilla := QReg("anc", num_ancilla), syn := CReg("syn", num_ancilla), # Initialize data qubits - Block(*[qb.Prep(data[i], "Z") for i in range(num_data)]), + Block(*[qb.PZ(data[i]) for i in range(num_data)]), # X stabilizer measurement (simplified) Block( - qb.Prep(ancilla[0], "X"), # Prepare |+> - qb.H(ancilla[0]), + qb.PX(ancilla[0]), # |+> qb.CX(ancilla[0], data[0]), qb.CX(ancilla[0], data[1]), qb.H(ancilla[0]), @@ -441,7 +440,7 @@ def surface_code_syndrome(d: int): ), # Z stabilizer measurement (simplified) Block( - qb.Prep(ancilla[1], "Z"), # Prepare |0> + qb.PZ(ancilla[1]), # Prepare |0> qb.CX(data[0], ancilla[1]), qb.CX(data[3], ancilla[1]), qb.Measure(ancilla[1]) > syn[1], diff --git a/docs/user-guide/parallel-execution.md b/docs/user-guide/parallel-execution.md index 418f02789..e045e5d15 100644 --- a/docs/user-guide/parallel-execution.md +++ b/docs/user-guide/parallel-execution.md @@ -160,7 +160,7 @@ Parallel( Here's a complete example showing parallel quantum phase estimation: ```python -from pecos.slr import Main, Parallel, Block, QReg, CReg, SlrConverter +from pecos.slr import Main, Parallel, Block, QReg, CReg, SlrConverter, rad from pecos.slr.qeclib import qubit as qb import numpy as np @@ -175,15 +175,15 @@ prog = Main( qb.H(q[2]), ), # Apply controlled rotations - qb.CRZ[np.pi](q[0], q[3]), - qb.CRZ[np.pi / 2](q[1], q[3]), - qb.CRZ[np.pi / 4](q[2], q[3]), + qb.CRZ(rad(np.pi), q[0], q[3]), + qb.CRZ(rad(np.pi / 2), q[1], q[3]), + qb.CRZ(rad(np.pi / 4), q[2], q[3]), # Inverse QFT on ancillas qb.H(q[0]), - qb.CRZ[-np.pi / 2](q[0], q[1]), + qb.CRZ(rad(-np.pi / 2), q[0], q[1]), qb.H(q[1]), - qb.CRZ[-np.pi / 4](q[0], q[2]), - qb.CRZ[-np.pi / 2](q[1], q[2]), + qb.CRZ(rad(-np.pi / 4), q[0], q[2]), + qb.CRZ(rad(-np.pi / 2), q[1], q[2]), qb.H(q[2]), # Measure ancillas Parallel( diff --git a/examples/python_examples/logical_steane_code_program.py b/examples/python_examples/logical_steane_code_program.py index a9c6520c8..bb1a758da 100644 --- a/examples/python_examples/logical_steane_code_program.py +++ b/examples/python_examples/logical_steane_code_program.py @@ -16,8 +16,8 @@ computation with error correction circuits. """ -from pecos.qeclib.steane.steane_class import Steane -from pecos.slr import Barrier, CReg, If, Main +from pecos.slr import Barrier, CReg, If, Main, Return, SlrConverter +from pecos.slr.qeclib.steane.steane_class import Steane # Turn of Black's formatting to allow for newline spacing below: # fmt: off @@ -73,9 +73,12 @@ def telep(prep_basis: str, meas_basis: str) -> str: # Final output stored in `m_out[0]` sout.m(meas_basis, m_out[0]), + + # Expose classical results explicitly (Phase 3b: no implicit return). + Return(m_bell, m_out), ) - return prog.qasm() # Convert the program to extended OpenQASM 2.0 + return SlrConverter(prog).qasm() # Convert the program to extended OpenQASM 2.0 def t_gate(prep_basis: str, meas_basis: str) -> str: @@ -121,8 +124,11 @@ def t_gate(prep_basis: str, meas_basis: str) -> str: # Final output stored in `m_out[1]` sin.m(meas_basis, m_out[1]), + + # Expose classical results explicitly (Phase 3b: no implicit return). + Return(m_reject, m_t, m_out), ) - return prog.qasm() + return SlrConverter(prog).qasm() # fmt: on diff --git a/examples/surface/build_report.py b/examples/surface/build_report.py index ea392b7b6..71b71fae2 100644 --- a/examples/surface/build_report.py +++ b/examples/surface/build_report.py @@ -661,7 +661,8 @@ def _build_html(analysis: dict) -> str: tables = analysis.get("comparison_tables", []) curves = analysis.get("threshold_curves", []) - style = dedent(""" + style = dedent( + """ :root { color-scheme: light dark; --bg: #f8fafc; --fg: #0f172a; @@ -762,7 +763,8 @@ def _build_html(analysis: dict) -> str: details.collapsible > summary::before { content: "\\25B6 "; font-size: 0.8em; } details.collapsible[open] > summary::before { content: "\\25BC "; } details.collapsible > .section { margin-top: 8px; } - """).strip() + """, + ).strip() def meta_card(label: str, value: str, *, raw: bool = False) -> str: val = value if raw else html_mod.escape(value) @@ -1051,7 +1053,8 @@ def _violin_plots_for_p(p_val: float) -> list[str]: parts.extend( [ "", - dedent(""" + dedent( + """ - """).strip(), + """, + ).strip(), "", "", ], diff --git a/examples/surface/decoder_comparison.py b/examples/surface/decoder_comparison.py index d1b9ea91d..926b5de7b 100644 --- a/examples/surface/decoder_comparison.py +++ b/examples/surface/decoder_comparison.py @@ -268,7 +268,8 @@ def write_json(path: Path, points: list[ComparisonPoint], config: dict) -> None: def write_html(path: Path, points: list[ComparisonPoint], config: dict) -> None: """Write an HTML report with comparison tables.""" - style = dedent(""" + style = dedent( + """ :root { color-scheme: light dark; --bg: #f8fafc; --fg: #0f172a; @@ -351,7 +352,8 @@ def write_html(path: Path, points: list[ComparisonPoint], config: dict) -> None: td:first-child, th:first-child { text-align: left; } tr:nth-child(even) td { background: var(--table-stripe); } code { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; } - """).strip() + """, + ).strip() def meta_card(label: str, value: str) -> str: return f'
{html.escape(label)}{html.escape(value)}
' @@ -435,7 +437,8 @@ def meta_card(label: str, value: str) -> str: parts.extend( [ "", - dedent(""" + dedent( + """ - """).strip(), + """, + ).strip(), "", "", ], diff --git a/examples/surface_code_slr_exploration.ipynb b/examples/surface_code_slr_exploration.ipynb index 4af7f0ec4..b793de75d 100644 --- a/examples/surface_code_slr_exploration.ipynb +++ b/examples/surface_code_slr_exploration.ipynb @@ -26,20 +26,7 @@ "id": "imports", "metadata": {}, "outputs": [], - "source": [ - "# Our circuit builder for comparison\n", - "from pecos.qec.surface import (\n", - " SurfacePatch,\n", - " generate_guppy_from_patch,\n", - " generate_stim_from_patch,\n", - " generate_tick_circuit_from_patch,\n", - ")\n", - "from pecos.qec.surface.schedule import compute_cnot_schedule\n", - "from pecos.slr import Barrier, Main, QReg\n", - "from pecos.slr.gen_codes.gen_stim import StimGenerator\n", - "from pecos.slr.gen_codes.guppy import IRGuppyGenerator\n", - "from pecos.slr.qeclib.qubit import CX, H, Measure, Prep" - ] + "source": "# Our circuit builder for comparison\nfrom pecos.qec.surface import (\n SurfacePatch,\n generate_guppy_from_patch,\n generate_stim_from_patch,\n generate_tick_circuit_from_patch,\n)\nfrom pecos.qec.surface.schedule import compute_cnot_schedule\nfrom pecos.slr import Barrier, Main, QReg, SlrConverter\nfrom pecos.slr.gen_codes.gen_stim import StimGenerator\nfrom pecos.slr.qeclib.qubit import CX, H, Measure, Prep" }, { "cell_type": "markdown", @@ -234,7 +221,7 @@ "id": "slr-to-guppy", "metadata": {}, "outputs": [], - "source": "# Try generating Guppy from SLR\ntry:\n guppy_gen = IRGuppyGenerator()\n guppy_gen.generate_block(slr_prog)\n slr_guppy = guppy_gen.get_output()\n print(\"=== SLR -> Guppy ===\")\n print(slr_guppy[:2000] if len(slr_guppy) > 2000 else slr_guppy)\n if len(slr_guppy) > 2000:\n print(f\"... ({len(slr_guppy) - 2000} more characters)\")\nexcept (ValueError, TypeError, AttributeError) as e:\n print(f\"Guppy generation failed: {e}\")\n slr_guppy = None" + "source": "# Generate Guppy from SLR via the AST -> Guppy path (SlrConverter.guppy()).\n# (The legacy IRGuppyGenerator was removed in the v1 cutover; SlrConverter\n# is now the supported entrypoint for all codegens.)\ntry:\n slr_guppy = SlrConverter(slr_prog).guppy()\n print(\"=== SLR -> Guppy ===\")\n print(slr_guppy[:2000] if len(slr_guppy) > 2000 else slr_guppy)\n if len(slr_guppy) > 2000:\n print(f\"... ({len(slr_guppy) - 2000} more characters)\")\nexcept (ValueError, TypeError, AttributeError) as e:\n print(f\"Guppy generation failed: {e}\")\n slr_guppy = None" }, { "cell_type": "code", @@ -312,57 +299,7 @@ "id": "gap-analysis", "metadata": {}, "outputs": [], - "source": [ - "gaps = [\n", - " {\n", - " \"feature\": \"TickCircuit Generator\",\n", - " \"status\": \"Missing\",\n", - " \"description\": \"SLR has no TickCircuitGenerator. Would need to add one similar to StimGenerator.\",\n", - " \"effort\": \"Medium\",\n", - " },\n", - " {\n", - " \"feature\": \"Detector Annotations\",\n", - " \"status\": \"Missing\",\n", - " \"description\": \"SLR has no concept of detectors/observables for DEM generation. \"\n", - " \"Would need to add metadata support or a new annotation type.\",\n", - " \"effort\": \"Medium-High\",\n", - " },\n", - " {\n", - " \"feature\": \"Tick/Phase Metadata\",\n", - " \"status\": \"Partial\",\n", - " \"description\": \"SLR has Barrier() which maps to TICK, but no semantic phase annotations \"\n", - " \"(prep, syndrome, measure). Could add via Comments or new metadata.\",\n", - " \"effort\": \"Low\",\n", - " },\n", - " {\n", - " \"feature\": \"Gate-level Metadata\",\n", - " \"status\": \"Missing\",\n", - " \"description\": \"No way to annotate individual gates with labels (e.g., 'sx0', 'Z2').\",\n", - " \"effort\": \"Low-Medium\",\n", - " },\n", - " {\n", - " \"feature\": \"Stim Generation\",\n", - " \"status\": \"Working\",\n", - " \"description\": \"StimGenerator produces valid Stim circuits. Gate counts match.\",\n", - " \"effort\": \"N/A\",\n", - " },\n", - " {\n", - " \"feature\": \"Guppy Generation\",\n", - " \"status\": \"Needs Testing\",\n", - " \"description\": \"IRGuppyGenerator exists but may need updates for surface code patterns.\",\n", - " \"effort\": \"TBD\",\n", - " },\n", - "]\n", - "\n", - "print(\"=== SLR Enhancement Gap Analysis ===\")\n", - "print()\n", - "for gap in gaps:\n", - " print(f\"Feature: {gap['feature']}\")\n", - " print(f\" Status: {gap['status']}\")\n", - " print(f\" Description: {gap['description']}\")\n", - " print(f\" Effort: {gap['effort']}\")\n", - " print()" - ] + "source": "gaps = [\n {\n \"feature\": \"TickCircuit Generator\",\n \"status\": \"Missing\",\n \"description\": \"SLR has no TickCircuitGenerator. Would need to add one similar to StimGenerator.\",\n \"effort\": \"Medium\",\n },\n {\n \"feature\": \"Detector Annotations\",\n \"status\": \"Missing\",\n \"description\": \"SLR has no concept of detectors/observables for DEM generation. \"\n \"Would need to add metadata support or a new annotation type.\",\n \"effort\": \"Medium-High\",\n },\n {\n \"feature\": \"Tick/Phase Metadata\",\n \"status\": \"Partial\",\n \"description\": \"SLR has Barrier() which maps to TICK, but no semantic phase annotations \"\n \"(prep, syndrome, measure). Could add via Comments or new metadata.\",\n \"effort\": \"Low\",\n },\n {\n \"feature\": \"Gate-level Metadata\",\n \"status\": \"Missing\",\n \"description\": \"No way to annotate individual gates with labels (e.g., 'sx0', 'Z2').\",\n \"effort\": \"Low-Medium\",\n },\n {\n \"feature\": \"Stim Generation\",\n \"status\": \"Working\",\n \"description\": \"StimGenerator produces valid Stim circuits. Gate counts match.\",\n \"effort\": \"N/A\",\n },\n {\n \"feature\": \"Guppy Generation\",\n \"status\": \"Working\",\n \"description\": \"SlrConverter(prog).guppy() (AST -> Guppy path) is the supported \"\n \"route; the legacy IRGuppyGenerator was removed in the v1 cutover.\",\n \"effort\": \"N/A\",\n },\n]\n\nprint(\"=== SLR Enhancement Gap Analysis ===\")\nprint()\nfor gap in gaps:\n print(f\"Feature: {gap['feature']}\")\n print(f\" Status: {gap['status']}\")\n print(f\" Description: {gap['description']}\")\n print(f\" Effort: {gap['effort']}\")\n print()" }, { "cell_type": "markdown", diff --git a/pyproject.toml b/pyproject.toml index 9d1575279..f7138dca9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ test = [ # exact pins so workspace tests run against a reproducible environment "hypothesis==6.152.1", "stim==1.15.0", # Stim-comparison and decomposition-invariant tests "matplotlib>=2.2.0", # Surface-patch render tests import matplotlib directly + "qir-qis==0.1.8", # QIR spec-compliance validation gate over the audit corpus (#71 Stage A) ] examples = [ # extras used by the examples/ tree and notebook walkthroughs "jupyter>=1.1.1", diff --git a/python/pecos-rslib-cuda/src/lib.rs b/python/pecos-rslib-cuda/src/lib.rs index e68c01c60..be572de4d 100644 --- a/python/pecos-rslib-cuda/src/lib.rs +++ b/python/pecos-rslib-cuda/src/lib.rs @@ -392,6 +392,33 @@ impl PyCuStateVec { self.inner.rzz(Angle64::from_radians(angle), &pairs); } + /// Apply controlled-RX gate. Pairs of qubits = (control, target). + fn crx(&mut self, angle: f64, qubits: Vec) { + let pairs: Vec<(QubitId, QubitId)> = qubits + .chunks_exact(2) + .map(|c| (QubitId(c[0]), QubitId(c[1]))) + .collect(); + self.inner.crx(Angle64::from_radians(angle), &pairs); + } + + /// Apply controlled-RY gate. Pairs of qubits = (control, target). + fn cry(&mut self, angle: f64, qubits: Vec) { + let pairs: Vec<(QubitId, QubitId)> = qubits + .chunks_exact(2) + .map(|c| (QubitId(c[0]), QubitId(c[1]))) + .collect(); + self.inner.cry(Angle64::from_radians(angle), &pairs); + } + + /// Apply controlled-RZ gate. Pairs of qubits = (control, target). + fn crz(&mut self, angle: f64, qubits: Vec) { + let pairs: Vec<(QubitId, QubitId)> = qubits + .chunks_exact(2) + .map(|c| (QubitId(c[0]), QubitId(c[1]))) + .collect(); + self.inner.crz(Angle64::from_radians(angle), &pairs); + } + /// Apply U gate (general single-qubit rotation). fn u(&mut self, theta: f64, phi: f64, lambda: f64, qubits: Vec) { let qubits: Vec = qubits.into_iter().map(QubitId).collect(); diff --git a/python/pecos-rslib-llvm/src/llvm_bindings.rs b/python/pecos-rslib-llvm/src/llvm_bindings.rs index f295e1ad8..4cf3cefbf 100644 --- a/python/pecos-rslib-llvm/src/llvm_bindings.rs +++ b/python/pecos-rslib-llvm/src/llvm_bindings.rs @@ -31,6 +31,7 @@ use pecos_llvm::prelude::*; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; +use pyo3::pyclass::CompareOp; use regex::Regex; use std::collections::HashMap; use std::sync::{Mutex, OnceLock}; @@ -394,7 +395,7 @@ impl PyModuleContext { // --- Type Classes --- /// Enum to handle any type for function parameters -#[derive(Copy, Clone, FromPyObject)] +#[derive(Copy, Clone, FromPyObject, IntoPyObject)] pub enum PyAnyType { Int(PyIntType), Double(PyDoubleType), @@ -415,6 +416,50 @@ impl PyAnyType { PyAnyType::Array(t) => t.ll_type, } } + + /// Underlying `LLType` (context-free; `to_ll_type` ignores its arg). + fn ll_type(&self) -> LLType<'static> { + match self { + PyAnyType::Int(t) => t.ll_type, + PyAnyType::Double(t) => t.ll_type, + PyAnyType::Void(_) => LLType::Void, + PyAnyType::Pointer(t) => t.ll_type, + PyAnyType::Struct(t) => LLType::Struct(t.struct_type), + PyAnyType::Array(t) => t.ll_type, + } + } +} + +/// Type equality by LLVM `LLVMTypeRef` identity within one `Context` +/// (sufficient for #71 B2: every type check happens inside one module). +/// `Eq`/`Ne` only; non-type `other` -> not-equal (`==` False, `!=` True). +/// Ordering ops are unsupported -> `NotImplemented` so Python raises +/// `TypeError` rather than returning a silently-wrong `False`. +fn lltype_richcmp( + py: Python<'_>, + a: LLType<'static>, + other: &Bound<'_, PyAny>, + op: CompareOp, +) -> Py { + match op { + CompareOp::Eq | CompareOp::Ne => { + let eq = other.extract::().is_ok_and(|o| a == o.ll_type()); + let val = if matches!(op, CompareOp::Eq) { eq } else { !eq }; + val.into_pyobject(py) + .expect("bool -> PyBool is infallible") + .to_owned() + .into_any() + .unbind() + } + _ => py.NotImplemented(), + } +} + +/// Hash consistent with `lltype_richcmp` (equal types hash equal). +fn lltype_hash(a: LLType<'static>) -> u64 { + let mut h = std::collections::hash_map::DefaultHasher::new(); + std::hash::Hash::hash(&a, &mut h); + std::hash::Hasher::finish(&h) } /// Python wrapper for struct types @@ -455,6 +500,13 @@ unsafe impl Sync for PyPointerType {} #[pymethods] impl PyPointerType { + fn __richcmp__(&self, py: Python<'_>, other: &Bound<'_, PyAny>, op: CompareOp) -> Py { + lltype_richcmp(py, self.ll_type, other, op) + } + fn __hash__(&self) -> u64 { + lltype_hash(self.ll_type) + } + fn as_pointer(&self) -> PyPointerType { let context = unsafe { &*self.context_ptr }; let ptr_type = self.ll_type.as_pointer(context); @@ -478,6 +530,13 @@ unsafe impl Sync for PyIntType {} #[pymethods] impl PyIntType { + fn __richcmp__(&self, py: Python<'_>, other: &Bound<'_, PyAny>, op: CompareOp) -> Py { + lltype_richcmp(py, self.ll_type, other, op) + } + fn __hash__(&self) -> u64 { + lltype_hash(self.ll_type) + } + fn as_pointer(&self) -> PyPointerType { let context = unsafe { &*self.context_ptr }; let ptr_type = self.ll_type.as_pointer(context); @@ -510,6 +569,13 @@ unsafe impl Sync for PyDoubleType {} #[pymethods] impl PyDoubleType { + fn __richcmp__(&self, py: Python<'_>, other: &Bound<'_, PyAny>, op: CompareOp) -> Py { + lltype_richcmp(py, self.ll_type, other, op) + } + fn __hash__(&self) -> u64 { + lltype_hash(self.ll_type) + } + fn as_pointer(&self) -> PyPointerType { let context = unsafe { &*self.context_ptr }; let ptr_type = self.ll_type.as_pointer(context); @@ -542,6 +608,13 @@ unsafe impl Sync for PyArrayType {} #[pymethods] impl PyArrayType { + fn __richcmp__(&self, py: Python<'_>, other: &Bound<'_, PyAny>, op: CompareOp) -> Py { + lltype_richcmp(py, self.ll_type, other, op) + } + fn __hash__(&self) -> u64 { + lltype_hash(self.ll_type) + } + #[new] fn new(element_type: PyAnyType, count: u32) -> Self { // Extract context pointer from element type @@ -584,6 +657,20 @@ pub struct PyVoidType { unsafe impl Send for PyVoidType {} unsafe impl Sync for PyVoidType {} +#[pymethods] +impl PyVoidType { + // `&self` is mandated by pyo3 `#[pymethods]` dunders; the void type + // carries no per-instance state, so `self` is unused here by design. + #[allow(clippy::unused_self, clippy::trivially_copy_pass_by_ref)] + fn __richcmp__(&self, py: Python<'_>, other: &Bound<'_, PyAny>, op: CompareOp) -> Py { + lltype_richcmp(py, LLType::Void, other, op) + } + #[allow(clippy::unused_self, clippy::trivially_copy_pass_by_ref)] + fn __hash__(&self) -> u64 { + lltype_hash(LLType::Void) + } +} + // --- IRBuilder - Instruction builder --- /// Python wrapper for LLVM IR instruction builder @@ -843,6 +930,92 @@ impl PyIRBuilder { }) } + /// Allocate a stack slot (`alloca `); caller positions the builder. + #[pyo3(signature = (ty, name=""))] + fn alloca(&mut self, ty: PyAnyType, name: &str) -> PyResult { + let builder = unsafe { &mut *self.builder_ptr }; + let context = unsafe { &*self.context_ptr }; + let ll_type = ty.to_ll_type(context); + let result = builder + .alloca(ll_type, name) + .map_err(|e| PyRuntimeError::new_err(format!("alloca failed: {e}")))?; + Ok(PyLLValue { + value: result, + context_ptr: self.context_ptr, + }) + } + + /// Load from a pointer (`load`; LLVM-14 typed-pointer pointee). + #[pyo3(signature = (ptr, name=""))] + fn load(&mut self, ptr: PyLLValue, name: &str) -> PyResult { + let builder = unsafe { &mut *self.builder_ptr }; + let result = builder + .load(ptr.value, name) + .map_err(|e| PyRuntimeError::new_err(format!("load failed: {e}")))?; + Ok(PyLLValue { + value: result, + context_ptr: self.context_ptr, + }) + } + + /// Store a value to a pointer (`store`). + fn store(&mut self, ptr: PyLLValue, value: PyLLValue) -> PyResult<()> { + let builder = unsafe { &mut *self.builder_ptr }; + builder + .store(ptr.value, value.value) + .map_err(|e| PyRuntimeError::new_err(format!("store failed: {e}")))?; + Ok(()) + } + + /// Zero-extend an integer value to a wider int type (`zext`). + #[pyo3(signature = (value, ty, name=""))] + fn zext(&mut self, value: PyLLValue, ty: PyAnyType, name: &str) -> PyResult { + let builder = unsafe { &mut *self.builder_ptr }; + let context = unsafe { &*self.context_ptr }; + let dest = ty.to_ll_type(context); + let result = builder + .zext(value.value, dest, name) + .map_err(|e| PyRuntimeError::new_err(format!("zext failed: {e}")))?; + Ok(PyLLValue { + value: result, + context_ptr: self.context_ptr, + }) + } + + /// Truncate an integer value to a narrower int type (`trunc`). + #[pyo3(signature = (value, ty, name=""))] + fn trunc(&mut self, value: PyLLValue, ty: PyAnyType, name: &str) -> PyResult { + let builder = unsafe { &mut *self.builder_ptr }; + let context = unsafe { &*self.context_ptr }; + let dest = ty.to_ll_type(context); + let result = builder + .trunc(value.value, dest, name) + .map_err(|e| PyRuntimeError::new_err(format!("trunc failed: {e}")))?; + Ok(PyLLValue { + value: result, + context_ptr: self.context_ptr, + }) + } + + /// Unsigned integer comparison (`icmp` with U-predicates). + #[pyo3(signature = (cmp_op, lhs, rhs, name=""))] + fn icmp_unsigned( + &mut self, + cmp_op: &str, + lhs: PyLLValue, + rhs: PyLLValue, + name: &str, + ) -> PyResult { + let builder = unsafe { &mut *self.builder_ptr }; + let result = builder + .icmp_unsigned(cmp_op, lhs.value, rhs.value, name) + .map_err(|e| PyRuntimeError::new_err(format!("icmp_unsigned failed: {e}")))?; + Ok(PyLLValue { + value: result, + context_ptr: self.context_ptr, + }) + } + /// Position builder at end of block fn position_at_end(&mut self, block: PyBasicBlock) { let builder = unsafe { &mut *self.builder_ptr }; @@ -1241,6 +1414,30 @@ impl PyLLValue { context_ptr: self.context_ptr, }) } + + /// The LLVM type of this value (llvmlite parity: `value.type`). + #[getter] + #[pyo3(name = "type")] + fn type_(&self) -> PyAnyType { + match &self.value { + LLValue::Int(v) => PyAnyType::Int(PyIntType { + ll_type: LLType::Int(v.get_type()), + context_ptr: self.context_ptr, + }), + LLValue::Float(v) => PyAnyType::Double(PyDoubleType { + ll_type: LLType::Float(v.get_type()), + context_ptr: self.context_ptr, + }), + LLValue::Pointer(v) => PyAnyType::Pointer(PyPointerType { + ll_type: LLType::Pointer(v.get_type()), + context_ptr: self.context_ptr, + }), + LLValue::Array(v) => PyAnyType::Array(PyArrayType { + ll_type: LLType::Array(v.get_type()), + context_ptr: self.context_ptr, + }), + } + } } // --- GlobalVariable - Global variable support --- @@ -1348,12 +1545,33 @@ impl PyGlobalVariable { /// ``` #[pyfunction] #[allow(non_snake_case)] -fn Constant(_py: Python, ty: PyAnyType, value: &Bound<'_, PyAny>) -> PyResult { +#[pyo3(signature = (ty, value=None))] +fn Constant(_py: Python, ty: PyAnyType, value: Option<&Bound<'_, PyAny>>) -> PyResult { // Check type isn't void (llvmlite doesn't allow void constants) if matches!(ty, PyAnyType::Void(_)) { return Err(PyRuntimeError::new_err("Cannot create void constant")); } + // `Constant(ty)` / `Constant(ty, None)` -> the type's zeroinitializer + // (Array -> `zeroinitializer`, Int -> `iN 0`); backs B2's zero-init buffer. + let Some(value) = value else { + let context_ptr = match &ty { + PyAnyType::Int(t) => t.context_ptr, + PyAnyType::Double(t) => t.context_ptr, + PyAnyType::Void(t) => t.context_ptr, + PyAnyType::Pointer(t) => t.context_ptr, + PyAnyType::Struct(t) => t.context_ptr, + PyAnyType::Array(t) => t.context_ptr, + }; + let context = unsafe { &*context_ptr }; + let zero = LLConstant::zero(ty.to_ll_type(context)) + .map_err(|e| PyRuntimeError::new_err(format!("Constant(zero) failed: {e}")))?; + return Ok(PyLLValue { + value: zero, + context_ptr, + }); + }; + // Handle different type/value combinations match &ty { PyAnyType::Int(int_ty) => { diff --git a/python/pecos-rslib/src/dtypes.rs b/python/pecos-rslib/src/dtypes.rs index 8f799086d..45ad86d39 100644 --- a/python/pecos-rslib/src/dtypes.rs +++ b/python/pecos-rslib/src/dtypes.rs @@ -3991,6 +3991,26 @@ impl ScalarAngle64 { self.value.to_radians_signed() } + /// Convert to turns (in [0, 1)) -- the inverse of `from_turns` + fn to_turns(&self) -> f64 { + self.value.to_turns() + } + + /// Convert to signed turns (in (-0.5, 0.5]) + fn to_turns_signed(&self) -> f64 { + self.value.to_turns_signed() + } + + /// Convert to half-turns (in [0, 2)); pi radians = 1.0 half-turn + fn to_half_turns(&self) -> f64 { + self.value.to_half_turns() + } + + /// Convert to signed half-turns (in (-1, 1]) + fn to_half_turns_signed(&self) -> f64 { + self.value.to_half_turns_signed() + } + /// Get the raw u64 fraction #[getter] fn fraction(&self) -> u64 { diff --git a/python/pecos-rslib/src/pecos_array.rs b/python/pecos-rslib/src/pecos_array.rs index a926d49a7..ded6972de 100644 --- a/python/pecos-rslib/src/pecos_array.rs +++ b/python/pecos-rslib/src/pecos_array.rs @@ -4702,7 +4702,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -4732,7 +4732,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -4762,7 +4762,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -4792,7 +4792,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -4822,7 +4822,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -4948,7 +4948,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -4978,7 +4978,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -5008,7 +5008,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -5038,7 +5038,7 @@ impl Array { let mut result = arr.clone(); for (axis, start, stop, step) in slices { if step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis @@ -5484,7 +5484,7 @@ impl Array { // So we can use the slice params as-is, just on the current_axis. if *step < 0 { - // ndarray's Slice doesn't match NumPy for negative steps (see issue #312) + // ndarray's Slice doesn't match NumPy for negative steps // We need to manually implement NumPy's behavior: // 1. Slice forward [stop+1, start+1] with step=1 // 2. Reverse the axis diff --git a/python/pecos-rslib/src/simulator_utils.rs b/python/pecos-rslib/src/simulator_utils.rs index 23bc059fd..c61f7537d 100644 --- a/python/pecos-rslib/src/simulator_utils.rs +++ b/python/pecos-rslib/src/simulator_utils.rs @@ -385,7 +385,7 @@ pub fn try_clifford_batch_dispatch( sim.pz(&collect_single_qubits(locations)?); return Ok(Some(PyDict::new(py).into())); } - "PnZ" | "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" => { + "PNZ" | "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" => { sim.pnz(&collect_single_qubits(locations)?); return Ok(Some(PyDict::new(py).into())); } @@ -393,7 +393,7 @@ pub fn try_clifford_batch_dispatch( sim.px(&collect_single_qubits(locations)?); return Ok(Some(PyDict::new(py).into())); } - "PnX" | "Init -X" | "init |->" => { + "PNX" | "Init -X" | "init |->" => { sim.pnx(&collect_single_qubits(locations)?); return Ok(Some(PyDict::new(py).into())); } @@ -401,7 +401,7 @@ pub fn try_clifford_batch_dispatch( sim.py(&collect_single_qubits(locations)?); return Ok(Some(PyDict::new(py).into())); } - "PnY" | "Init -Y" | "init |-i>" => { + "PNY" | "Init -Y" | "init |-i>" => { sim.pny(&collect_single_qubits(locations)?); return Ok(Some(PyDict::new(py).into())); } diff --git a/python/pecos-rslib/src/sparse_sim.rs b/python/pecos-rslib/src/sparse_sim.rs index 400234425..1276ac93e 100644 --- a/python/pecos-rslib/src/sparse_sim.rs +++ b/python/pecos-rslib/src/sparse_sim.rs @@ -158,15 +158,15 @@ impl SparseSim { self.inner.py(q); Ok(None) } - "PnZ" => { + "PNZ" => { self.inner.pnz(q); Ok(None) } - "PnX" => { + "PNX" => { self.inner.pnx(q); Ok(None) } - "PnY" => { + "PNY" => { self.inner.pny(q); Ok(None) } diff --git a/python/pecos-rslib/src/sparse_stab_bindings.rs b/python/pecos-rslib/src/sparse_stab_bindings.rs index 83de68aab..0c3795c8c 100644 --- a/python/pecos-rslib/src/sparse_stab_bindings.rs +++ b/python/pecos-rslib/src/sparse_stab_bindings.rs @@ -240,7 +240,7 @@ impl PySparseStab { self.inner.pz(q); Ok(None) } - "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PnZ" => { + "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PNZ" => { self.inner.pnz(q); Ok(None) } @@ -248,7 +248,7 @@ impl PySparseStab { self.inner.px(q); Ok(None) } - "Init -X" | "init |->" | "PnX" => { + "Init -X" | "init |->" | "PNX" => { self.inner.pnx(q); Ok(None) } @@ -256,7 +256,7 @@ impl PySparseStab { self.inner.py(q); Ok(None) } - "Init -Y" | "init |-i>" | "PnY" => { + "Init -Y" | "init |-i>" | "PNY" => { self.inner.pny(q); Ok(None) } diff --git a/python/pecos-rslib/src/stab_bindings.rs b/python/pecos-rslib/src/stab_bindings.rs index 67eafd047..5f58d7561 100644 --- a/python/pecos-rslib/src/stab_bindings.rs +++ b/python/pecos-rslib/src/stab_bindings.rs @@ -242,7 +242,7 @@ impl PyStabilizer { self.inner.pz(q); Ok(None) } - "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PnZ" => { + "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PNZ" => { self.inner.pnz(q); Ok(None) } @@ -250,7 +250,7 @@ impl PyStabilizer { self.inner.px(q); Ok(None) } - "Init -X" | "init |->" | "PnX" => { + "Init -X" | "init |->" | "PNX" => { self.inner.pnx(q); Ok(None) } @@ -258,7 +258,7 @@ impl PyStabilizer { self.inner.py(q); Ok(None) } - "Init -Y" | "init |-i>" | "PnY" => { + "Init -Y" | "init |-i>" | "PNY" => { self.inner.pny(q); Ok(None) } diff --git a/python/pecos-rslib/src/stab_vec_bindings.rs b/python/pecos-rslib/src/stab_vec_bindings.rs index da54bca67..59e99517a 100644 --- a/python/pecos-rslib/src/stab_vec_bindings.rs +++ b/python/pecos-rslib/src/stab_vec_bindings.rs @@ -209,7 +209,7 @@ impl PyStabVec { self.inner.pz(q); Ok(None) } - "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PnZ" => { + "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PNZ" => { self.inner.pnz(q); Ok(None) } @@ -217,7 +217,7 @@ impl PyStabVec { self.inner.px(q); Ok(None) } - "Init -X" | "init |->" | "PnX" => { + "Init -X" | "init |->" | "PNX" => { self.inner.pnx(q); Ok(None) } @@ -225,7 +225,7 @@ impl PyStabVec { self.inner.py(q); Ok(None) } - "Init -Y" | "init |-i>" | "PnY" => { + "Init -Y" | "init |-i>" | "PNY" => { self.inner.pny(q); Ok(None) } diff --git a/python/pecos-rslib/src/state_vec_bindings.rs b/python/pecos-rslib/src/state_vec_bindings.rs index 81b97edc6..dcc8fa532 100644 --- a/python/pecos-rslib/src/state_vec_bindings.rs +++ b/python/pecos-rslib/src/state_vec_bindings.rs @@ -339,7 +339,7 @@ impl PyStateVec { self.inner.pz(q); Ok(None) } - "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PnZ" => { + "Init -Z" | "init |1>" | "leak |1>" | "unleak |1>" | "PNZ" => { self.inner.pnz(q); Ok(None) } @@ -347,7 +347,7 @@ impl PyStateVec { self.inner.px(q); Ok(None) } - "Init -X" | "init |->" | "PnX" => { + "Init -X" | "init |->" | "PNX" => { self.inner.pnx(q); Ok(None) } @@ -355,7 +355,7 @@ impl PyStateVec { self.inner.py(q); Ok(None) } - "Init -Y" | "init |-i>" | "PnY" => { + "Init -Y" | "init |-i>" | "PNY" => { self.inner.pny(q); Ok(None) } @@ -508,6 +508,40 @@ impl PyStateVec { Ok(None) } + "CRX" | "CRY" | "CRZ" => { + let Some(params) = params else { + return Err(PyErr::new::( + "Angle parameter missing for controlled rotation gate", + )); + }; + let angle = match params.get_item("angle") { + Ok(Some(py_any)) => py_any.extract::().map_err(|_| { + PyErr::new::( + "Expected a valid angle parameter for controlled rotation gate", + ) + })?, + Ok(None) => { + return Err(PyErr::new::( + "Angle parameter missing for controlled rotation gate", + )); + } + Err(err) => return Err(err), + }; + match symbol { + "CRX" => { + self.inner.crx(angle.0, pair); + } + "CRY" => { + self.inner.cry(angle.0, pair); + } + "CRZ" => { + self.inner.crz(angle.0, pair); + } + _ => unreachable!(), + } + Ok(None) + } + "RXXRYYRZZ" | "RZZRYYRXX" | "R2XXYYZZ" | "RXXYYZZ" => { if let Some(params) = params { match params.get_item("angles") { diff --git a/python/quantum-pecos/docs/reference/_autosummary/pecos.qeclib.qubit.preps.rst b/python/quantum-pecos/docs/reference/_autosummary/pecos.qeclib.qubit.preps.rst index 0b24f822a..012be4346 100644 --- a/python/quantum-pecos/docs/reference/_autosummary/pecos.qeclib.qubit.preps.rst +++ b/python/quantum-pecos/docs/reference/_autosummary/pecos.qeclib.qubit.preps.rst @@ -17,4 +17,9 @@ pecos.qeclib.qubit.preps .. autosummary:: - Prep + PNX + PNY + PNZ + PX + PY + PZ diff --git a/python/quantum-pecos/src/pecos/circuit_converters/hugr_to_ast.py b/python/quantum-pecos/src/pecos/circuit_converters/hugr_to_ast.py index b8605af81..96bc86c80 100644 --- a/python/quantum-pecos/src/pecos/circuit_converters/hugr_to_ast.py +++ b/python/quantum-pecos/src/pecos/circuit_converters/hugr_to_ast.py @@ -276,9 +276,7 @@ def convert(self) -> Program: ) # Add classical register declarations for measurement results - decl_list.extend( - RegisterDecl(name=result_var, size=1, is_result=True) for result_var in self.measurement_results.values() - ) + decl_list.extend(RegisterDecl(name=result_var, size=1) for result_var in self.measurement_results.values()) declarations = tuple(decl_list) @@ -1249,7 +1247,7 @@ def hugr_to_ast( ... >>> package = simple.compile() >>> ast = hugr_to_ast(package.modules[0]) - >>> len(ast.body) # Prep + H + Measure + >>> len(ast.body) # PZ + H + Measure 3 """ converter = HugrToAstConverter(hugr) diff --git a/python/quantum-pecos/src/pecos/circuits/quantum_circuit.py b/python/quantum-pecos/src/pecos/circuits/quantum_circuit.py index ae03fb7bf..9119ea8d1 100644 --- a/python/quantum-pecos/src/pecos/circuits/quantum_circuit.py +++ b/python/quantum-pecos/src/pecos/circuits/quantum_circuit.py @@ -87,7 +87,7 @@ "SWAP": "SWAP", "RXXRYYRZZ": "RXXRYYRZZ", "R2XXYYZZ": "RXXRYYRZZ", - "Prep": "init |0>", + "PZ": "init |0>", "Measure": "measure", "MeasureFree": "measure", "QAlloc": "QAlloc", diff --git a/python/quantum-pecos/src/pecos/qec/surface/circuit_builder.py b/python/quantum-pecos/src/pecos/qec/surface/circuit_builder.py index f88604c10..b50b51718 100644 --- a/python/quantum-pecos/src/pecos/qec/surface/circuit_builder.py +++ b/python/quantum-pecos/src/pecos/qec/surface/circuit_builder.py @@ -766,10 +766,10 @@ def render( """Render to PECOS TickCircuit. The tick structure follows Stim's pattern: - - Tick: Prep data qubits + - Tick: PZ data qubits - Tick: H for X-basis prep (if X-basis) - For each syndrome round: - - Tick: Prep ancillas + - Tick: PZ ancillas - Tick: H on X ancillas - Tick: CX round 1 - Tick: CX round 2 diff --git a/python/quantum-pecos/src/pecos/simulators/cuda_statevec/bindings.py b/python/quantum-pecos/src/pecos/simulators/cuda_statevec/bindings.py index 9dad4fee8..de40bab06 100644 --- a/python/quantum-pecos/src/pecos/simulators/cuda_statevec/bindings.py +++ b/python/quantum-pecos/src/pecos/simulators/cuda_statevec/bindings.py @@ -412,6 +412,45 @@ def RZZ( state.backend.rzz(angles[0], list(qubits)) +def CRX( + state: CudaStateVec, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RX gate (qubits = (control, target)).""" + if len(angles) != 1: + msg = "CRX gate requires exactly 1 angle parameter." + raise ValueError(msg) + state.backend.crx(angles[0], list(qubits)) + + +def CRY( + state: CudaStateVec, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RY gate (qubits = (control, target)).""" + if len(angles) != 1: + msg = "CRY gate requires exactly 1 angle parameter." + raise ValueError(msg) + state.backend.cry(angles[0], list(qubits)) + + +def CRZ( + state: CudaStateVec, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RZ gate (qubits = (control, target)).""" + if len(angles) != 1: + msg = "CRZ gate requires exactly 1 angle parameter." + raise ValueError(msg) + state.backend.crz(angles[0], list(qubits)) + + # ============================================================================= # Gate dictionary # ============================================================================= @@ -505,6 +544,9 @@ def RZZ( "SqrtZZ": SZZ, "SZZdg": SZZdg, # Two-qubit rotations + "CRX": CRX, + "CRY": CRY, + "CRZ": CRZ, "RXX": RXX, "RYY": RYY, "RZZ": RZZ, diff --git a/python/quantum-pecos/src/pecos/simulators/custatevec/bindings.py b/python/quantum-pecos/src/pecos/simulators/custatevec/bindings.py index 93f3d2e8b..2c3791498 100644 --- a/python/quantum-pecos/src/pecos/simulators/custatevec/bindings.py +++ b/python/quantum-pecos/src/pecos/simulators/custatevec/bindings.py @@ -68,6 +68,9 @@ "RXX": two_q.RXX, "RYY": two_q.RYY, "RZZ": two_q.RZZ, + "CRX": two_q.CRX, + "CRY": two_q.CRY, + "CRZ": two_q.CRZ, "RXXRYYRZZ": two_q.RXXRYYRZZ, "R2XXYYZZ": two_q.RXXRYYRZZ, "SXX": two_q.SXX, diff --git a/python/quantum-pecos/src/pecos/simulators/custatevec/gates_two_qubit.py b/python/quantum-pecos/src/pecos/simulators/custatevec/gates_two_qubit.py index 9544c9410..56faf886d 100644 --- a/python/quantum-pecos/src/pecos/simulators/custatevec/gates_two_qubit.py +++ b/python/quantum-pecos/src/pecos/simulators/custatevec/gates_two_qubit.py @@ -318,6 +318,73 @@ def RZZ( _apply_two_qubit_matrix(state, qubits, matrix) +def CRX( + state: CuStateVec, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RX gate (qubits[0] = control, qubits[1] = target). + + Uses `_apply_controlled_matrix` so only the 2x2 RX(theta) action on + the target is passed; cuStateVec handles the controlled gating + internally. RX(theta) = [[cos(theta/2), -i sin(theta/2)], + [-i sin(theta/2), cos(theta/2)]]. + """ + if len(angles) != 1: + msg = "CRX gate requires exactly 1 angle parameter." + raise ValueError(msg) + theta = float(angles[0]) + c = cmath.cos(theta / 2) + s = cmath.sin(theta / 2) + matrix = cp.asarray([c, -1j * s, -1j * s, c], dtype=state.cp_type) + _apply_controlled_matrix(state, qubits[0], qubits[1], matrix) + + +def CRY( + state: CuStateVec, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RY gate (qubits[0] = control, qubits[1] = target). + + Uses `_apply_controlled_matrix` with the 2x2 RY(theta) on target. + RY(theta) = [[cos(theta/2), -sin(theta/2)], + [sin(theta/2), cos(theta/2)]]. + """ + if len(angles) != 1: + msg = "CRY gate requires exactly 1 angle parameter." + raise ValueError(msg) + theta = float(angles[0]) + c = cmath.cos(theta / 2) + s = cmath.sin(theta / 2) + matrix = cp.asarray([c, -s, s, c], dtype=state.cp_type) + _apply_controlled_matrix(state, qubits[0], qubits[1], matrix) + + +def CRZ( + state: CuStateVec, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RZ gate (qubits[0] = control, qubits[1] = target). + + Uses `_apply_controlled_matrix` with the 2x2 RZ(theta) on target. + RZ(theta) = diag(exp(-i theta/2), exp(i theta/2)). + """ + if len(angles) != 1: + msg = "CRZ gate requires exactly 1 angle parameter." + raise ValueError(msg) + theta = float(angles[0]) + matrix = cp.asarray( + [cmath.exp(-1j * theta / 2), 0, 0, cmath.exp(1j * theta / 2)], + dtype=state.cp_type, + ) + _apply_controlled_matrix(state, qubits[0], qubits[1], matrix) + + def RXXRYYRZZ( state: CuStateVec, qubits: tuple[int, int], diff --git a/python/quantum-pecos/src/pecos/simulators/mps_pytket/bindings.py b/python/quantum-pecos/src/pecos/simulators/mps_pytket/bindings.py index aad2a2deb..2b0191145 100644 --- a/python/quantum-pecos/src/pecos/simulators/mps_pytket/bindings.py +++ b/python/quantum-pecos/src/pecos/simulators/mps_pytket/bindings.py @@ -65,6 +65,9 @@ "CX": two_q.CX, "CY": two_q.CY, "CZ": two_q.CZ, + "CRX": two_q.CRX, + "CRY": two_q.CRY, + "CRZ": two_q.CRZ, "RXX": two_q.RXX, "RYY": two_q.RYY, "RZZ": two_q.RZZ, diff --git a/python/quantum-pecos/src/pecos/simulators/mps_pytket/gates_two_qubit.py b/python/quantum-pecos/src/pecos/simulators/mps_pytket/gates_two_qubit.py index 75554de1e..0c7c8c039 100644 --- a/python/quantum-pecos/src/pecos/simulators/mps_pytket/gates_two_qubit.py +++ b/python/quantum-pecos/src/pecos/simulators/mps_pytket/gates_two_qubit.py @@ -203,6 +203,97 @@ def RZZ( _apply_two_qubit_matrix(state, qubits, matrix) +def CRX( + state: MPS, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RX gate (qubits[0] = control, qubits[1] = target). + + Block-diag(I, RX(theta)) in q0-HIGH basis (MPS convention): + [[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, cos(theta/2), -i sin(theta/2)], + [0, 0, -i sin(theta/2), cos(theta/2)]]. + """ + if len(angles) != 1: + msg = "CRX gate requires exactly 1 angle parameter." + raise ValueError(msg) + theta = angles[0] + c = cmath.cos(theta / 2) + s = cmath.sin(theta / 2) + matrix = cp.asarray( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, c, -1j * s], + [0, 0, -1j * s, c], + ], + dtype=state.dtype, + ) + _apply_two_qubit_matrix(state, qubits, matrix) + + +def CRY( + state: MPS, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RY gate (qubits[0] = control, qubits[1] = target). + + Block-diag(I, RY(theta)) in q0-HIGH basis: + [[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, cos(theta/2), -sin(theta/2)], + [0, 0, sin(theta/2), cos(theta/2)]]. + """ + if len(angles) != 1: + msg = "CRY gate requires exactly 1 angle parameter." + raise ValueError(msg) + theta = angles[0] + c = cmath.cos(theta / 2) + s = cmath.sin(theta / 2) + matrix = cp.asarray( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, c, -s], + [0, 0, s, c], + ], + dtype=state.dtype, + ) + _apply_two_qubit_matrix(state, qubits, matrix) + + +def CRZ( + state: MPS, + qubits: tuple[int, int], + angles: tuple[float], + **_params: SimulatorGateParams, +) -> None: + """Controlled-RZ gate (qubits[0] = control, qubits[1] = target). + + Block-diag(I, RZ(theta)) in q0-HIGH basis: diag(1, 1, + exp(-i theta/2), exp(i theta/2)). + """ + if len(angles) != 1: + msg = "CRZ gate requires exactly 1 angle parameter." + raise ValueError(msg) + theta = angles[0] + matrix = cp.asarray( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, cmath.exp(-1j * theta / 2), 0], + [0, 0, 0, cmath.exp(1j * theta / 2)], + ], + dtype=state.dtype, + ) + _apply_two_qubit_matrix(state, qubits, matrix) + + def RXXRYYRZZ( state: MPS, qubits: tuple[int, int], diff --git a/python/quantum-pecos/src/pecos/simulators/statevec/bindings.py b/python/quantum-pecos/src/pecos/simulators/statevec/bindings.py index f7faa9a14..2eb6c6f6c 100644 --- a/python/quantum-pecos/src/pecos/simulators/statevec/bindings.py +++ b/python/quantum-pecos/src/pecos/simulators/statevec/bindings.py @@ -153,25 +153,27 @@ def get_bindings(state: StateVec) -> dict: "PZ": lambda _s, q, **p: sim.run_1q_gate("PZ", q, p), "PX": lambda _s, q, **p: sim.run_1q_gate("PX", q, p), "PY": lambda _s, q, **p: sim.run_1q_gate("PY", q, p), - "PnZ": lambda _s, q, **p: sim.run_1q_gate("PnZ", q, p), + "PNZ": lambda _s, q, **p: sim.run_1q_gate("PNZ", q, p), + "PNX": lambda _s, q, **p: sim.run_1q_gate("PNX", q, p), + "PNY": lambda _s, q, **p: sim.run_1q_gate("PNY", q, p), "Init": lambda _s, q, **p: sim.run_1q_gate("PZ", q, p), "Init +Z": lambda _s, q, **p: sim.run_1q_gate("PZ", q, p), - "Init -Z": lambda _s, q, **p: sim.run_1q_gate("PnZ", q, p), + "Init -Z": lambda _s, q, **p: sim.run_1q_gate("PNZ", q, p), "Init +X": lambda _s, q, **p: sim.run_1q_gate("PX", q, p), - "Init -X": lambda _s, q, **p: sim.run_1q_gate("PnX", q, p), + "Init -X": lambda _s, q, **p: sim.run_1q_gate("PNX", q, p), "Init +Y": lambda _s, q, **p: sim.run_1q_gate("PY", q, p), - "Init -Y": lambda _s, q, **p: sim.run_1q_gate("PnY", q, p), + "Init -Y": lambda _s, q, **p: sim.run_1q_gate("PNY", q, p), "init |0>": lambda _s, q, **p: sim.run_1q_gate("PZ", q, p), - "init |1>": lambda _s, q, **p: sim.run_1q_gate("PnZ", q, p), + "init |1>": lambda _s, q, **p: sim.run_1q_gate("PNZ", q, p), "init |+>": lambda _s, q, **p: sim.run_1q_gate("PX", q, p), - "init |->": lambda _s, q, **p: sim.run_1q_gate("PnX", q, p), + "init |->": lambda _s, q, **p: sim.run_1q_gate("PNX", q, p), "init |+i>": lambda _s, q, **p: sim.run_1q_gate("PY", q, p), - "init |-i>": lambda _s, q, **p: sim.run_1q_gate("PnY", q, p), + "init |-i>": lambda _s, q, **p: sim.run_1q_gate("PNY", q, p), "leak": lambda _s, q, **p: sim.run_1q_gate("PZ", q, p), "leak |0>": lambda _s, q, **p: sim.run_1q_gate("PZ", q, p), - "leak |1>": lambda _s, q, **p: sim.run_1q_gate("PnZ", q, p), + "leak |1>": lambda _s, q, **p: sim.run_1q_gate("PNZ", q, p), "unleak |0>": lambda _s, q, **p: sim.run_1q_gate("PZ", q, p), - "unleak |1>": lambda _s, q, **p: sim.run_1q_gate("PnZ", q, p), + "unleak |1>": lambda _s, q, **p: sim.run_1q_gate("PNZ", q, p), # Aliases "Q": lambda _s, q, **p: sim.run_1q_gate("SX", q, p), "Qd": lambda _s, q, **p: sim.run_1q_gate("SXdg", q, p), @@ -268,4 +270,22 @@ def get_bindings(state: StateVec) -> dict: tuple(qs) if isinstance(qs, list) else qs, {"angles": p["angles"]} if "angles" in p else {"angles": [0, 0, 0]}, ), + # Controlled rotations -- dispatch the parameterized 2q gate to the + # Rust binding which then calls the new `crx`/`cry`/`crz` default + # methods on `ArbitraryRotationGateable` (1-RZZ 2q-minimal decomp). + "CRX": lambda _s, qs, **p: sim.run_2q_gate( + "CRX", + tuple(qs) if isinstance(qs, list) else qs, + {"angle": p["angles"][0]} if "angles" in p else {"angle": 0}, + ), + "CRY": lambda _s, qs, **p: sim.run_2q_gate( + "CRY", + tuple(qs) if isinstance(qs, list) else qs, + {"angle": p["angles"][0]} if "angles" in p else {"angle": 0}, + ), + "CRZ": lambda _s, qs, **p: sim.run_2q_gate( + "CRZ", + tuple(qs) if isinstance(qs, list) else qs, + {"angle": p["angles"][0]} if "angles" in p else {"angle": 0}, + ), } diff --git a/python/quantum-pecos/src/pecos/slr/__init__.py b/python/quantum-pecos/src/pecos/slr/__init__.py index 38e8e38b0..173f5b945 100644 --- a/python/quantum-pecos/src/pecos/slr/__init__.py +++ b/python/quantum-pecos/src/pecos/slr/__init__.py @@ -83,19 +83,15 @@ from typing import TYPE_CHECKING from pecos.slr import ast, qeclib +from pecos.slr.angle import Angle, rad, turns from pecos.slr.block import Block from pecos.slr.cond_block import If, Repeat -from pecos.slr.gen_codes.guppy.qubit_state_validator import ( - QubitStateValidator, - StateViolation, - validate_qubit_states, -) from pecos.slr.loop_block import For, While from pecos.slr.main import Main from pecos.slr.main import ( Main as SLR, ) -from pecos.slr.misc import Barrier, Comment, Parallel, Permute, Return +from pecos.slr.misc import Barrier, Comment, Parallel, Permute, Print, Return from pecos.slr.qalloc import QAlloc, QubitRef, SlotState from pecos.slr.slr_converter import SlrConverter from pecos.slr.types import Array @@ -169,6 +165,7 @@ def generate( __all__ = [ "SLR", + "Angle", "Array", "Barrier", "Bit", @@ -182,20 +179,18 @@ def generate( "Main", "Parallel", "Permute", + "Print", # Qubit allocator (new) "QAlloc", # Legacy register (kept for compatibility) "QReg", "Qubit", "QubitRef", - # State validation - "QubitStateValidator", "QubitType", "Repeat", "Return", "SlotState", "SlrConverter", - "StateViolation", "Vars", "While", # AST module @@ -204,5 +199,7 @@ def generate( "generate", # QEC library "qeclib", - "validate_qubit_states", + # Typed angles for rotation gates + "rad", + "turns", ] diff --git a/python/quantum-pecos/src/pecos/slr/angle.py b/python/quantum-pecos/src/pecos/slr/angle.py new file mode 100644 index 000000000..12bf6581c --- /dev/null +++ b/python/quantum-pecos/src/pecos/slr/angle.py @@ -0,0 +1,65 @@ +"""Typed angle values for SLR rotation gates. + +SLR rotation gates take a typed angle, not a bare float. `rad(x)` and +`turns(x)` construct an :class:`Angle` whose value is the exposed +``pecos.angle64`` fixed-point dtype -- the dtype carries all the math; +the wrapper adds only the source unit, used solely for pretty-printing. +Every backend unwraps to the underlying ``angle64`` before lowering, so +the wrapper is a display policy, not a parallel angle type. +""" + +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License.You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from pecos_rslib import angle64 + +__all__ = ["Angle", "rad", "turns"] + + +@dataclass(frozen=True) +class Angle: + """A typed rotation angle: a ``pecos.angle64`` value plus its source unit. + + The ``value`` holds the exact fixed-point angle and owns every + conversion (``to_radians``/``to_radians_signed``/``to_half_turns``/...). + ``source_unit`` records whether the user wrote ``rad(...)`` or + ``turns(...)`` so pretty-print can round-trip the unit label; it carries + no math and is never consulted during backend lowering. + """ + + value: angle64 + source_unit: Literal["rad", "turns"] + + def slr_repr(self) -> str: + """Render the SLR source form, e.g. ``rad(0.5)`` / ``turns(0.25)``. + + Uses the signed conversion so ordinary negative rotations read + naturally; the numeric value is the canonicalized fixed-point angle, + not the user's original literal (angle64 does not retain it). + """ + if self.source_unit == "turns": + return f"turns({self.value.to_turns_signed()})" + return f"rad({self.value.to_radians_signed()})" + + +def rad(value: float) -> Angle: + """Construct an :class:`Angle` from a value in radians.""" + return Angle(angle64.from_radians(float(value)), "rad") + + +def turns(value: float) -> Angle: + """Construct an :class:`Angle` from a value in turns (1.0 = a full turn).""" + return Angle(angle64.from_turns(float(value)), "turns") diff --git a/python/quantum-pecos/src/pecos/slr/ast/__init__.py b/python/quantum-pecos/src/pecos/slr/ast/__init__.py index 47c5ed1b7..e46fda47b 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/__init__.py +++ b/python/quantum-pecos/src/pecos/slr/ast/__init__.py @@ -134,6 +134,8 @@ def visit_gate(self, node: GateOp) -> int: from pecos.slr.ast.compare import ast_equal, compare_ast, nodes_equal from pecos.slr.ast.converter import SlrToAst, slr_to_ast from pecos.slr.ast.nodes import ( + # Reusable blocks + AllocatorArg, # Declarations AllocatorDecl, # Types @@ -148,10 +150,15 @@ def visit_gate(self, node: GateOp) -> int: BinaryExpr, # Enums BinaryOp, + BitBundleArg, BitExpr, # References BitRef, BitTypeExpr, + BlockArg, + BlockCall, + BlockDecl, + BlockInput, CommentOp, Declaration, Expression, @@ -165,12 +172,17 @@ def visit_gate(self, node: GateOp) -> int: ParallelBlock, PermuteOp, PrepareOp, + PrintOp, # Program Program, + QubitBundleArg, QubitTypeExpr, RegisterDecl, RepeatStmt, + ResourceEffect, ReturnOp, + SingleBitArg, + SingleQubitArg, SlotRef, SourceLocation, Statement, @@ -190,26 +202,28 @@ def visit_gate(self, node: GateOp) -> int: ) __all__ = [ + "AllocatorArg", "AllocatorDecl", "AllocatorTypeExpr", "ArrayTypeExpr", "AssignOp", - # Base "AstNode", - # Analysis "AstQubitStateValidator", - # Code generation "AstToGuppy", "AstToQasm", - # Visitors "AstVisitor", "BarrierOp", "BaseVisitor", "BinaryExpr", "BinaryOp", + "BitBundleArg", "BitExpr", "BitRef", "BitTypeExpr", + "BlockArg", + "BlockCall", + "BlockDecl", + "BlockInput", "CodegenOptions", "CodegenResult", "CollectingVisitor", @@ -231,18 +245,21 @@ def visit_gate(self, node: GateOp) -> int: "ParallelBlock", "PermuteOp", "PrepareOp", + "PrintOp", # Program "Program", + "QubitBundleArg", "QubitStateTracker", "QubitTypeExpr", "RegisterDecl", "RepeatStmt", "ResourceCount", "ResourceCounter", + "ResourceEffect", "ReturnOp", - # References + "SingleBitArg", + "SingleQubitArg", "SlotRef", - # Converter "SlrToAst", "SourceLocation", "StateViolation", diff --git a/python/quantum-pecos/src/pecos/slr/ast/_block_substitution.py b/python/quantum-pecos/src/pecos/slr/ast/_block_substitution.py new file mode 100644 index 000000000..50c3c1d51 --- /dev/null +++ b/python/quantum-pecos/src/pecos/slr/ast/_block_substitution.py @@ -0,0 +1,431 @@ +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Shared BlockDecl-body reference substitution. + +`BodyRemap` is a bidirectional slot/bit/whole-name remap. Two consumers +build it in opposite directions but use the identical substitution core: + +- `converter._convert_block_call` builds an OUTER -> PARAM remap, turning a + Block instance's body into a reusable `BlockDecl` body. +- `_block_flatten._inline_call` builds the inverse PARAM -> OUTER remap, + inlining a `BlockDecl` body at its call site for non-Guppy codegens. + +Unifying both into one `substitute_stmt` eliminates the +fix-one-forget-the-mirror bug class: two parallel copies previously +drifted, leaving PermuteOp and BlockCall substitution gaps. + +`substitute_stmt` is **fully recursive** over every node that can carry a +`SlotRef` / `BitRef` / expression, including the spots the pre-5e +allocator-name-only substitution silently skipped (`MeasureOp.results`, +`AssignOp`, `PrintOp`, `ReturnOp`, conditions, `GateOp.params`, +`ForStmt` bounds, nested `BlockCall`). + +Name-level refs (`BarrierOp`, bare `PermuteOp`, `PrepareOp(slots=None)`, +`VarExpr`, str `ReturnOp`/`AssignOp` targets) cannot express a partial +(single-qubit / bundle / single-bit) binding; touching a partially-bound +outer name raises `BodySubstitutionError` rather than silently leaking it. +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from pecos.slr.ast.nodes import ( + AllocatorArg, + AssignOp, + BarrierOp, + BinaryExpr, + BitBundleArg, + BitExpr, + BitRef, + BlockCall, + CommentOp, + ForStmt, + GateOp, + IfStmt, + LiteralExpr, + MeasureOp, + ParallelBlock, + PermuteOp, + PrepareOp, + PrintOp, + QubitBundleArg, + RepeatStmt, + ReturnOp, + SingleBitArg, + SingleQubitArg, + SlotRef, + UnaryExpr, + VarExpr, + WhileStmt, +) + +if TYPE_CHECKING: + from pecos.slr.ast.nodes import BlockArg, Expression, Statement + +# A PermuteOp source/target string is either a bare name or `name[idx]`. +_PERMUTE_REF_RE = re.compile(r"([A-Za-z_]\w*)(?:\[(\d+)\])?$") +# Leading identifier token of an arbitrary (possibly unparseable) ref, used +# to decide reject-on-partial WITHOUT a loose substring match -- `"q" in +# "sq[0:2]"` would otherwise falsely reject an unrelated `sq` ref. +_LEADING_IDENT_RE = re.compile(r"[A-Za-z_]\w*") + + +class BodySubstitutionError(ValueError): + """A body reference cannot be substituted (e.g. partial-binding name use). + + Subclasses `ValueError` so existing call sites/tests that expect a + `ValueError` from the old per-helper substitution keep working. + """ + + +class BodyRemap: + """Slot/bit/whole-name remap for BlockDecl body rewriting. + + Tables map a source `(name, index)` to a destination `(name, index)`. + `whole_alloc` records names that are bound *completely* (a whole-`QReg` + binding); only those may be renamed at the name level. Any name that + appears in a per-slot/per-bit binding (single qubit, bundle, single + bit) is recorded as *partial*: name-level use of it is rejected. + """ + + def __init__(self) -> None: + self._slot: dict[tuple[str, int], tuple[str, int]] = {} + self._bit: dict[tuple[str, int], tuple[str, int]] = {} + self._whole_alloc: dict[str, str] = {} + self._partial_names: set[str] = set() + + # ---- builders ---- + # + # An outer name may be bound in exactly ONE mode -- whole xor partial -- + # and a whole binding maps exactly one src. Conflicting builder calls + # (same name bound both whole and partial, or whole-bound twice) are an + # input-aliasing error: reject at construction time rather than letting + # whole silently win at lookup. This protects the QASM/flatten path too, + # where Guppy's own linearity alias check never runs. + + def _reject_conflict(self, name: str, *, mode: str) -> None: + if name in self._whole_alloc: + msg = ( + f"BodyRemap: outer name {name!r} is already bound whole; " + f"cannot also bind it {mode} (input aliasing)" + ) + raise BodySubstitutionError(msg) + if name in self._partial_names and mode == "whole": + msg = ( + f"BodyRemap: outer name {name!r} is already bound partially; " + f"cannot also bind it whole (input aliasing)" + ) + raise BodySubstitutionError(msg) + + def add_whole_alloc(self, src: str, dst: str, size: int) -> None: + """Bind a whole QReg `src` (size N) to `dst`, identity per-slot.""" + self._reject_conflict(src, mode="whole") + self._whole_alloc[src] = dst + for i in range(size): + self._slot[(src, i)] = (dst, i) + + def add_slot(self, src: tuple[str, int], dst: tuple[str, int]) -> None: + """Bind one outer qubit slot; marks `src` allocator partial. + + A repeated exact `src` slot is rejected: silently overwriting it + would drop the earlier binding and corrupt the body rewrite -- e.g. + a `[q[0], q[0]]` bundle, or two single-qubit inputs aliased to the + same outer slot. Qubit aliasing is invalid anyway (no-cloning); + fail loudly here so it cannot reach codegen. + """ + self._reject_conflict(src[0], mode="partial (per-slot)") + if src in self._slot: + msg = ( + f"BodyRemap: outer qubit slot {src!r} is already bound " + f"(to {self._slot[src]!r}); a qubit cannot be aliased to " + "two block-input positions (no-cloning)" + ) + raise BodySubstitutionError(msg) + self._slot[src] = dst + self._partial_names.add(src[0]) + + def add_bit(self, src: tuple[str, int], dst: tuple[str, int]) -> None: + """Bind one outer classical bit; marks `src` register partial. + + A repeated exact `src` bit is rejected for the same reason as + `add_slot`: a second binding would silently overwrite the first + and lose body references to it during substitution. + """ + self._reject_conflict(src[0], mode="partial (per-bit)") + if src in self._bit: + msg = ( + f"BodyRemap: outer bit {src!r} is already bound " + f"(to {self._bit[src]!r}); the same outer bit cannot back " + "two block-input positions (lossy substitution)" + ) + raise BodySubstitutionError(msg) + self._bit[src] = dst + self._partial_names.add(src[0]) + + # ---- lookups ---- + + def slot(self, ref: SlotRef) -> SlotRef: + dst = self._slot.get((ref.allocator, ref.index)) + if dst is None: + return ref + return SlotRef(allocator=dst[0], index=dst[1], location=ref.location) + + def bit(self, ref: BitRef) -> BitRef: + dst = self._bit.get((ref.register, ref.index)) + if dst is None: + return ref + return BitRef(register=dst[0], index=dst[1], location=ref.location) + + def whole_name(self, name: str, *, context: str) -> str: + """Rename a whole-allocator/register name; reject if partially bound. + + Unmapped names pass through unchanged (they reference allocators not + bound by any input -- a Block-local register, say). + """ + if name in self._whole_alloc: + return self._whole_alloc[name] + if name in self._partial_names: + msg = ( + f"{context} references {name!r}, which is only partially bound " + f"(a single-qubit / bundle / single-bit input). A whole-name " + f"reference cannot express that binding -- pass the whole " + f"register, or restructure the Block so the body does not use " + f"{name!r} by bare name." + ) + raise BodySubstitutionError(msg) + return name + + +def substitute_stmt(stmt: Statement, remap: BodyRemap) -> Statement: + """Return `stmt` with every slot/bit/expression reference remapped.""" + if isinstance(stmt, GateOp): + return GateOp( + gate=stmt.gate, + targets=tuple(remap.slot(t) for t in stmt.targets), + params=tuple(_sub_expr(p, remap) for p in stmt.params), + location=stmt.location, + ) + if isinstance(stmt, MeasureOp): + return MeasureOp( + targets=tuple(remap.slot(t) for t in stmt.targets), + results=tuple(remap.bit(r) for r in stmt.results), + location=stmt.location, + ) + if isinstance(stmt, PrepareOp): + return _sub_prepare(stmt, remap) + if isinstance(stmt, BarrierOp): + return BarrierOp( + allocators=tuple(remap.whole_name(a, context="Barrier") for a in stmt.allocators), + location=stmt.location, + ) + if isinstance(stmt, AssignOp): + target = stmt.target + new_target = ( + remap.bit(target) if isinstance(target, BitRef) else remap.whole_name(target, context="assignment target") + ) + return AssignOp( + target=new_target, + value=_sub_expr(stmt.value, remap), + location=stmt.location, + ) + if isinstance(stmt, ReturnOp): + # Substitution remaps names in place (1:1, order/count + # preserved), so the parallel `value_kinds` provenance still + # aligns and MUST be carried (dropping it would re-introduce + # the CReg/QReg name-collision miscompile inside a + # substituted BlockCall body). + return ReturnOp( + values=tuple( + _sub_expr(v, remap) if not isinstance(v, str) else remap.whole_name(v, context="Return value") + for v in stmt.values + ), + value_kinds=stmt.value_kinds, + location=stmt.location, + ) + if isinstance(stmt, PrintOp): + value = stmt.value + new_value = remap.bit(value) if isinstance(value, BitRef) else remap.whole_name(value, context="Print value") + return PrintOp( + value=new_value, + tag=stmt.tag, + namespace=stmt.namespace, + location=stmt.location, + ) + if isinstance(stmt, IfStmt): + return IfStmt( + condition=_sub_expr(stmt.condition, remap), + then_body=tuple(substitute_stmt(s, remap) for s in stmt.then_body), + else_body=tuple(substitute_stmt(s, remap) for s in stmt.else_body), + location=stmt.location, + ) + if isinstance(stmt, WhileStmt): + return WhileStmt( + condition=_sub_expr(stmt.condition, remap), + body=tuple(substitute_stmt(s, remap) for s in stmt.body), + location=stmt.location, + ) + if isinstance(stmt, ForStmt): + return ForStmt( + variable=stmt.variable, + start=_sub_expr(stmt.start, remap), + stop=_sub_expr(stmt.stop, remap), + step=None if stmt.step is None else _sub_expr(stmt.step, remap), + body=tuple(substitute_stmt(s, remap) for s in stmt.body), + location=stmt.location, + ) + if isinstance(stmt, RepeatStmt): + return RepeatStmt( + count=stmt.count, # plain int -- no refs + body=tuple(substitute_stmt(s, remap) for s in stmt.body), + location=stmt.location, + ) + if isinstance(stmt, ParallelBlock): + return ParallelBlock( + body=tuple(substitute_stmt(s, remap) for s in stmt.body), + location=stmt.location, + ) + if isinstance(stmt, PermuteOp): + return PermuteOp( + sources=tuple(_sub_permute_ref(r, remap) for r in stmt.sources), + targets=tuple(_sub_permute_ref(r, remap) for r in stmt.targets), + add_comment=stmt.add_comment, + whole_register=stmt.whole_register, + location=stmt.location, + ) + if isinstance(stmt, BlockCall): + return BlockCall( + callee=stmt.callee, + arg_bindings=tuple(_sub_block_arg(a, remap) for a in stmt.arg_bindings), + out_bindings=tuple(_sub_block_arg(a, remap) for a in stmt.out_bindings), + location=stmt.location, + ) + # CommentOp + anything else carrying no slot/bit/expr ref: pass through. + if not isinstance(stmt, CommentOp): # defensive: surface unhandled nodes + # Unknown statement types are passed through unchanged, matching the + # pre-5e behavior; if a new ref-bearing node is added it must be + # wired in here (the iter-5e plan's recurring-risk note). + pass + return stmt + + +def _sub_expr(expr: Expression, remap: BodyRemap) -> Expression: + """Recurse an expression, remapping any BitRef/VarExpr it contains.""" + if isinstance(expr, LiteralExpr): + return expr + if isinstance(expr, VarExpr): + return VarExpr( + name=remap.whole_name(expr.name, context="variable expression"), + location=expr.location, + ) + if isinstance(expr, BitExpr): + return BitExpr(ref=remap.bit(expr.ref), location=expr.location) + if isinstance(expr, BinaryExpr): + return BinaryExpr( + op=expr.op, + left=_sub_expr(expr.left, remap), + right=_sub_expr(expr.right, remap), + location=expr.location, + ) + if isinstance(expr, UnaryExpr): + return UnaryExpr( + op=expr.op, + operand=_sub_expr(expr.operand, remap), + location=expr.location, + ) + return expr + + +def _sub_prepare(stmt: PrepareOp, remap: BodyRemap) -> PrepareOp: + """Remap a PrepareOp. + + `slots=None` (prepare_all) is a whole-register op -> name-level rename + (reject if partially bound). `slots=(...)` is per-slot: remap each + `(allocator, i)`; all must land in the same destination allocator. + """ + if stmt.slots is None: + return PrepareOp( + allocator=remap.whole_name(stmt.allocator, context="Prepare-all"), + slots=None, + basis=stmt.basis, + location=stmt.location, + ) + remapped = [remap.slot(SlotRef(allocator=stmt.allocator, index=i)) for i in stmt.slots] + dst_allocs = {r.allocator for r in remapped} + if len(dst_allocs) > 1: + msg = ( + f"Prepare on {stmt.allocator!r} maps to multiple destination " + f"allocators {sorted(dst_allocs)} under a bundle binding; a single " + "Prepare cannot span allocators -- restructure the Block." + ) + raise BodySubstitutionError(msg) + (dst_alloc,) = dst_allocs + return PrepareOp( + allocator=dst_alloc, + slots=tuple(r.index for r in remapped), + basis=stmt.basis, + location=stmt.location, + ) + + +def _sub_permute_ref(ref: str, remap: BodyRemap) -> str: + """Remap a PermuteOp `name` or `name[idx]` string. + + `name[idx]` resolves through the per-slot table (works for whole-alloc + identity AND bundles). Bare `name` is a whole-register reference -> + name-level rename, rejected if partially bound. Unparseable refs that + mention a partially-bound name are rejected rather than silently leaked. + """ + match = _PERMUTE_REF_RE.fullmatch(ref) + if match is None: + # Compare the ref's LEADING identifier token exactly against the + # partial-name set (not a substring scan) so `sq[0:2]` is not + # falsely rejected when `q` is partially bound. + lead = _LEADING_IDENT_RE.match(ref) + if lead is not None and lead.group(0) in remap._partial_names: # noqa: SLF001 -- same module + base = lead.group(0) + msg = ( + f"Cannot substitute PermuteOp ref {ref!r}: unsupported " + f"ref form whose base name {base!r} is partially bound" + ) + raise BodySubstitutionError(msg) + return ref + name, idx = match.group(1), match.group(2) + if idx is None: + return remap.whole_name(name, context="Permute") + new = remap.slot(SlotRef(allocator=name, index=int(idx))) + return f"{new.allocator}[{new.index}]" + + +def _sub_block_arg(arg: BlockArg, remap: BodyRemap) -> BlockArg: + """Remap a nested BlockCall's arg/out BlockArg through the remap.""" + if isinstance(arg, AllocatorArg): + return AllocatorArg( + name=remap.whole_name(arg.name, context="nested BlockCall AllocatorArg"), + location=arg.location, + ) + if isinstance(arg, SingleQubitArg): + return SingleQubitArg(slot=remap.slot(arg.slot), location=arg.location) + if isinstance(arg, SingleBitArg): + return SingleBitArg(bit=remap.bit(arg.bit), location=arg.location) + if isinstance(arg, QubitBundleArg): + return QubitBundleArg( + slots=tuple(remap.slot(s) for s in arg.slots), + location=arg.location, + ) + if isinstance(arg, BitBundleArg): + return BitBundleArg( + bits=tuple(remap.bit(b) for b in arg.bits), + location=arg.location, + ) + return arg diff --git a/python/quantum-pecos/src/pecos/slr/ast/analysis/data_flow.py b/python/quantum-pecos/src/pecos/slr/ast/analysis/data_flow.py index d9da3b80e..1372be80c 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/analysis/data_flow.py +++ b/python/quantum-pecos/src/pecos/slr/ast/analysis/data_flow.py @@ -104,7 +104,7 @@ def add_use( self.consumed_at.append(position) def add_replacement(self, position: int) -> None: - """Mark that this value is replaced at a position (e.g., Prep).""" + """Mark that this value is replaced at a position (e.g., PZ).""" self.replaced_at.append(position) def has_use_after_consumption(self) -> bool: diff --git a/python/quantum-pecos/src/pecos/slr/ast/analysis/qubit_state_validator.py b/python/quantum-pecos/src/pecos/slr/ast/analysis/qubit_state_validator.py index 4c663ba6e..75f8e3066 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/analysis/qubit_state_validator.py +++ b/python/quantum-pecos/src/pecos/slr/ast/analysis/qubit_state_validator.py @@ -78,7 +78,7 @@ def message(self) -> str: """Human-readable error message.""" return ( f"Gate '{self.gate.name}' applied to unprepared qubit " - f"{self.allocator}[{self.index}]. Call Prep() before applying gates." + f"{self.allocator}[{self.index}]. Call PZ() before applying gates." ) def __str__(self) -> str: @@ -106,7 +106,7 @@ def get_state(self, allocator: str, index: int) -> ValidationSlotState: return self.slot_states.get((allocator, index), ValidationSlotState.UNPREPARED) def mark_prepared(self, allocator: str, index: int) -> None: - """Mark a slot as prepared (after Prep operation).""" + """Mark a slot as prepared (after PZ operation).""" self.slot_states[(allocator, index)] = ValidationSlotState.PREPARED def mark_unprepared(self, allocator: str, index: int) -> None: diff --git a/python/quantum-pecos/src/pecos/slr/ast/analysis/resource_counter.py b/python/quantum-pecos/src/pecos/slr/ast/analysis/resource_counter.py index d22ae9e29..9ef9ce045 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/analysis/resource_counter.py +++ b/python/quantum-pecos/src/pecos/slr/ast/analysis/resource_counter.py @@ -183,7 +183,7 @@ def _count_prepare(self, node: PrepareOp) -> None: if node.slots is not None: self.result.preparation_count += len(node.slots) else: - # Prep all - would need allocator info to count exactly + # PZ all - would need allocator info to count exactly self.result.preparation_count += 1 def _count_if(self, node: IfStmt) -> None: diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/_block_flatten.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/_block_flatten.py new file mode 100644 index 000000000..2d8955016 --- /dev/null +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/_block_flatten.py @@ -0,0 +1,192 @@ +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""BlockDecl/BlockCall flattening for non-Guppy codegens. + +Non-Guppy codegens (qasm, qir, stim, quantum_circuit) cannot represent +reusable functions, so a `BlockCall` is inlined at its call site by +substituting each input parameter name with the corresponding +`arg_binding` outer-scope allocator name. + +The Guppy emitter does NOT use this pass: it lowers `BlockDecl` to +`@guppy def` and `BlockCall` to a packed-array call. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pecos.slr.ast._block_substitution import BodyRemap, substitute_stmt +from pecos.slr.ast.nodes import ( + AllocatorArg, + BitBundleArg, + BlockCall, + ForStmt, + IfStmt, + ParallelBlock, + Program, + QubitBundleArg, + RepeatStmt, + SingleBitArg, + SingleQubitArg, + WhileStmt, +) + +if TYPE_CHECKING: + from pecos.slr.ast.nodes import BlockDecl, Statement + + +def validate_unique_block_decl_names(program: Program) -> None: + """Raise ValueError if any BlockDecl name appears more than once. + + Shared precondition check: both the Guppy emitter and the non-Guppy + flatten pass require globally-unique BlockDecl + names. Keeping the check in one place ensures the contract stays in sync + across codegens. + """ + seen: set[str] = set() + for decl in program.block_decls: + if decl.name in seen: + msg = f"Duplicate BlockDecl name {decl.name!r}" + raise ValueError(msg) + seen.add(decl.name) + + +def flatten_block_calls(program: Program) -> Program: + """Return a new Program with every BlockCall inlined and no BlockDecls left. + + The substitution rule maps each `BlockDecl` input parameter name to the + typed `BlockCall.arg_bindings` BlockArg. Currently only + `AllocatorArg` is supported; richer BlockArg shapes raise + `NotImplementedError`. Quantum-only for now. + """ + if not program.block_decls: + return program + + validate_unique_block_decl_names(program) + decls = {decl.name: decl for decl in program.block_decls} + new_body = _flatten_stmts(program.body, decls) + return Program( + name=program.name, + declarations=program.declarations, + body=new_body, + returns=program.returns, + allocator=program.allocator, + block_decls=(), + ) + + +def _flatten_stmts(body: tuple[Statement, ...], decls: dict[str, BlockDecl]) -> tuple[Statement, ...]: + out: list[Statement] = [] + for stmt in body: + if isinstance(stmt, BlockCall): + inlined = _inline_call(stmt, decls) + out.extend(_flatten_stmts(inlined, decls)) + continue + + if isinstance(stmt, IfStmt): + out.append( + IfStmt( + condition=stmt.condition, + then_body=_flatten_stmts(stmt.then_body, decls), + else_body=_flatten_stmts(stmt.else_body, decls), + location=stmt.location, + ), + ) + continue + + if isinstance(stmt, RepeatStmt): + out.append( + RepeatStmt(count=stmt.count, body=_flatten_stmts(stmt.body, decls), location=stmt.location), + ) + continue + + if isinstance(stmt, ForStmt): + out.append( + ForStmt( + variable=stmt.variable, + start=stmt.start, + stop=stmt.stop, + step=stmt.step, + body=_flatten_stmts(stmt.body, decls), + location=stmt.location, + ), + ) + continue + + if isinstance(stmt, WhileStmt): + out.append( + WhileStmt( + condition=stmt.condition, + body=_flatten_stmts(stmt.body, decls), + location=stmt.location, + ), + ) + continue + + if isinstance(stmt, ParallelBlock): + out.append(ParallelBlock(body=_flatten_stmts(stmt.body, decls), location=stmt.location)) + continue + + out.append(stmt) + return tuple(out) + + +def _inline_call(call: BlockCall, decls: dict[str, BlockDecl]) -> tuple[Statement, ...]: + decl = decls.get(call.callee) + if decl is None: + msg = f"BlockCall references undefined block {call.callee!r}" + raise ValueError(msg) + if len(call.arg_bindings) != len(decl.inputs): + msg = ( + f"BlockCall {call.callee!r}: {len(call.arg_bindings)} arg_bindings but " + f"BlockDecl declares {len(decl.inputs)} inputs" + ) + raise ValueError(msg) + # 5e.2: AllocatorArg / SingleQubitArg / SingleBitArg / QubitBundleArg all + # inline. BitBundleArg is the only still-deferred shape -- reject it (in + # BOTH arg and out bindings; silently allowing it in out_bindings would + # be a silent-fallback). The `test_bitbundle_*_rejected` lock-ins depend + # on this. + for position, args in (("arg", call.arg_bindings), ("out", call.out_bindings)): + for arg in args: + if isinstance(arg, BitBundleArg): + msg = ( + f"Flatten pass does not yet support BlockArg " + f"{type(arg).__name__} in {position}_bindings of " + f"{call.callee!r} (BitBundleArg is still deferred)" + ) + raise NotImplementedError(msg) + # Build the PARAM -> OUTER BodyRemap (flatten inlines a BlockDecl body -- + # which references param names -- at the call site, which uses the outer + # binding). converter builds the inverse OUTER -> PARAM remap; both use + # the shared `substitute_stmt` (5e.1 unification -- one substitution core, + # no fix-one-forget-the-mirror drift). Only arg_bindings drive the body + # rewrite; out_bindings do not contribute (the inlined body writes the + # outer slots directly -- there is no separate return-unpack in flatten). + remap = BodyRemap() + for inp, arg in zip(decl.inputs, call.arg_bindings, strict=True): + if isinstance(arg, AllocatorArg): + # AllocatorArg inputs are always array[qubit, N]; the emitter + # validates this, so flatten can trust inp.type_expr.size. + size = getattr(inp.type_expr, "size", 0) + remap.add_whole_alloc(inp.name, arg.name, size) + elif isinstance(arg, SingleQubitArg): + remap.add_slot((inp.name, 0), (arg.slot.allocator, arg.slot.index)) + elif isinstance(arg, SingleBitArg): + remap.add_bit((inp.name, 0), (arg.bit.register, arg.bit.index)) + elif isinstance(arg, QubitBundleArg): + for k, slot in enumerate(arg.slots): + remap.add_slot((inp.name, k), (slot.allocator, slot.index)) + else: # BitBundleArg already rejected above; defensive + msg = f"Flatten pass: unexpected BlockArg {type(arg).__name__} for input {inp.name!r} of {call.callee!r}" + raise NotImplementedError(msg) + return tuple(substitute_stmt(stmt, remap) for stmt in decl.body) diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/_prep_tail.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/_prep_tail.py new file mode 100644 index 000000000..bdd1b151e --- /dev/null +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/_prep_tail.py @@ -0,0 +1,48 @@ +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License.You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Canonical prep-basis lowering, single source for every codegen. + +A prep gate is a Z-reset (|0>) followed by a fixed Clifford tail. The +tail is expressed as `GateKind`s so each backend reuses its existing +`GateKind -> name` map (`GATE_TO_QIR/STIM/QC/QASM`, guppy +`FUNCTIONAL_GATES`); there is exactly ONE tail table, not six. Pinned +by review (states experimentally verified; +S = diag(1, i)): the uniform symmetric model X-basis = H, +phase-flip via trailing Z; Y-basis = H then S(+)/Sdg(-). +""" + +from __future__ import annotations + +from pecos.slr.ast.nodes import GateKind + +# basis -> Clifford tail applied AFTER a |0> reset. +PREP_TAIL: dict[str, tuple[GateKind, ...]] = { + "PZ": (), # |0> + "PNZ": (GateKind.X,), # |1> + "PX": (GateKind.H,), # |+> + "PNX": (GateKind.H, GateKind.Z), # |-> + "PY": (GateKind.H, GateKind.SZ), # |+i> + "PNY": (GateKind.H, GateKind.SZdg), # |-i> +} + + +def prep_tail(basis: str) -> tuple[GateKind, ...]: + """Tail for `basis`, or fail LOUD on an unknown basis. + + An unknown basis silently lowering as a bare |0> reset would be + exactly the silent-miscompile class this lowering exists to kill. + """ + try: + return PREP_TAIL[basis] + except KeyError: + msg = f"codegen: unknown prep basis {basis!r} (expected one of {sorted(PREP_TAIL)})." + raise NotImplementedError(msg) from None diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/entry_wrapper.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/entry_wrapper.py new file mode 100644 index 000000000..4a45b48f1 --- /dev/null +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/entry_wrapper.py @@ -0,0 +1,226 @@ +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""No-arg `entry()` wrapper for AST-emitted parameterized `main(...)`. + +The AST emitter produces `main(q: array[qubit, N] @ owned, ...)`. Downstream +HUGR consumers (`pecos.Hugr(bytes)`, `pecos_rslib.HugrProgram`, the Selene +runtime) require a no-arg entrypoint, matching the legacy IR generator's +shape. This module builds that wrapper by mirroring the same return-shape +logic the emitter uses, so the wrapper signature matches main's exactly. + +Two modes match `AstToGuppy._return_type`: +- Explicit `Return(...)` -> pass through main's return value unchanged. +- No `Return(...)` -> `entry() -> None` and discard (the output model no + longer implicitly returns result-flagged CRegs). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pecos.slr.ast import ( + AllocatorDecl, + BitExpr, + ForStmt, + IfStmt, + LiteralExpr, + MeasureOp, + ParallelBlock, + RegisterDecl, + RepeatStmt, + ReturnOp, + WhileStmt, +) + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pecos.slr.ast import Expression, Program, Statement + + +@dataclass(frozen=True) +class EntryWrapperInfo: + """Metadata extracted from the AST program for building the wrapper. + + `all_creg_sizes` mirrors the emitter's `context.registers` lookup view + used by `AstToGuppy._return_value_type`: every declared CReg plus every + inline-Measure-introduced CReg. Needed so an explicit `Return(...)` + referencing a declared or inline CReg resolves the same way the emitter + does (not as `ValueError`). + """ + + allocator_sizes: dict[str, int] + explicit_return: ReturnOp | None + all_creg_sizes: dict[str, int] + + +# Reserved namespace for the opt-in returned-CReg result tags. Private +# (double-underscore) so it cannot collide with user `Print(..., namespace= +# "result")` -> `result.` outputs. Single source of truth shared with +# `_selene_harness` so the emit/read sides never drift. +RETURN_TAG_NAMESPACE = "__pecos_return" + + +def build_no_arg_entry_wrapper( + program: Program, + *, + emit_return_result_tags: bool = False, +) -> tuple[str, EntryWrapperInfo]: + """Generate the wrapper source and return the metadata used to build it. + + Args: + program: The AST program to wrap. + emit_return_result_tags: **Opt-in, test-harness only.** When True and + the program has an explicit `Return(...)`, the wrapper + destructures main's return and emits + `result("__pecos_return.", )` per returned CReg + instead of `return `, and is typed `-> None`. This makes + Selene key the outputs by CReg name (immune to internal, + non-returned measurements -- e.g. RUS verify). Default False + keeps the production wrapper byte-identical (the `.hugr()` / + raw-`measurement_N` consumers must not change). + + Returns: + A `(source, info)` tuple. `source` is the Guppy snippet defining + `entry()`; concatenate it after the main source. `info` exposes the + allocator sizes and explicit Return (if any) that the caller may + need (e.g., Selene's measurement-key generation). + """ + info = _collect_info(program) + source = _render_wrapper(info, emit_return_result_tags=emit_return_result_tags) + return source, info + + +def _collect_info(program: Program) -> EntryWrapperInfo: + allocator_sizes: dict[str, int] = {} + for decl in getattr(program, "declarations", ()): + if isinstance(decl, AllocatorDecl) and decl.parent is None: + allocator_sizes.setdefault(decl.name, decl.capacity) + top = getattr(program, "allocator", None) + if isinstance(top, AllocatorDecl) and top.parent is None: + allocator_sizes.setdefault(top.name, top.capacity) + + declared: dict[str, RegisterDecl] = {} + for decl in getattr(program, "declarations", ()): + if isinstance(decl, RegisterDecl): + declared[decl.name] = decl + + body = getattr(program, "body", ()) + inline_max: dict[str, int] = {} + _walk_for_measure_results(body, declared, inline_max) + + all_creg_sizes: dict[str, int] = {name: decl.size for name, decl in declared.items()} + for name, max_index in inline_max.items(): + all_creg_sizes[name] = max_index + 1 + + explicit_return = body[-1] if body and isinstance(body[-1], ReturnOp) else None + + return EntryWrapperInfo( + allocator_sizes=allocator_sizes, + explicit_return=explicit_return, + all_creg_sizes=all_creg_sizes, + ) + + +def _walk_for_measure_results( + body: Iterable[Statement], + declared: dict[str, RegisterDecl], + inline_max: dict[str, int], +) -> None: + for stmt in body: + if isinstance(stmt, MeasureOp): + for ref in stmt.results: + if ref.register not in declared: + inline_max[ref.register] = max(inline_max.get(ref.register, -1), ref.index) + elif isinstance(stmt, IfStmt): + _walk_for_measure_results(stmt.then_body, declared, inline_max) + _walk_for_measure_results(stmt.else_body, declared, inline_max) + elif isinstance(stmt, (RepeatStmt, ForStmt, WhileStmt, ParallelBlock)): + _walk_for_measure_results(stmt.body, declared, inline_max) + + +def _render_wrapper(info: EntryWrapperInfo, *, emit_return_result_tags: bool = False) -> str: + body_lines: list[str] = [ + f" {name} = array(qubit() for _ in range({size}))" for name, size in info.allocator_sizes.items() + ] + call_args = ", ".join(info.allocator_sizes.keys()) + call_expr = f"main({call_args})" + + if emit_return_result_tags and info.explicit_return is not None: + # Opt-in: destructure main's return and tag each returned CReg + # by name so Selene keys outputs by name, not positional + # measurement_N (which counts internal measurements too). + targets = [v if isinstance(v, str) else getattr(v, "name", None) for v in info.explicit_return.values] + if any(t is None for t in targets): + msg = f"emit_return_result_tags supports only named return values, got {info.explicit_return.values!r}" + raise ValueError(msg) + lhs = targets[0] if len(targets) == 1 else ", ".join(targets) + body_lines.append(f" {lhs} = {call_expr}") + body_lines.extend( + f' result("{RETURN_TAG_NAMESPACE}.{name}", {name})' for name in targets if name in info.all_creg_sizes + ) + return_ann = "None" + elif info.explicit_return is not None: + body_lines.append(f" return {call_expr}") + return_ann = _explicit_return_type(info) + else: + body_lines.append(f" {call_expr}") + return_ann = "None" + + body = "\n".join(body_lines) if body_lines else " pass" + return f"\n\n@guppy\ndef entry() -> {return_ann}:\n{body}\n" + + +def _explicit_return_type(info: EntryWrapperInfo) -> str: + assert info.explicit_return is not None # noqa: S101 + types = [ + _return_value_type(value, info.allocator_sizes, info.all_creg_sizes) for value in info.explicit_return.values + ] + return _tuple_type(types) + + +def _return_value_type(value: Expression | str, allocator_sizes: dict[str, int], creg_sizes: dict[str, int]) -> str: + if isinstance(value, str): + if value in allocator_sizes: + return f"array[qubit, {allocator_sizes[value]}]" + if value in creg_sizes: + return f"array[bool, {creg_sizes[value]}]" + msg = f"Unsupported Guppy return value {value!r}" + raise ValueError(msg) + if isinstance(value, BitExpr): + return "bool" + if isinstance(value, LiteralExpr) and isinstance(value.value, bool): + return "bool" + if isinstance(value, LiteralExpr) and isinstance(value.value, int): + return "int" + msg = f"Unsupported Guppy return expression {value!r}" + raise ValueError(msg) + + +def _tuple_type(types: list[str]) -> str: + """Mirror `AstToGuppy._tuple_type`: empty -> None, single -> bare, multi -> tuple[...].""" + if not types: + return "None" + if len(types) == 1: + return types[0] + return "tuple[" + ", ".join(types) + "]" + + +def truncate_source_for_error(source: str, max_lines: int = 80) -> str: + """Truncate generated Guppy source for inclusion in an error message.""" + lines = source.splitlines() + if len(lines) <= max_lines: + return source + head = lines[: max_lines - 10] + tail = lines[-10:] + return "\n".join([*head, f"... ({len(lines) - max_lines} lines elided) ...", *tail]) diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/guppy.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/guppy.py index a9013e44a..948d2616f 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/codegen/guppy.py +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/guppy.py @@ -11,111 +11,213 @@ """AST to Guppy Python code generator. -This module provides a visitor that transforms AST nodes into Guppy Python code. -Guppy is a quantum programming language that compiles to HUGR. - -Example: - from pecos.slr.ast import slr_to_ast, Program - from pecos.slr.ast.codegen import AstToGuppy - - # Convert SLR to AST - ast = slr_to_ast(slr_program) - - # Generate Guppy code - generator = AstToGuppy() - guppy_code = generator.generate(ast) +This emitter lowers SLR's allocator-style AST to Guppy source. Guppy has +linear qubit ownership, so quantum arrays are unpacked to stable local qubit +variables at function entry and the Guppy-only `GuppyLinearityState` tracks +which local owns each logical slot while recursive descent emits statements. """ from __future__ import annotations +import re +from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING - +from typing import TYPE_CHECKING, cast + +from pecos.slr.ast.codegen._block_flatten import validate_unique_block_decl_names +from pecos.slr.ast.codegen._prep_tail import prep_tail +from pecos.slr.ast.codegen.guppy_linearity import ( + GuppyLinearityState, + LinearityError, + Slot, + SlotState, +) from pecos.slr.ast.nodes import ( + AllocatorArg, AllocatorDecl, + ArrayTypeExpr, + BarrierOp, BinaryExpr, BinaryOp, BitExpr, BitRef, + BitTypeExpr, + BlockCall, ForStmt, GateKind, + GateOp, IfStmt, LiteralExpr, MeasureOp, ParallelBlock, + PermuteOp, + PrepareOp, + QubitBundleArg, + QubitTypeExpr, RegisterDecl, RepeatStmt, + ResourceEffect, + ReturnOp, + SingleBitArg, + SingleQubitArg, UnaryExpr, UnaryOp, VarExpr, WhileStmt, ) -from pecos.slr.ast.visitor import BaseVisitor if TYPE_CHECKING: + from collections.abc import Iterator + from pecos.slr.ast.nodes import ( AssignOp, - BarrierOp, + BlockDecl, + BlockInput, CommentOp, Expression, - GateOp, - PermuteOp, - PrepareOp, + PrintOp, Program, - ReturnOp, SlotRef, + Statement, ) -# Mapping from AST GateKind to Guppy function names -GATE_TO_GUPPY: dict[GateKind, str] = { - # Single-qubit Paulis - GateKind.X: "quantum.x", - GateKind.Y: "quantum.y", - GateKind.Z: "quantum.z", - # Hadamard - GateKind.H: "quantum.h", - # Phase gates - GateKind.S: "quantum.s", - GateKind.Sdg: "quantum.sdg", - GateKind.T: "quantum.t", - GateKind.Tdg: "quantum.tdg", - # Square root gates - GateKind.SX: "quantum.sx", - GateKind.SY: "quantum.sy", - GateKind.SZ: "quantum.sz", - GateKind.SXdg: "quantum.sxdg", - GateKind.SYdg: "quantum.sydg", - GateKind.SZdg: "quantum.szdg", - # Rotation gates - GateKind.RX: "quantum.rx", - GateKind.RY: "quantum.ry", - GateKind.RZ: "quantum.rz", - # Two-qubit gates - GateKind.CX: "quantum.cx", - GateKind.CY: "quantum.cy", - GateKind.CZ: "quantum.cz", - GateKind.CH: "quantum.ch", - # Two-qubit rotation gates - GateKind.SXX: "quantum.sxx", - GateKind.SYY: "quantum.syy", - GateKind.SZZ: "quantum.szz", - GateKind.SXXdg: "quantum.sxxdg", - GateKind.SYYdg: "quantum.syydg", - GateKind.SZZdg: "quantum.szzdg", - GateKind.RZZ: "quantum.rzz", - # Controlled rotation gates - GateKind.CRX: "quantum.crx", - GateKind.CRY: "quantum.cry", - GateKind.CRZ: "quantum.crz", - # Face rotations - GateKind.F: "quantum.f", - GateKind.Fdg: "quantum.fdg", - GateKind.F4: "quantum.f4", - GateKind.F4dg: "quantum.f4dg", +FUNCTIONAL_GATES: dict[GateKind, str] = { + GateKind.X: "x", + GateKind.Y: "y", + GateKind.Z: "z", + GateKind.H: "h", + GateKind.T: "t", + GateKind.Tdg: "tdg", + GateKind.SZ: "s", + GateKind.SZdg: "sdg", + # Guppy's `v`/`vdg` are sqrt(X) / sqrt(X)-dagger (the standard V gate). + GateKind.SX: "v", + GateKind.SXdg: "vdg", + GateKind.CX: "cx", + GateKind.CY: "cy", + GateKind.CZ: "cz", + GateKind.CH: "ch", } -# Mapping from AST BinaryOp to Python operators +# Native Guppy parameterized rotation gates: `fn(qubit..., angle)`. +# Guppy's `angle` type stores half-turns (pi radians = 1.0 half-turn), +# which is exactly `angle64.to_half_turns_signed()`, so the angle is +# emitted as `angle()` with no radians/pi conversion. +PARAMETERIZED_FUNCTIONAL_GATES: dict[GateKind, str] = { + GateKind.RX: "rx", + GateKind.RY: "ry", + GateKind.RZ: "rz", + GateKind.CRZ: "crz", +} + +FUNCTIONAL_GATE_IMPORTS = ", ".join( + sorted(set(FUNCTIONAL_GATES.values()) | set(PARAMETERIZED_FUNCTIONAL_GATES.values()) | {"reset"}), +) + +# Decomposition table for PECOS gates with no native single Guppy gate. +# Each step is (guppy_fn, qubit_idx_tuple_into_targets, angle_spec) in +# CIRCUIT order (first applied first). angle_spec is: +# None -> non-parameterized gate: fn(q...) +# float (half-turns) -> constant angle: fn(q..., angle(h)) +# Callable[[input_params], float] -> angle (half-turns) from the input +# gate's params (also half-turns) +# The 1q-Clifford sequences mirror the dual-reviewed QIR `_GATE_DECOMP`; +# the 2q sqrt-Paulis use the native `zz_phase` (= RZZ, qsystem Quantinuum +# extension); CRX/CRY conjugate the native `crz`. +_GuppyAngleSpec = None | float | Callable[[tuple[float, ...]], float] +_GuppyDecompStep = tuple[str, tuple[int, ...], _GuppyAngleSpec] +GUPPY_GATE_DECOMP: dict[GateKind, tuple[_GuppyDecompStep, ...]] = { + # ---- single-qubit Cliffords (no native sqrt-Y / face gate) ---- + GateKind.SY: (("h", (0,), None), ("x", (0,), None)), + GateKind.SYdg: (("h", (0,), None), ("z", (0,), None)), + GateKind.F: (("sdg", (0,), None), ("h", (0,), None)), + GateKind.Fdg: (("h", (0,), None), ("s", (0,), None)), + GateKind.F4: (("h", (0,), None), ("sdg", (0,), None)), + GateKind.F4dg: (("s", (0,), None), ("h", (0,), None)), + # ---- two-qubit sqrt-Paulis via native zz_phase (= RZZ) ---- + GateKind.SZZ: (("zz_phase", (0, 1), 0.5),), + GateKind.SZZdg: (("zz_phase", (0, 1), -0.5),), + GateKind.SXX: ( + ("h", (0,), None), + ("h", (1,), None), + ("zz_phase", (0, 1), 0.5), + ("h", (0,), None), + ("h", (1,), None), + ), + GateKind.SXXdg: ( + ("h", (0,), None), + ("h", (1,), None), + ("zz_phase", (0, 1), -0.5), + ("h", (0,), None), + ("h", (1,), None), + ), + GateKind.SYY: ( + ("sdg", (0,), None), + ("sdg", (1,), None), + ("h", (0,), None), + ("h", (1,), None), + ("zz_phase", (0, 1), 0.5), + ("h", (0,), None), + ("h", (1,), None), + ("s", (0,), None), + ("s", (1,), None), + ), + GateKind.SYYdg: ( + ("sdg", (0,), None), + ("sdg", (1,), None), + ("h", (0,), None), + ("h", (1,), None), + ("zz_phase", (0, 1), -0.5), + ("h", (0,), None), + ("h", (1,), None), + ("s", (0,), None), + ("s", (1,), None), + ), + # ---- parameterized two-qubit gates ---- + # RZZ is the native zz_phase with the passed-through angle. + GateKind.RZZ: (("zz_phase", (0, 1), lambda p: p[0]),), + # CRX = (I o H) . CRZ . (I o H); CRY = (I o S.H) . CRZ . (I o H.Sdg). + # `crz` is native; the passed-through angle threads into it. + GateKind.CRX: ( + ("h", (1,), None), + ("crz", (0, 1), lambda p: p[0]), + ("h", (1,), None), + ), + GateKind.CRY: ( + ("sdg", (1,), None), + ("h", (1,), None), + ("crz", (0, 1), lambda p: p[0]), + ("h", (1,), None), + ("s", (1,), None), + ), +} + +# Gate names whose decomposition uses the qsystem `zz_phase` import. +_ZZ_PHASE_GATES = frozenset( + gk for gk, steps in GUPPY_GATE_DECOMP.items() if any(fn == "zz_phase" for fn, _, _ in steps) +) + + +def _param_to_half_turns(param: object, gate_name: str) -> float: + """Resolve a user gate angle param to signed half-turns for Guppy `angle`. + + Guppy's `angle` is half-turn based (pi rad = 1.0), which is exactly + ``angle64.to_half_turns_signed()``. Requires a typed `Angle` + (`rad(...)` / `turns(...)`) -- a non-`Angle` param (bare float, or a + non-literal classical expression at a gate-param position) fails loud. + """ + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + if not isinstance(param, LiteralExpr) or not isinstance(param.value, Angle): + msg = ( + f"AST -> Guppy v1: parameterized gate {gate_name} requires a typed `Angle` " + f"parameter (use `rad(...)` / `turns(...)`); got {param!r}" + ) + raise GuppyCodegenError(msg) + return param.value.value.to_half_turns_signed() + + BINARY_OP_TO_PYTHON: dict[BinaryOp, str] = { BinaryOp.ADD: "+", BinaryOp.SUB: "-", @@ -134,32 +236,31 @@ BinaryOp.RSHIFT: ">>", } -# Mapping from AST UnaryOp to Python operators -UNARY_OP_TO_PYTHON: dict[UnaryOp, str] = { - UnaryOp.NOT: "not", - UnaryOp.NEG: "-", -} +BOOL_OPERAND_BINARY_OPS = {BinaryOp.AND, BinaryOp.OR, BinaryOp.XOR} +BOOL_COMPARISON_OPS = {BinaryOp.EQ, BinaryOp.NE} + + +class GuppyCodegenError(LinearityError): + """Raised when the v1 AST -> Guppy emitter rejects an unsupported construct.""" @dataclass -class CodeGenContext: - """Context for code generation.""" +class GuppyContext: + """Mutable state for one Guppy emission run.""" indent_level: int = 0 - allocators: dict[str, int] = field(default_factory=dict) # name -> capacity - allocator_parents: dict[str, str | None] = field( - default_factory=dict, - ) # name -> parent - allocator_offsets: dict[str, int] = field( - default_factory=dict, - ) # name -> offset in parent - registers: dict[str, int] = field(default_factory=dict) # name -> size - measured_slots: set[tuple[str, int]] = field( - default_factory=set, - ) # (allocator, index) - measurement_vars: list[str] = field( - default_factory=list, - ) # variable names for results + root_allocators: dict[str, int] = field(default_factory=dict) + child_allocators: set[str] = field(default_factory=set) + registers: dict[str, RegisterDecl] = field(default_factory=dict) + linearity: GuppyLinearityState | None = None + temp_counter: int = 0 + # Single namespace-wide slot -> Guppy-local table. Populated + # by `populate_slot_locals` after declarations are collected (so + # the registers + allocator names are known); read by both + # `GuppyLinearityState.from_allocators(..., slot_locals=...)` and + # `AstToGuppy._local_name` so all three sites that emit a slot + # name agree (the bug class the xfail tracked). + slot_locals: dict[Slot, str] = field(default_factory=dict) def indent(self) -> str: """Return current indentation string.""" @@ -173,117 +274,106 @@ def pop_indent(self) -> None: """Decrease indentation level.""" self.indent_level = max(0, self.indent_level - 1) - def mark_measured(self, allocator: str, index: int) -> None: - """Mark a qubit slot as consumed by measurement.""" - self.measured_slots.add((allocator, index)) - - def is_allocator_fully_consumed(self, name: str) -> bool: - """Check if all slots of an allocator have been measured.""" - if name not in self.allocators: - return False - capacity = self.allocators[name] - return all((name, i) in self.measured_slots for i in range(capacity)) - - def get_root_allocator(self, name: str) -> str: - """Get the root allocator for a given allocator name.""" - current = name - while self.allocator_parents.get(current) is not None: - current = self.allocator_parents[current] - return current - - def get_absolute_index(self, allocator: str, index: int) -> int: - """Get the absolute index in the root allocator.""" - offset = self.allocator_offsets.get(allocator, 0) - return offset + index - - -class AstToGuppy(BaseVisitor[list[str]]): - """Visitor that generates Guppy Python code from AST. + def temp(self, prefix: str) -> str: + """Return a unique temporary local name.""" + name = f"_{prefix}_{self.temp_counter}" + self.temp_counter += 1 + return name + + def populate_slot_locals(self) -> None: + """Compute disambiguated Guppy local names for every allocator slot. + + Default name is `f"{allocator}_{index}"`; if that collides with + any declared allocator name, register name, or previously + assigned slot local, suffix `_` until unique. The result is the + authority used by both linearity-state binding init and the + emitter's `_local_name` so the entry-unpack LHS, the linearity + bindings, and per-slot references all agree (this disambiguation + prevents `q_0, q_1 = q` from shadowing a separately declared + `QReg("q_0", ...)` parameter). Idempotent: caller may invoke + once after `_collect_declarations` populates allocators+regs. + """ + taken: set[str] = set(self.root_allocators) | set(self.registers) + # Existing slot_locals (a re-population path; should not normally + # happen) are preserved -- once a slot's name is committed, + # everything downstream depends on it. + for name in self.slot_locals.values(): + taken.add(name) + for allocator, size in self.root_allocators.items(): + for index in range(size): + slot = Slot(allocator, index) + if slot in self.slot_locals: + continue + candidate = f"{allocator}_{index}" + while candidate in taken: + candidate += "_" + self.slot_locals[slot] = candidate + taken.add(candidate) - Generates clean Guppy code that can be compiled to HUGR. - Usage: - generator = AstToGuppy() - lines = generator.generate(ast_program) - code = "\\n".join(lines) - """ +class AstToGuppy: + """Recursive-descent Guppy code generator for AST programs.""" def __init__(self) -> None: """Initialize the generator.""" - self.context = CodeGenContext() + self.context = GuppyContext() + self._block_decls: dict[str, BlockDecl] = {} def generate(self, program: Program) -> list[str]: - """Generate Guppy code for a program. - - Args: - program: The AST Program to generate code for. - - Returns: - List of code lines. - """ - self.context = CodeGenContext() - return self.visit(program) - - def default_result(self) -> list[str]: - """Return empty list as default.""" - return [] - - def combine_results(self, results: list[list[str]]) -> list[str]: - """Combine multiple results into a single list.""" - combined = [] - for r in results: - combined.extend(r) - return combined - - # === Program === - - def visit_program(self, node: Program) -> list[str]: - """Generate code for a complete program.""" - lines = [] - - # Standard imports - lines.append("from guppylang import guppy") - lines.append("from guppylang.std import quantum") - lines.append("from guppylang.std.quantum import qubit") - lines.append("") + """Generate Guppy code for a program.""" + validate_unique_block_decl_names(program) + self.context = GuppyContext() + self._block_decls = {decl.name: decl for decl in program.block_decls} + for decl in program.block_decls: + self._validate_block_decl(decl) + self._validate_scratch_outer_slots(program) - # Process declarations to build context - for decl in node.declarations: - if isinstance(decl, AllocatorDecl): - self.context.allocators[decl.name] = decl.capacity - self.context.allocator_parents[decl.name] = decl.parent - elif isinstance(decl, RegisterDecl): - self.context.registers[decl.name] = decl.size + lines = self._imports() - if node.allocator: - self.context.allocators[node.allocator.name] = node.allocator.capacity - self.context.allocator_parents[node.allocator.name] = node.allocator.parent + for decl in program.block_decls: + lines.append("") + lines.extend(self._generate_block_decl(decl)) - # Calculate offsets for child allocators (sequential allocation within parent) - self._calculate_allocator_offsets(node) - - # First pass: scan body to find measurements (to determine return type) - self._scan_for_measurements(node.body) - - # Generate function signature - func_name = node.name.lower() - params = self._generate_params(node) - return_type = self._generate_return_type(node) + lines.append("") + lines.extend(self._generate_main(program)) + return lines - lines.append("@guppy") - lines.append(f"def {func_name}({params}) -> {return_type}:") + def _generate_main(self, program: Program) -> list[str]: + """Emit the main Guppy function for the program body.""" + self.context = GuppyContext() + self._collect_declarations(program) + self._collect_implicit_measure_registers(program.body) + self.context.populate_slot_locals() + self.context.linearity = GuppyLinearityState.from_allocators( + self.context.root_allocators, + slot_locals=self.context.slot_locals, + ) + self._reject_child_allocators() + + # Print AST-level validators (run before emission so failures + # point at the source program, not the generated Guppy). + self._validate_print_paths(program.body) + self._validate_print_inline_creg_assignment(program) + + body = list(program.body) + explicit_return = self._validate_return_position(body) + emitted_body = body[:-1] if explicit_return else body + + lines: list[str] = ["@guppy"] + lines.append(f"def {program.name.lower()}({self._render_params()}) -> {self._return_type(explicit_return)}:") - # Generate body self.context.push_indent() + body_lines: list[str] = [] + body_lines.extend(self._emit_entry_unpacks()) + body_lines.extend(self._emit_register_initializers()) - body_lines = [] - for stmt in node.body: - body_lines.extend(self.visit(stmt)) + for stmt in emitted_body: + body_lines.extend(self._emit_stmt(stmt)) - # Add return statement - return_lines = self._generate_return_statement(node) - if return_lines: - body_lines.extend(return_lines) + if explicit_return is not None: + body_lines.extend(self._emit_explicit_return(explicit_return)) + else: + body_lines.extend(self._emit_end_cleanup()) if body_lines: lines.extend(body_lines) @@ -291,539 +381,1723 @@ def visit_program(self, node: Program) -> list[str]: lines.append(f"{self.context.indent()}pass") self.context.pop_indent() - return lines - def _scan_for_measurements(self, stmts: tuple) -> None: - """Scan statements to find all measurements and mark consumed qubits. + def _validate_block_decl(self, decl: BlockDecl) -> None: + """Reject BlockDecl shapes that the v1 Guppy emitter cannot lower.""" + for inp in decl.inputs: + # Supported input shapes: array[qubit, N]; single qubit (bare + # `QubitTypeExpr`); single classical bit (bare + # `BitTypeExpr`, lowered via an array[bool, 1] write-back proxy). + # Qubit/bit bundles are added later. + is_qubit = isinstance(inp.type_expr, QubitTypeExpr) + is_qubit_array = isinstance(inp.type_expr, ArrayTypeExpr) and isinstance( + inp.type_expr.element, + QubitTypeExpr, + ) + is_bit = isinstance(inp.type_expr, BitTypeExpr) + if not (is_qubit or is_qubit_array or is_bit): + msg = ( + f"BlockDecl {decl.name!r} input {inp.name!r}: only array[qubit, N], " + f"bare qubit, and bare bit inputs are supported " + f"(got {type(inp.type_expr).__name__})" + ) + raise GuppyCodegenError(msg) + if is_bit: + # A classical bit is copyable; CONSUMED/PRODUCED/DROPPED don't + # apply. A write-back bit always survives the call. + if inp.effect is not ResourceEffect.LIVE_PRESERVED: + msg = ( + f"BlockDecl {decl.name!r} input {inp.name!r}: bare bit inputs " + f"must be LIVE_PRESERVED (got {inp.effect.name}); a classical " + "bit is copyable so consumed/produced/dropped do not apply" + ) + raise GuppyCodegenError(msg) + elif inp.effect is ResourceEffect.SCRATCH: + # Scratch ancilla: the + # block resets+measures it internally. Guppy allocates it + # internally (no parameter), so it must be a bare qubit -- + # array/bundle scratch is out of scope (Check's ancilla is a + # bare qubit; Check1Flag's flag is too but is deferred). + if not is_qubit: + msg = ( + f"BlockDecl {decl.name!r} input {inp.name!r}: SCRATCH is " + f"only supported for a bare qubit ancilla (got " + f"{type(inp.type_expr).__name__})" + ) + raise GuppyCodegenError(msg) + elif inp.effect not in {ResourceEffect.LIVE_PRESERVED, ResourceEffect.CONSUMED}: + msg = ( + f"BlockDecl {decl.name!r} input {inp.name!r}: only LIVE_PRESERVED and " + f"CONSUMED effects are supported (got {inp.effect.name})" + ) + raise GuppyCodegenError(msg) + if decl.return_op is not None: + msg = ( + f"BlockDecl {decl.name!r}: explicit Return inside BlockDecl is not yet " + "supported; live_preserved inputs are returned implicitly" + ) + raise GuppyCodegenError(msg) - Also pre-registers measurement variable names for return type generation. - """ - for stmt in stmts: - if isinstance(stmt, MeasureOp): - for i, target in enumerate(stmt.targets): - self.context.mark_measured(target.allocator, target.index) - # Pre-register measurement variable name - if i < len(stmt.results): - result = stmt.results[i] - var_name = f"{result.register}_{result.index}" - else: - var_name = f"_m{len(self.context.measurement_vars)}" - self.context.measurement_vars.append(var_name) - elif isinstance(stmt, IfStmt): - self._scan_for_measurements(stmt.then_body) - if stmt.else_body: - self._scan_for_measurements(stmt.else_body) - elif isinstance(stmt, (ForStmt, WhileStmt, RepeatStmt, ParallelBlock)): - self._scan_for_measurements(stmt.body) + def _generate_block_decl(self, decl: BlockDecl) -> list[str]: + """Emit a BlockDecl as a top-level Guppy @guppy def function. - def _calculate_allocator_offsets(self, node: Program) -> None: - """Calculate the offset of each child allocator within its parent. + Each input is one of: + - `array[qubit, N]`: parameter `name: array[qubit, N] @ owned`, + unpacks at entry into `name_0..name_{N-1}` slots. + - bare `qubit`: parameter `name: qubit @ owned`, aliased to its + 1-slot linearity binding at entry. + - bare `bit`: parameter `name: array[bool, 1] @ owned` (write-back + proxy); body BitRefs to `name` render as `name[0]`; the array is + returned so the caller sees the mutation. - Children are allocated sequentially within their parent's capacity. - This allows translating child[i] to parent[offset + i]. + Future iters add qubit/bit bundles and PRODUCED/DROPPED effects. """ - # Track allocated space per parent - parent_next_offset: dict[str, int] = {} - - # Root allocators have offset 0 - for decl in node.declarations: - if isinstance(decl, AllocatorDecl) and decl.parent is None: - self.context.allocator_offsets[decl.name] = 0 - - if node.allocator and node.allocator.parent is None: - self.context.allocator_offsets[node.allocator.name] = 0 - - # Process child allocators in declaration order - for decl in node.declarations: - if isinstance(decl, AllocatorDecl) and decl.parent is not None: - parent = decl.parent - if parent not in parent_next_offset: - parent_next_offset[parent] = 0 - - # Get parent's offset (for nested hierarchies) - parent_offset = self.context.allocator_offsets.get(parent, 0) - - # This child's offset is parent's offset + next available slot - self.context.allocator_offsets[decl.name] = parent_offset + parent_next_offset[parent] + saved_context = self.context + self.context = GuppyContext() + # Categorize each input. `_validate_block_decl` guarantees the type_expr + # is one of: bare qubit, array[qubit, N], or bare bit. `size` is the + # qubit-array length (1 for single qubit, unused for bit). + input_shapes: list[tuple[str, str, int]] = [] # (name, kind, size) + for inp in decl.inputs: + if inp.effect is ResourceEffect.SCRATCH: + # Bare-qubit ancilla allocated internally (no parameter). + input_shapes.append((inp.name, "scratch_qubit", 1)) + elif isinstance(inp.type_expr, BitTypeExpr): + input_shapes.append((inp.name, "single_bit", 1)) + elif isinstance(inp.type_expr, QubitTypeExpr): + input_shapes.append((inp.name, "single_qubit", 1)) + else: + arr = cast("ArrayTypeExpr", inp.type_expr) + input_shapes.append((inp.name, "qubit_array", arr.size)) + for name, kind, size in input_shapes: + if kind == "single_bit": + # Register the bit-input name so body BitRefs (`name[0]`) render; + # do NOT emit an initializer -- it's a bound parameter. + self.context.registers[name] = RegisterDecl(name=name, size=1) + else: + # scratch_qubit is registered too: the body's Prep/gates/Measure + # still resolve `Slot(name, i)` through linearity. It just has + # no parameter and is seeded CONSUMED below so the first Prep + # allocates a fresh internal `qubit()`. + self.context.root_allocators[name] = size + self.context.populate_slot_locals() + self.context.linearity = GuppyLinearityState.from_allocators( + self.context.root_allocators, + slot_locals=self.context.slot_locals, + ) + # Seed scratch slots CONSUMED so the body's first `Prep(scratch)` takes + # the fresh-`qubit()` branch in `_emit_prepare` (no entry binding, no + # parameter to alias). `from_allocators` starts every slot LIVE. + for name, kind, size in input_shapes: + if kind == "scratch_qubit": + for index in range(size): + self.context.linearity.consume(Slot(name, index)) + + live_inputs = tuple( + (inp, shape) + for inp, shape in zip(decl.inputs, input_shapes, strict=True) + if inp.effect is ResourceEffect.LIVE_PRESERVED + ) + + return_types: list[str] = [] + for _inp, (_name, kind, size) in live_inputs: + if kind == "qubit_array": + return_types.append(f"array[qubit, {size}]") + elif kind == "single_qubit": + return_types.append("qubit") + else: # single_bit + return_types.append("array[bool, 1]") + return_type_str = self._tuple_type(return_types) + + param_parts: list[str] = [] + for name, kind, size in input_shapes: + if kind == "scratch_qubit": + continue # allocated internally -- no parameter + if kind == "qubit_array": + param_parts.append(f"{name}: array[qubit, {size}] @ owned") + elif kind == "single_qubit": + param_parts.append(f"{name}: qubit @ owned") + else: # single_bit write-back proxy + param_parts.append(f"{name}: array[bool, 1] @ owned") + params = ", ".join(param_parts) + + lines: list[str] = ["@guppy", f"def {decl.name}({params}) -> {return_type_str}:"] - # Reserve space in parent - parent_next_offset[parent] += decl.capacity + self.context.push_indent() + body_lines: list[str] = [] + linearity = self._linearity() + # Per-input entry bindings: array inputs unpack into per-slot locals; + # single-qubit inputs alias the param to its single slot's linearity + # local; single-bit inputs need no entry binding (the param IS the + # array[bool, 1] the body writes to via `name[0]`). + for name, kind, size in input_shapes: + if kind == "qubit_array": + locals_for = [binding.local for slot, binding in linearity.bindings() if slot.allocator == name] + lhs = ", ".join(locals_for) + if size == 1: + lhs += "," + body_lines.append(f"{self.context.indent()}{lhs} = {name}") + elif kind == "single_qubit": + slot_local = linearity.live(Slot(name, 0)) + body_lines.append(f"{self.context.indent()}{slot_local} = {name}") + # single_bit: no entry binding needed. + + for stmt in decl.body: + body_lines.extend(self._emit_stmt(stmt)) + + # Auto-emit return for live_preserved inputs. Qubit-array inputs repack + # per-slot locals into `array(...)`; single-qubit inputs return their + # slot's local; single-bit inputs return the (mutated) array[bool, 1] + # parameter directly. + if live_inputs: + return_exprs: list[str] = [] + for _inp, (name, kind, _size) in live_inputs: + if kind == "qubit_array": + return_exprs.append(self._consume_allocator_for_return(name)) + elif kind == "single_qubit": + return_exprs.append(linearity.consume(Slot(name, 0))) + else: # single_bit: the array param itself + return_exprs.append(name) + if len(return_exprs) == 1: + body_lines.append(f"{self.context.indent()}return {return_exprs[0]}") + else: + body_lines.append(f"{self.context.indent()}return {', '.join(return_exprs)}") + else: + body_lines.extend(self._emit_end_cleanup()) - def _generate_params(self, node: Program) -> str: - """Generate function parameters from declarations. + if body_lines: + lines.extend(body_lines) + else: + lines.append(f"{self.context.indent()}pass") - Only includes root allocators (those without parents) as function parameters. - Child allocators are derived from parent allocators within the function. - """ - params = [] + self.context.pop_indent() + self.context = saved_context + return lines - # Add allocator parameters (only root allocators without parents) - for decl in node.declarations: + def _collect_declarations(self, program: Program) -> None: + for decl in program.declarations: if isinstance(decl, AllocatorDecl): - # Skip child allocators - they're derived from parents - if decl.parent is not None: - continue - params.append(f"{decl.name}: array[qubit, {decl.capacity}] @owned") - - if node.allocator and node.allocator.parent is None: - params.append( - f"{node.allocator.name}: array[qubit, {node.allocator.capacity}] @owned", + self._add_allocator_decl(decl) + elif isinstance(decl, RegisterDecl): + self.context.registers[decl.name] = decl + + if program.allocator is not None: + self._add_allocator_decl(program.allocator) + + def _add_allocator_decl(self, decl: AllocatorDecl) -> None: + if decl.parent is not None: + self.context.child_allocators.add(decl.name) + return + self.context.root_allocators.setdefault(decl.name, decl.capacity) + + def _collect_implicit_measure_registers(self, body: tuple[Statement, ...]) -> None: + """Add result registers introduced only as measurement outputs.""" + max_indices: dict[str, int] = {} + for stmt in body: + self._collect_implicit_measure_register_refs(stmt, max_indices) + + for register, max_index in max_indices.items(): + if register not in self.context.registers: + self.context.registers[register] = RegisterDecl(name=register, size=max_index + 1) + + def _collect_implicit_measure_register_refs(self, stmt: Statement, max_indices: dict[str, int]) -> None: + if isinstance(stmt, MeasureOp): + for ref in stmt.results: + if ref.register not in self.context.registers: + max_indices[ref.register] = max(max_indices.get(ref.register, -1), ref.index) + return + + if isinstance(stmt, IfStmt): + self._collect_implicit_measure_registers_in_body(stmt.then_body, max_indices) + self._collect_implicit_measure_registers_in_body(stmt.else_body, max_indices) + return + + if isinstance(stmt, RepeatStmt | ForStmt | WhileStmt | ParallelBlock): + self._collect_implicit_measure_registers_in_body(stmt.body, max_indices) + + def _collect_implicit_measure_registers_in_body( + self, + body: tuple[Statement, ...], + max_indices: dict[str, int], + ) -> None: + for stmt in body: + self._collect_implicit_measure_register_refs(stmt, max_indices) + + def _reject_child_allocators(self) -> None: + if self.context.child_allocators: + names = ", ".join(sorted(self.context.child_allocators)) + msg = f"AST -> Guppy v1 does not support child allocators: {names}" + raise GuppyCodegenError(msg) + + def _iter_stmts(self, body: tuple[Statement, ...]) -> Iterator[Statement]: + """Yield every statement in `body`, recursing into control flow.""" + for stmt in body: + yield stmt + if isinstance(stmt, IfStmt): + yield from self._iter_stmts(stmt.then_body) + yield from self._iter_stmts(stmt.else_body) + elif isinstance(stmt, (WhileStmt, ForStmt, RepeatStmt, ParallelBlock)): + yield from self._iter_stmts(stmt.body) + + def _validate_scratch_outer_slots(self, program: Program) -> None: + """Reject programs where a scratch-bound outer slot is also used as + meaningful caller state. + + A SCRATCH input lowers asymmetrically: flatten substitutes the + scratch param to the outer slot (the block resets+measures THAT + slot), while Guppy allocates the ancilla internally and leaves the + outer slot untouched. The two paths only agree when the outer slot + is pure scratch -- never observed by the caller. If the caller + gates/measures/barriers/permutes/returns it, or hands it to another + block, the codegens diverge. Multiple scratch BlockCalls reusing + the slot stay allowed (the intended SynExtractBare pattern). + + `PrepareOp` on a scratch outer slot IS allowed: a reset is + unobserved and dead under both lowerings (and the qeclib corpus + wholesale-preps the ancilla register, e.g. `Prep(q)` covering + `q[3]`, before using `q[3]` as a check ancilla). + + Runs per scope: `program.body` AND every + `BlockDecl.body` -- a nested BlockCall whose scratch arg references + the enclosing block's param slot has the same purity requirement + within that block's scope. + """ + self._validate_scratch_purity_in_scope(program.body) + for decl in program.block_decls: + self._validate_scratch_purity_in_scope(decl.body) + + @staticmethod + def _ref_base(ref: str) -> str: + """Leading identifier of a string ref (`q[0]`/`q.x`/`q` -> `q`).""" + return ref.split("[", 1)[0].split(".", 1)[0] + + def _validate_scratch_purity_in_scope(self, body: tuple[Statement, ...]) -> None: + scratch: dict[tuple[str, int], str] = {} + for stmt in self._iter_stmts(body): + if not isinstance(stmt, BlockCall): + continue + decl = self._block_decls.get(stmt.callee) + if decl is None: + continue + for inp, arg in zip(decl.inputs, stmt.arg_bindings, strict=False): + if inp.effect is ResourceEffect.SCRATCH and isinstance(arg, SingleQubitArg): + scratch[(arg.slot.allocator, arg.slot.index)] = stmt.callee + if not scratch: + return + scratch_allocs = {alloc for alloc, _ in scratch} + + def _reject(where: str, slot: tuple[str, int]) -> None: + callee = scratch[slot] + msg = ( + f"Scratch outer slot {slot[0]}[{slot[1]}] (bound as the " + f"scratch ancilla of BlockCall {callee!r}) is also used as " + f"meaningful caller state by {where}. A scratch-bound slot " + "must be pure scratch -- flatten mutates it while Guppy " + "allocates the ancilla internally, so any other use " + "diverges. (A bare Prep on it is allowed; reusing it across " + "scratch BlockCalls is allowed.)" ) - - return ", ".join(params) - - def _generate_return_type(self, node: Program) -> str: - """Generate return type annotation based on consumed/unconsumed qubits.""" - return_types = [] - - # Only include qubit arrays that are NOT fully consumed by measurement - for decl in node.declarations: - if isinstance(decl, AllocatorDecl): - # Skip child allocators - only include root allocators in params/returns - if decl.parent is not None: + raise GuppyCodegenError(msg) + + def _reject_alloc(where: str, alloc: str) -> None: + _reject(where, next(s for s in scratch if s[0] == alloc)) + + for stmt in self._iter_stmts(body): + if isinstance(stmt, GateOp): + for t in stmt.targets: + if (t.allocator, t.index) in scratch: + _reject(f"a {stmt.gate.name} gate", (t.allocator, t.index)) + elif isinstance(stmt, MeasureOp): + for t in stmt.targets: + if (t.allocator, t.index) in scratch: + _reject("a Measure", (t.allocator, t.index)) + elif isinstance(stmt, BarrierOp): + for alloc in stmt.allocators: + if alloc in scratch_allocs: + # Name-level: conservatively reject any barrier + # naming a register that hosts a scratch slot. + _reject_alloc("a Barrier", alloc) + elif isinstance(stmt, PermuteOp): + # sources/targets are string refs (`q`, `q[0]`); a permute + # touching the scratch register reorders/observes it. + for ref in (*stmt.sources, *stmt.targets): + if self._ref_base(ref) in scratch_allocs: + _reject_alloc("a Permute", self._ref_base(ref)) + elif isinstance(stmt, ReturnOp): + # Returning the scratch slot (or its register) exposes the + # outer slot the caller would observe -- flatten mutated it, + # Guppy did not. + for v in stmt.values: + name = v if isinstance(v, str) else getattr(v, "name", "") + if name and self._ref_base(name) in scratch_allocs: + _reject_alloc("a Return", self._ref_base(name)) + elif isinstance(stmt, BlockCall): + decl = self._block_decls.get(stmt.callee) + if decl is None: continue - if not self.context.is_allocator_fully_consumed(decl.name): - return_types.append(f"array[qubit, {decl.capacity}]") - - if node.allocator and not self.context.is_allocator_fully_consumed( - node.allocator.name, - ): - return_types.append(f"array[qubit, {node.allocator.capacity}]") - - # Add measurement results (bools) - if self.context.measurement_vars: - return_types.extend("bool" for _ in self.context.measurement_vars) - - if not return_types: + for inp, arg in zip(decl.inputs, stmt.arg_bindings, strict=False): + if inp.effect is ResourceEffect.SCRATCH: + continue # the scratch binding itself -- allowed + if isinstance(arg, AllocatorArg) and arg.name in scratch_allocs: + _reject_alloc(f"a non-scratch input of BlockCall {stmt.callee!r}", arg.name) + elif ( + isinstance(arg, SingleQubitArg) + and ( + arg.slot.allocator, + arg.slot.index, + ) + in scratch + ): + _reject( + f"a non-scratch input of BlockCall {stmt.callee!r}", + (arg.slot.allocator, arg.slot.index), + ) + elif isinstance(arg, QubitBundleArg): + for s in arg.slots: + if (s.allocator, s.index) in scratch: + _reject( + f"a non-scratch input of BlockCall {stmt.callee!r}", + (s.allocator, s.index), + ) + + def _imports(self) -> list[str]: + imports = [ + "from guppylang import guppy", + "from guppylang.std.builtins import array, owned, result", + "from guppylang.std.mem import mem_swap", + "from guppylang.std.quantum import discard, measure, qubit", + f"from guppylang.std.quantum.functional import {FUNCTIONAL_GATE_IMPORTS}", + ] + # `angle` is needed for parameterized rotations; `zz_phase` is + # the native Quantinuum Q-System 2q ZZ rotation (= RZZ) used by + # the SZZ/SXX/SYY-family + RZZ decompositions. Both imported + # unconditionally (cheap; unused imports are harmless -- Guppy + # only compiles the ops a program actually calls). + imports.append("from guppylang.std.angles import angle") + imports.append("from guppylang.std.qsystem.functional import zz_phase") + return imports + + def _render_params(self) -> str: + return ", ".join(f"{name}: array[qubit, {size}] @ owned" for name, size in self.context.root_allocators.items()) + + def _return_type(self, explicit_return: ReturnOp | None) -> str: + if explicit_return is None: return "None" - if len(return_types) == 1: - return return_types[0] - return f"tuple[{', '.join(return_types)}]" - - def _generate_return_statement(self, node: Program) -> list[str]: - """Generate return statement with unconsumed qubits and measurement results.""" - return_values = [] - - # Return unconsumed qubit arrays - for decl in node.declarations: - if isinstance(decl, AllocatorDecl): - # Skip child allocators - if decl.parent is not None: - continue - if not self.context.is_allocator_fully_consumed(decl.name): - return_values.append(decl.name) - - if node.allocator and not self.context.is_allocator_fully_consumed( - node.allocator.name, - ): - return_values.append(node.allocator.name) - - # Return measurement results - return_values.extend(self.context.measurement_vars) - - if not return_values: - return [] - - return [f"{self.context.indent()}return {', '.join(return_values)}"] - - # === Declarations === - - def visit_allocator_decl(self, _node: AllocatorDecl) -> list[str]: - """Allocator declarations are handled at program level.""" - return [] - - def visit_register_decl(self, _node: RegisterDecl) -> list[str]: - """Register declarations are handled at program level.""" - return [] + types = [self._return_value_type(value) for value in explicit_return.values] + return self._tuple_type(types) + + def _return_value_type(self, value: Expression | str) -> str: + if isinstance(value, str): + if value in self.context.root_allocators: + return f"array[qubit, {self.context.root_allocators[value]}]" + if value in self.context.registers: + return f"array[bool, {self.context.registers[value].size}]" + msg = f"Unsupported Guppy return value {value!r}" + raise GuppyCodegenError(msg) + + if isinstance(value, BitExpr): + return "bool" + if isinstance(value, LiteralExpr) and isinstance(value.value, bool): + return "bool" + if isinstance(value, LiteralExpr) and isinstance(value.value, int): + return "int" + msg = f"Unsupported Guppy return expression {value!r}" + raise GuppyCodegenError(msg) + + def _tuple_type(self, types: list[str]) -> str: + if not types: + return "None" + if len(types) == 1: + return types[0] + return f"tuple[{', '.join(types)}]" + + def _emit_entry_unpacks(self) -> list[str]: + lines: list[str] = [] + linearity = self._linearity() + for allocator, size in self.context.root_allocators.items(): + if size == 0: + continue + locals_for_allocator = [ + binding.local for slot, binding in linearity.bindings() if slot.allocator == allocator + ] + lhs = ", ".join(locals_for_allocator) + if size == 1: + lhs += "," + lines.append(f"{self.context.indent()}{lhs} = {allocator}") + return lines - # === Gates === + def _emit_register_initializers(self) -> list[str]: + lines: list[str] = [] + for decl in self.context.registers.values(): + values = ", ".join("False" for _ in range(decl.size)) + lines.append(f"{self.context.indent()}{decl.name} = array({values})") + return lines - def visit_gate(self, node: GateOp) -> list[str]: - """Generate gate operation.""" - gate_func = GATE_TO_GUPPY.get(node.gate, f"quantum.{node.gate.name.lower()}") + def _validate_return_position(self, body: list[Statement]) -> ReturnOp | None: + return_count = self._count_returns(body) + if return_count == 0: + return None + if return_count == 1 and body and isinstance(body[-1], ReturnOp): + return body[-1] + msg = "AST -> Guppy v1 supports only one final root-level Return" + raise GuppyCodegenError(msg) + + def _count_returns(self, body: list[Statement] | tuple[Statement, ...]) -> int: + count = 0 + for stmt in body: + if isinstance(stmt, ReturnOp): + count += 1 + elif isinstance(stmt, IfStmt): + count += self._count_returns(stmt.then_body) + count += self._count_returns(stmt.else_body) + elif isinstance(stmt, WhileStmt | ForStmt | RepeatStmt | ParallelBlock): + count += self._count_returns(stmt.body) + return count + + def _validate_print_paths(self, body: tuple[Statement, ...]) -> None: + """Validate Print path-signature consistency across If/Elif branches. + + Walks the body once, descending into nested control flow. For each + If, both `then_body` and `else_body` (recursively) must emit the + same ordered sequence of Print events. `Repeat(n)` and static-bound + `For(name, start, stop[, step])` multiply inner signatures by the + static trip count. Non-static `For` and `While` reject Prints + since the trip count is not statically known. + + Side effect: raises `GuppyCodegenError` if any validation fails. + """ + from pecos.slr.ast.nodes import PrintOp # noqa: PLC0415 + + self._collect_print_path_signature(body, PrintOp) + + def _collect_print_path_signature( + self, + body: tuple[Statement, ...], + print_op_cls: type, + ) -> tuple[tuple[str, str, str, int], ...]: + """Return the ordered Print signature for `body`, validating as we go. + + Each Print emission contributes one signature tuple: + `(namespace, tag, value_kind, value_shape)` where `value_kind` is + `"creg"` (whole register) or `"bit"` (single bit) and + `value_shape` is the register size (or 1 for bit refs). + + Side effects: + - Raises `GuppyCodegenError` if an `If` body has asymmetric Print + signatures across `then_body` / `else_body`. + - Raises `GuppyCodegenError` if a non-static `For` or any `While` + contains a Print in its body. + """ + signature: list[tuple[str, str, str, int]] = [] + for stmt in body: + if isinstance(stmt, print_op_cls): + signature.append(self._print_op_event(stmt)) + elif isinstance(stmt, IfStmt): + then_sig = self._collect_print_path_signature(stmt.then_body, print_op_cls) + else_sig = self._collect_print_path_signature(stmt.else_body, print_op_cls) + if then_sig != else_sig: + msg = ( + "Print path-signature mismatch across If branches:\n" + f" Then: {then_sig}\n" + f" Else: {else_sig}\n" + "Symmetric Print emission is required across all branches of an " + "If/Elif chain. Either add the missing Print(s) to the lighter branch, " + "or move the Print outside the If." + ) + raise GuppyCodegenError(msg) + signature.extend(then_sig) + elif isinstance(stmt, RepeatStmt): + inner = self._collect_print_path_signature(stmt.body, print_op_cls) + signature.extend(inner * stmt.count) + elif isinstance(stmt, ForStmt): + inner = self._collect_print_path_signature(stmt.body, print_op_cls) + if inner: + trip = self._static_for_trip_count(stmt) + if trip is None: + msg = ( + "Print inside non-static `For` is not supported. " + "Use `Repeat(n)` or `For(name, start, stop)` with literal int " + "start/stop/step, or move the Print outside the For body." + ) + raise GuppyCodegenError(msg) + signature.extend(inner * trip) + elif isinstance(stmt, WhileStmt): + inner = self._collect_print_path_signature(stmt.body, print_op_cls) + if inner: + msg = ( + "Print inside `While` is not supported (no static trip " + "count). Move the Print outside the While body." + ) + raise GuppyCodegenError(msg) + elif isinstance(stmt, ParallelBlock): + signature.extend(self._collect_print_path_signature(stmt.body, print_op_cls)) + return tuple(signature) + + def _print_op_event(self, op) -> tuple[str, str, str, int]: + if isinstance(op.value, BitRef): + return (op.namespace, op.tag, "bit", 1) + # str = whole CReg name + decl = self.context.registers.get(op.value) + shape = decl.size if decl is not None else 0 + return (op.namespace, op.tag, "creg", shape) + + def _static_for_trip_count(self, stmt: ForStmt) -> int | None: + """Compute static trip count for a `For(name, start, stop[, step])`. + + Returns the integer trip count when start/stop/step are all integer + literals; returns None otherwise (Print is then rejected in the + loop body via `_collect_print_path_signature`). + """ + start = self._static_int(stmt.start) + stop = self._static_int(stmt.stop) + if start is None or stop is None: + return None + step = 1 + if stmt.step is not None: + step_val = self._static_int(stmt.step) + if step_val is None: + return None + step = step_val + if step == 0: + return None + return len(range(start, stop, step)) + + def _static_int(self, expr) -> int | None: + if isinstance(expr, LiteralExpr) and isinstance(expr.value, int) and not isinstance(expr.value, bool): + return expr.value + return None + + def _validate_print_inline_creg_assignment(self, program: Program) -> None: + """Reject Print of an inline CReg bit before Measure has written to it. + + Inline CRegs are those introduced only by `Measure(q) > CReg(...)` -- + they appear in `context.registers` via `_collect_implicit_measure_registers` + but not in `program.declarations`. Without explicit user declaration in + `Main(...)`, the runtime sees an auto-initialized all-False register if a + Print runs before any Measure has populated it. That silent zero-emission + is the bug; this validator rejects it. + + Declared CRegs (those in `program.declarations`) are NOT validated -- + users who explicitly declare a CReg are acknowledging the zero-init. + + Granularity: **bit-level**. The validator tracks `(register, bit_index)` + pairs. `Print(c[i])` requires the specific `(c, i)` to be assigned; + whole-CReg `Print(c)` requires every bit `0..size-1` (inferred size) to + be assigned. Bit-level tracking also acts as a bounds check: a `Print` + referencing an index past the inferred size is rejected because that + `(reg, index)` cannot have been added by any Measure. + """ + declared = {d.name for d in program.declarations if isinstance(d, RegisterDecl)} + inline_cregs = {name for name in self.context.registers if name not in declared} + if not inline_cregs: + return + assigned: set[tuple[str, int]] = set() + self._check_print_inline_assignment(program.body, assigned, inline_cregs) + + def _check_print_inline_assignment( + self, + body: tuple[Statement, ...], + assigned: set[tuple[str, int]], + inline_cregs: set[str], + ) -> None: + """Walk body left-to-right; mutate `assigned` in-place across the path. + + At control-flow joins, merge per-path assignment sets: + - `If(...)`: definite-after = intersection of `then`/`else` post-states. + - `Repeat(n>=1)` / static `For(count>=1)`: body runs at least once, so + inner assignments propagate. + - `Repeat(0)` / static `For(count<=0)` / non-static `For` / `While`: + body may not run. Walk for validation (catch unreachable invalid + Prints), but do NOT propagate inner assignments to the outer scope. + - `Parallel`: treated as sequential (matches the emitter's flatten + behavior). + """ + from pecos.slr.ast.nodes import PrintOp # noqa: PLC0415 - # Generate target references - targets = [self._render_slot_ref(t) for t in node.targets] + for stmt in body: + if isinstance(stmt, MeasureOp): + for ref in stmt.results: + if ref.register in inline_cregs: + assigned.add((ref.register, ref.index)) + elif isinstance(stmt, PrintOp): + self._check_print_inline_read(stmt, assigned, inline_cregs) + elif isinstance(stmt, IfStmt): + then_assigned = set(assigned) + self._check_print_inline_assignment(stmt.then_body, then_assigned, inline_cregs) + else_assigned = set(assigned) + self._check_print_inline_assignment(stmt.else_body, else_assigned, inline_cregs) + assigned.update(then_assigned & else_assigned) + elif isinstance(stmt, RepeatStmt): + if stmt.count >= 1: + inner_assigned = set(assigned) + self._check_print_inline_assignment(stmt.body, inner_assigned, inline_cregs) + assigned.update(inner_assigned) + else: + # count == 0: body doesn't run. Walk for validation; do not propagate. + self._check_print_inline_assignment(stmt.body, set(assigned), inline_cregs) + elif isinstance(stmt, ForStmt): + trip = self._static_for_trip_count(stmt) + if trip is not None and trip >= 1: + inner_assigned = set(assigned) + self._check_print_inline_assignment(stmt.body, inner_assigned, inline_cregs) + assigned.update(inner_assigned) + else: + self._check_print_inline_assignment(stmt.body, set(assigned), inline_cregs) + elif isinstance(stmt, WhileStmt): + self._check_print_inline_assignment(stmt.body, set(assigned), inline_cregs) + elif isinstance(stmt, ParallelBlock): + self._check_print_inline_assignment(stmt.body, assigned, inline_cregs) + + def _check_print_inline_read(self, op, assigned: set[tuple[str, int]], inline_cregs: set[str]) -> None: + if isinstance(op.value, BitRef): + reg = op.value.register + if reg not in inline_cregs: + return + if (reg, op.value.index) not in assigned: + msg = ( + f"Print references inline CReg bit {reg}[{op.value.index}] before any " + f"Measure has written to it. Print would emit the auto-initialized False " + f"value (or read past the inferred register bound), not a measurement " + f"result. Move the Print after a Measure(...) > {reg}[{op.value.index}] " + f"that runs on every path, or declare {reg!r} explicitly as a positional " + f"in Main(...) if you intend to print the zero-initialized state." + ) + raise GuppyCodegenError(msg) + elif isinstance(op.value, str): + reg = op.value + if reg not in inline_cregs: + return + # Reject whole-CReg Print of an inline CReg outright. + # The user-stated `CReg(name, size)` size is lost during inline-from- + # Measure inference (only Measure-targeted bit indices contribute to + # the inferred RegisterDecl.size). Emitting `result(tag, c)` for the + # inferred c can silently shrink the register relative to the user's + # intent. Require either an explicit Main(...) declaration (then the + # CReg is no longer inline and whole-CReg Print is allowed) or per-bit + # `Print(c[i], ...)` calls. + msg = ( + f"Print(whole-CReg) of inline CReg {reg!r} is rejected. " + "Whole-register Print can silently shrink an inline CReg because the " + "original `CReg(name, size)` size is lost during inline-from-Measure " + f"inference. Declare {reg!r} as a positional in `Main(...)` (then whole-" + f"CReg Print is allowed) or print individual bits via `Print({reg}[i], ...)`." + ) + raise GuppyCodegenError(msg) + + def _emit_stmt(self, stmt: Statement) -> list[str]: + if isinstance(stmt, GateOp): + return self._emit_gate(stmt) + if isinstance(stmt, PrepareOp): + return self._emit_prepare(stmt) + if isinstance(stmt, MeasureOp): + return self._emit_measure(stmt) + if isinstance(stmt, IfStmt): + return self._emit_if(stmt) + if isinstance(stmt, RepeatStmt): + return self._emit_repeat(stmt) + if isinstance(stmt, ForStmt): + return self._emit_for(stmt) + if isinstance(stmt, WhileStmt): + msg = "AST -> Guppy v1 does not support While loops" + raise GuppyCodegenError(msg) + if isinstance(stmt, ParallelBlock): + return self._emit_parallel(stmt) + if isinstance(stmt, BlockCall): + return self._emit_block_call(stmt) + if isinstance(stmt, ReturnOp): + msg = "AST -> Guppy v1 supports Return only as the final root-level statement" + raise GuppyCodegenError(msg) + + from pecos.slr.ast.nodes import AssignOp, CommentOp, PrintOp # noqa: PLC0415 + + if isinstance(stmt, AssignOp): + return self._emit_assign(stmt) + if isinstance(stmt, BarrierOp): + return self._emit_barrier(stmt) + if isinstance(stmt, CommentOp): + return self._emit_comment(stmt) + if isinstance(stmt, PermuteOp): + return self._emit_permute(stmt) + if isinstance(stmt, PrintOp): + return self._emit_print(stmt) + + msg = f"Unsupported AST statement for Guppy codegen: {type(stmt).__name__}" + raise GuppyCodegenError(msg) + + def _emit_gate(self, node: GateOp) -> list[str]: + if node.gate in PARAMETERIZED_FUNCTIONAL_GATES: + return self._emit_parameterized_gate(node) + + if node.gate in GUPPY_GATE_DECOMP: + return self._emit_decomposed_gate(node) + + gate = FUNCTIONAL_GATES.get(node.gate) + if gate is None: + self._raise_unsupported_gate(node.gate) - # Handle parameterized gates if node.params: - params = [self._render_expression(p) for p in node.params] - args = ", ".join(params + targets) - else: - args = ", ".join(targets) + msg = f"AST -> Guppy v1 does not support parameterized gate {node.gate.name}" + raise GuppyCodegenError(msg) - # Single qubit gates need reassignment for linearity - if node.gate.arity == 1: - target = targets[0] - return [f"{self.context.indent()}{target} = {gate_func}({target})"] - # Two-qubit gates return a tuple - return [ - f"{self.context.indent()}{targets[0]}, {targets[1]} = {gate_func}({args})", - ] + slots = [self._slot_from_ref(target) for target in node.targets] + if len(slots) != len(set(slots)): + msg = f"Gate {node.gate.name} uses the same qubit slot more than once" + raise GuppyCodegenError(msg) - def visit_prepare(self, node: PrepareOp) -> list[str]: - """Generate prepare/reset operation.""" - lines = [] + linearity = self._linearity() + locals_ = [linearity.live(slot) for slot in slots] - if node.slots is None: - # Prepare all - would need array iteration - lines.append( - f"{self.context.indent()}# Prepare all slots in {node.allocator}", - ) - else: - for slot in node.slots: - ref = f"{node.allocator}[{slot}]" - # In Guppy, qubits start in |0⟩ state from allocation - # For re-preparation after measurement, we'd use reset - lines.append( - f"{self.context.indent()}{ref} = quantum.reset({ref})", - ) + if node.gate.arity == 1: + local = locals_[0] + linearity.set_live(slots[0], local) + return [f"{self.context.indent()}{local} = {gate}({local})"] + + if node.gate.arity == 2: + left, right = locals_ + linearity.set_live(slots[0], left) + linearity.set_live(slots[1], right) + return [f"{self.context.indent()}{left}, {right} = {gate}({left}, {right})"] + + msg = f"AST -> Guppy v1 does not support {node.gate.arity}-qubit gate {node.gate.name}" + raise GuppyCodegenError(msg) + + def _emit_parameterized_gate(self, node: GateOp) -> list[str]: + """Emit a native Guppy rotation: `fn(qubit..., angle(half_turns))`. + + Guppy's `angle` stores half-turns (pi rad == 1.0 half-turn), so + the typed `Angle` param is emitted via + ``angle64.to_half_turns_signed()`` -- no radians/pi conversion. + Only typed `Angle` params are supported (a bare float or a + non-literal classical expression at a gate-param position fails + loud, mirroring the QIR backend's parameterized guard). + """ + gate = PARAMETERIZED_FUNCTIONAL_GATES[node.gate] + if not node.params: + msg = f"AST -> Guppy v1: parameterized gate {node.gate.name} requires an angle parameter" + raise GuppyCodegenError(msg) + angle_args = [f"angle({_param_to_half_turns(param, node.gate.name)})" for param in node.params] + angle_str = ", ".join(angle_args) - return lines + slots = [self._slot_from_ref(target) for target in node.targets] + if len(slots) != len(set(slots)): + msg = f"Gate {node.gate.name} uses the same qubit slot more than once" + raise GuppyCodegenError(msg) - def visit_measure(self, node: MeasureOp) -> list[str]: - """Generate measurement operation. + linearity = self._linearity() + locals_ = [linearity.live(slot) for slot in slots] - In Guppy, quantum.measure() consumes the qubit and returns a bool. - We use local variable names for measurement results. - Variable names are pre-registered during scan phase for return type generation. + if node.gate.arity == 1: + local = locals_[0] + linearity.set_live(slots[0], local) + return [f"{self.context.indent()}{local} = {gate}({local}, {angle_str})"] + + if node.gate.arity == 2: + left, right = locals_ + linearity.set_live(slots[0], left) + linearity.set_live(slots[1], right) + return [f"{self.context.indent()}{left}, {right} = {gate}({left}, {right}, {angle_str})"] + + msg = f"AST -> Guppy v1 does not support {node.gate.arity}-qubit parameterized gate {node.gate.name}" + raise GuppyCodegenError(msg) + + def _emit_decomposed_gate(self, node: GateOp) -> list[str]: + """Emit a multi-step decomposition for a PECOS gate with no native + single Guppy gate (`GUPPY_GATE_DECOMP`). + + Each step is a Guppy-native gate call threaded through the + linearity tracker (functional style: `local = fn(local)` or + `a, b = fn(a, b[, angle])`). Angle specs that are callable read + the input gate's (literal) params; non-literal params fail loud. """ - lines = [] - - for i, target in enumerate(node.targets): - target_ref = self._render_slot_ref(target) - - if i < len(node.results): - result = node.results[i] - # Use a proper local variable name instead of array indexing - var_name = f"{result.register}_{result.index}" + steps = GUPPY_GATE_DECOMP[node.gate] + base_slots = [self._slot_from_ref(target) for target in node.targets] + if len(base_slots) != len(set(base_slots)): + msg = f"Gate {node.gate.name} uses the same qubit slot more than once" + raise GuppyCodegenError(msg) + + # Resolve the input gate's literal params once (only needed if a + # step has a callable angle spec). + resolved_params: tuple[float, ...] | None = None + if any(callable(spec) for _, _, spec in steps): + params = node.params or () + if not params: + # A parameterized decomposition (RZZ/CRX/CRY) with no + # angle -- e.g. the malformed positional call + # `RZZ(q0, q1, 0.5)` that passes the angle as a qarg. + # Fail loud with a clear message (parity with the native + # `_emit_parameterized_gate` guard), not a raw IndexError + # when a callable spec indexes `p[0]`. + msg = f"AST -> Guppy v1: parameterized gate {node.gate.name} requires an angle parameter" + raise GuppyCodegenError(msg) + # User params resolve to half-turns; the forwarding lambdas + # (`p[0]`) then carry half-turns, matching the constant specs. + resolved_params = tuple(_param_to_half_turns(param, node.gate.name) for param in params) + + linearity = self._linearity() + lines: list[str] = [] + for fn, idxs, angle_spec in steps: + slots = [base_slots[i] for i in idxs] + locals_ = [linearity.live(slot) for slot in slots] + if angle_spec is None: + angle_arg = "" else: - # No result specified - use indexed name - var_name = f"_m{i}" - - lines.append( - f"{self.context.indent()}{var_name} = quantum.measure({target_ref})", - ) - + theta = angle_spec(resolved_params) if callable(angle_spec) else angle_spec + angle_arg = f", angle({float(theta)})" + if len(idxs) == 1: + local = locals_[0] + lines.append(f"{self.context.indent()}{local} = {fn}({local}{angle_arg})") + linearity.set_live(slots[0], local) + elif len(idxs) == 2: + left, right = locals_ + lines.append(f"{self.context.indent()}{left}, {right} = {fn}({left}, {right}{angle_arg})") + linearity.set_live(slots[0], left) + linearity.set_live(slots[1], right) + else: + msg = f"AST -> Guppy v1: decomposition step for {node.gate.name} has unsupported arity {len(idxs)}" + raise GuppyCodegenError(msg) return lines - # === Statements === + def _raise_unsupported_gate(self, gate: GateKind) -> None: + if gate.is_parameterized: + msg = f"AST -> Guppy v1 does not support parameterized gate {gate.name}" + raise GuppyCodegenError(msg) + msg = f"AST -> Guppy v1 does not support gate {gate.name}" + raise GuppyCodegenError(msg) + + def _emit_prepare(self, node: PrepareOp) -> list[str]: + # Z-reset/alloc to |0>, then the canonical Clifford tail + # (functional, linearity-preserving -- same FUNCTIONAL_GATES + # path as ordinary 1q gates; the qubit primitive yields |0> + # so this is exactly `PZ(q); H(q); ...`). + tail = prep_tail(node.basis) + lines: list[str] = [] + slots = range(self.context.root_allocators[node.allocator]) if node.slots is None else node.slots + linearity = self._linearity() + for index in slots: + slot = Slot(node.allocator, index) + local = self._local_name(slot) + if linearity.status(slot) is SlotState.LIVE: + cur = linearity.live(slot) + lines.append(f"{self.context.indent()}{cur} = reset({cur})") + else: + cur = local + lines.append(f"{self.context.indent()}{cur} = qubit()") + lines.extend(f"{self.context.indent()}{cur} = {FUNCTIONAL_GATES[gk]}({cur})" for gk in tail) + linearity.set_live(slot, cur) + return lines - def visit_assign(self, node: AssignOp) -> list[str]: - """Generate assignment operation.""" - target = f"{node.target.register}[{node.target.index}]" if isinstance(node.target, BitRef) else str(node.target) + def _emit_measure(self, node: MeasureOp) -> list[str]: + lines: list[str] = [] + linearity = self._linearity() + for index, target in enumerate(node.targets): + slot = self._slot_from_ref(target) + local = linearity.consume(slot) + if index < len(node.results): + result = self._render_bit_ref(node.results[index]) + lines.append(f"{self.context.indent()}{result} = measure({local})") + else: + temp = self.context.temp("measurement") + lines.append(f"{self.context.indent()}{temp} = measure({local})") + return lines - value = self._render_expression(node.value) + def _emit_assign(self, node: AssignOp) -> list[str]: + is_bit_target = isinstance(node.target, BitRef) + target = self._render_bit_ref(node.target) if is_bit_target else str(node.target) + value = self._render_expression(node.value, bool_context=is_bit_target) return [f"{self.context.indent()}{target} = {value}"] - def visit_barrier(self, node: BarrierOp) -> list[str]: - """Generate barrier (as comment - no direct Guppy equivalent).""" - if node.allocators: - allocs = ", ".join(node.allocators) - return [f"{self.context.indent()}# barrier({allocs})"] + def _emit_barrier(self, _node: BarrierOp) -> list[str]: return [f"{self.context.indent()}# barrier"] - def visit_comment(self, node: CommentOp) -> list[str]: - """Generate comment.""" - if node.text: - return [f"{self.context.indent()}# {node.text}"] - return [] - - def visit_return(self, node: ReturnOp) -> list[str]: - """Generate return statement.""" - if not node.values: - return [f"{self.context.indent()}return"] - - values = [] - for v in node.values: - if isinstance(v, str): - values.append(v) - else: - values.append(self._render_expression(v)) - - return [f"{self.context.indent()}return {', '.join(values)}"] - - def visit_permute(self, node: PermuteOp) -> list[str]: - """Generate permutation (register swap) code. - - Generates temp variable assignments to swap register references. - For Permute(a, b), generates: - # Swap a and b - _temp_a = a - a = b - b = _temp_a - """ - lines = [] - - if len(node.sources) != len(node.targets): - lines.append( - f"{self.context.indent()}# ERROR: Permute sources/targets length mismatch", - ) - return lines - - if len(node.sources) == 0: - return lines - - # Add comment if requested - if node.add_comment: - names = " and ".join(node.sources) - lines.append(f"{self.context.indent()}# Swap {names}") - - # For a simple two-way swap: a, b = b, a - if len(node.sources) == 1 and node.sources[0] != node.targets[0]: - src = node.sources[0] - tgt = node.targets[0] - temp = f"_temp_{src}" - lines.append(f"{self.context.indent()}{temp} = {src}") - lines.append(f"{self.context.indent()}{src} = {tgt}") - lines.append(f"{self.context.indent()}{tgt} = {temp}") - elif len(node.sources) == 2 and set(node.sources) == set(node.targets): - # Simple swap: Permute([a, b], [b, a]) - # Can use Python tuple swap - a, b = node.sources - lines.append(f"{self.context.indent()}{a}, {b} = {b}, {a}") - else: - # General case: use temp variables - temps = [] - for src in node.sources: - temp = f"_temp_{src}" - temps.append(temp) - lines.append(f"{self.context.indent()}{temp} = {src}") - - for i, src in enumerate(node.sources): - tgt = node.targets[i] - lines.append(f"{self.context.indent()}{src} = {tgt}") - - for i, tgt in enumerate(node.targets): - lines.append(f"{self.context.indent()}{tgt} = {temps[i]}") - - return lines - - # === Control Flow === + def _emit_comment(self, node: CommentOp) -> list[str]: + if not node.text: + return [] + return [f"{self.context.indent()}# {line.strip()}" for line in node.text.splitlines()] - def visit_if(self, node: IfStmt) -> list[str]: - """Generate if statement.""" - lines = [] + def _emit_if(self, node: IfStmt) -> list[str]: + linearity = self._linearity() + before = linearity.snapshot() - cond = self._render_expression(node.condition) - lines.append(f"{self.context.indent()}if {cond}:") + cond = self._render_expression(node.condition, bool_context=True) + lines = [f"{self.context.indent()}if {cond}:"] - # Then block self.context.push_indent() - then_lines = [] - for stmt in node.then_body: - then_lines.extend(self.visit(stmt)) - - if then_lines: - lines.extend(then_lines) - else: - lines.append(f"{self.context.indent()}pass") + then_lines = self._emit_block(node.then_body) + lines.extend(then_lines or [f"{self.context.indent()}pass"]) self.context.pop_indent() + then_state = linearity.snapshot() - # Else block + linearity.restore(before) + else_state = None if node.else_body: lines.append(f"{self.context.indent()}else:") self.context.push_indent() - else_lines = [] - for stmt in node.else_body: - else_lines.extend(self.visit(stmt)) - - if else_lines: - lines.extend(else_lines) - else: - lines.append(f"{self.context.indent()}pass") + else_lines = self._emit_block(node.else_body) + lines.extend(else_lines or [f"{self.context.indent()}pass"]) self.context.pop_indent() + else_state = linearity.snapshot() + linearity.merge_if(before, then_state, else_state, label="If") return lines - def visit_while(self, node: WhileStmt) -> list[str]: - """Generate while loop.""" - lines = [] - - cond = self._render_expression(node.condition) - lines.append(f"{self.context.indent()}while {cond}:") + def _emit_repeat(self, node: RepeatStmt) -> list[str]: + linearity = self._linearity() + before = linearity.snapshot() + lines = [f"{self.context.indent()}for _ in range({node.count}):"] self.context.push_indent() - body_lines = [] - for stmt in node.body: - body_lines.extend(self.visit(stmt)) - - if body_lines: - lines.extend(body_lines) - else: - lines.append(f"{self.context.indent()}pass") + body_lines = self._emit_block(node.body) + lines.extend(body_lines or [f"{self.context.indent()}pass"]) self.context.pop_indent() + after = linearity.snapshot() + linearity.assert_same(before, after, label=f"Repeat({node.count})") return lines - def visit_for(self, node: ForStmt) -> list[str]: - """Generate for loop.""" - lines = [] - + def _emit_for(self, node: ForStmt) -> list[str]: + linearity = self._linearity() start = self._render_expression(node.start) stop = self._render_expression(node.stop) - - if node.step: + if node.step is not None: step = self._render_expression(node.step) - lines.append( - f"{self.context.indent()}for {node.variable} in range({start}, {stop}, {step}):", - ) + header = f"for {node.variable} in range({start}, {stop}, {step}):" else: - lines.append( - f"{self.context.indent()}for {node.variable} in range({start}, {stop}):", - ) + header = f"for {node.variable} in range({start}, {stop}):" + before = linearity.snapshot() + lines = [f"{self.context.indent()}{header}"] self.context.push_indent() - body_lines = [] - for stmt in node.body: - body_lines.extend(self.visit(stmt)) - - if body_lines: - lines.extend(body_lines) - else: - lines.append(f"{self.context.indent()}pass") + body_lines = self._emit_block(node.body) + lines.extend(body_lines or [f"{self.context.indent()}pass"]) self.context.pop_indent() + after = linearity.snapshot() + linearity.assert_same(before, after, label=f"For({node.variable})") return lines - def visit_repeat(self, node: RepeatStmt) -> list[str]: - """Generate repeat loop (as for _ in range(n)).""" - lines = [] + def _emit_parallel(self, node: ParallelBlock) -> list[str]: + return self._emit_block(node.body) + + def _emit_block_call(self, node: BlockCall) -> list[str]: + """Lower a BlockCall to a packed-array call + unpack pattern. + + Per-input dispatch: + - `array[qubit, N]` input + `AllocatorArg(name=outer)`: pack the + outer allocator's slots into `array(outer_0, outer_1, ...)`, + unpack the returned array back into the same slots. + - bare `qubit` input + `SingleQubitArg(slot=outer[i])`: pass the + outer slot's local directly (no array wrap), rebind it from the + returned single qubit value. + - bare `bit` input + `SingleBitArg(bit=c[i])`: wrap into an + `array[bool, 1]` write-back proxy, write the mutated bit back. + - `array[qubit, N]` input + `QubitBundleArg(slots=(...))`: pack N + arbitrary (possibly non-contiguous, cross-allocator) outer slots, + unpack the returned array back into the same slots. + + The remaining BlockArg subclass (`BitBundleArg`) raises -- it lands + with a later iteration. + """ + decl = self._block_decls.get(node.callee) + if decl is None: + msg = f"BlockCall references undefined block {node.callee!r}" + raise GuppyCodegenError(msg) + + if len(node.arg_bindings) != len(decl.inputs): + msg = ( + f"BlockCall {node.callee!r}: {len(node.arg_bindings)} arg_bindings " + f"but BlockDecl declares {len(decl.inputs)} inputs" + ) + raise GuppyCodegenError(msg) + + # Phase 1: validate every arg + out binding BEFORE touching linearity state, so + # a late-raised GuppyCodegenError can't leave the tracker half-consumed. + # Each validated_args entry is tagged with one of "array", + # "single_qubit", "single_bit", or "qubit_bundle" so the Phase-2 emit + # step knows how to pack the call argument. + validated_args: list[tuple[BlockInput, str, tuple]] = [] + live_inputs_out: list[BlockInput] = [] + # Scratch slots are still validated (type/presence/bounds) and fed + # into the cross-input alias check below, but kept OUT of + # `validated_args` so Phase 2 never packs/consumes/returns them and + # they are not positional call arguments (the block allocates the + # ancilla internally; the outer slot stays live and is discarded at + # end-of-scope per R1). Validating-then-excluding -- not skipping + # before validation -- so a malformed scratch binding (unknown/OOB + # outer slot) still fails loudly. + scratch_slots: list[tuple[str, int]] = [] + for inp, arg in zip(decl.inputs, node.arg_bindings, strict=True): + if inp.effect is ResourceEffect.SCRATCH: + kind, info = self._validate_block_call_arg(node.callee, inp, arg) + if kind != "single_qubit": + msg = ( + f"BlockCall {node.callee!r}: SCRATCH input {inp.name!r} " + f"must be bound by a SingleQubitArg (got {kind})" + ) + raise GuppyCodegenError(msg) + scratch_slots.append(info) + continue + kind, info = self._validate_block_call_arg(node.callee, inp, arg) + validated_args.append((inp, kind, info)) + if inp.effect is ResourceEffect.LIVE_PRESERVED: + live_inputs_out.append(inp) + + if len(node.out_bindings) != len(live_inputs_out): + expected = [inp.name for inp in live_inputs_out] + msg = ( + f"BlockCall {node.callee!r}: out_bindings count " + f"({len(node.out_bindings)}) does not match expected return positions {expected}" + ) + raise GuppyCodegenError(msg) + + validated_outs: list[tuple[str, tuple]] = [] + # Build a quick lookup of (validated_args index) for each LIVE_PRESERVED + # input. NOTE: this indexes `validated_args`, which excludes SCRATCH + # inputs -- so it must enumerate `validated_args`, not `decl.inputs` + # (a scratch input would otherwise shift every later index). + live_arg_index: dict[int, int] = {} + live_count = 0 + for va_index, (va_inp, _k, _i) in enumerate(validated_args): + if va_inp.effect is ResourceEffect.LIVE_PRESERVED: + live_arg_index[live_count] = va_index + live_count += 1 + for out_idx, (out, inp) in enumerate(zip(node.out_bindings, live_inputs_out, strict=True)): + kind, info = self._validate_block_call_arg(node.callee, inp, out, is_out=True) + # Cross-check: a LIVE_PRESERVED input's out_binding MUST reference the same + # outer-scope slot/allocator as its arg_binding. Otherwise the emitter would + # blindly set_live() on a slot that was never consumed, producing invalid + # Guppy where the never-consumed slot is overwritten. + arg_kind, arg_info = validated_args[live_arg_index[out_idx]][1:] + if (kind, info) != (arg_kind, arg_info): + msg = ( + f"BlockCall {node.callee!r}: LIVE_PRESERVED input {inp.name!r} " + f"must use an identical arg_binding and out_binding (same " + f"allocator name for AllocatorArg; same slot for " + f"SingleQubitArg; same bit for SingleBitArg; same ordered " + f"slot tuple for QubitBundleArg); got " + f"arg={arg_kind}{arg_info} vs out={kind}{info}" + ) + raise GuppyCodegenError(msg) + validated_outs.append((kind, info)) + + # Cross-input aliasing check (still Phase 1, pre-consume): two distinct + # quantum arg_bindings must not reference the same outer qubit slot. + # Without this, the overlap would only surface mid-Phase-2 as a + # LinearityError ("slot consumed") with the tracker half-mutated. Raising + # here keeps the "all validation before linearity mutation" invariant + # strict. Bits are copyable so single_bit args are excluded. + seen_slots: dict[tuple[str, int], str] = {} + for inp, kind, info in validated_args: + if kind == "array": + alloc, outer_size = info + slots = [(alloc, i) for i in range(outer_size)] + elif kind == "single_qubit": + slots = [info] + elif kind == "qubit_bundle": + slots = list(info) + else: # single_bit -- no qubit slots + continue + for slot in slots: + if slot in seen_slots: + msg = ( + f"BlockCall {node.callee!r}: qubit slot " + f"{slot[0]}[{slot[1]}] is referenced by more than one " + f"arg_binding (inputs {seen_slots[slot]!r} and " + f"{inp.name!r}); a qubit cannot be passed to two inputs" + ) + raise GuppyCodegenError(msg) + seen_slots[slot] = inp.name + # Scratch slots participate in the alias check too: a scratch slot + # shared with another (scratch or non-scratch) input slot is invalid. + for scratch_slot in scratch_slots: + if scratch_slot in seen_slots: + msg = ( + f"BlockCall {node.callee!r}: scratch qubit slot " + f"{scratch_slot[0]}[{scratch_slot[1]}] is also bound by " + f"input {seen_slots[scratch_slot]!r}; a scratch ancilla " + "cannot be shared with another input" + ) + raise GuppyCodegenError(msg) + seen_slots[scratch_slot] = "" + + # Phase 2: now that every check passed, consume slots and emit code. + linearity = self._linearity() + arg_exprs: list[str] = [] + for _inp, kind, info in validated_args: + if kind == "array": + arg_name, outer_size = info + locals_ = [linearity.consume(Slot(arg_name, i)) for i in range(outer_size)] + arg_exprs.append(f"array({', '.join(locals_)})") + elif kind == "single_qubit": + outer_alloc, outer_index = info + arg_exprs.append(linearity.consume(Slot(outer_alloc, outer_index))) + elif kind == "single_bit": + # Wrap the outer CReg bit into a 1-element bool array (write-back + # proxy). Bits are copyable, so no linearity consume. + register, bit_index = info + arg_exprs.append(f"array({register}[{bit_index}])") + elif kind == "qubit_bundle": + # Pack arbitrary (non-contiguous) outer slots into one array. + locals_ = [linearity.consume(Slot(alloc, idx)) for alloc, idx in info] + arg_exprs.append(f"array({', '.join(locals_)})") + else: + msg = f"Unsupported validated arg kind {kind!r}" # pragma: no cover + raise GuppyCodegenError(msg) - lines.append(f"{self.context.indent()}for _ in range({node.count}):") + call_expr = f"{node.callee}({', '.join(arg_exprs)})" + lines: list[str] = [] - self.context.push_indent() - body_lines = [] - for stmt in node.body: - body_lines.extend(self.visit(stmt)) + if not live_inputs_out: + lines.append(f"{self.context.indent()}{call_expr}") + return lines - if body_lines: - lines.extend(body_lines) - else: - lines.append(f"{self.context.indent()}pass") - self.context.pop_indent() + if len(live_inputs_out) == 1: + kind, info = validated_outs[0] + ret_temp = self.context.temp("call_ret") + lines.append(f"{self.context.indent()}{ret_temp} = {call_expr}") + lines.extend(self._unpack_block_call_return(kind, info, ret_temp)) + return lines + ret_temps = [self.context.temp("call_ret") for _ in live_inputs_out] + lines.append(f"{self.context.indent()}{', '.join(ret_temps)} = {call_expr}") + for ret_temp, (kind, info) in zip(ret_temps, validated_outs, strict=True): + lines.extend(self._unpack_block_call_return(kind, info, ret_temp)) return lines - def visit_parallel(self, node: ParallelBlock) -> list[str]: - """Generate parallel block (as comment + sequential for now).""" - lines = [] - lines.append(f"{self.context.indent()}# parallel begin") - - for stmt in node.body: - lines.extend(self.visit(stmt)) + def _validate_block_call_arg( + self, + callee: str, + inp: BlockInput, + arg: object, + *, + is_out: bool = False, + ) -> tuple[str, tuple]: + """Cross-check input type and BlockArg shape; return (kind, info). + + kind is one of: + - "array": info = (outer_alloc_name, outer_size) + - "single_qubit": info = (outer_alloc_name, outer_index) + - "single_bit": info = (outer_register_name, outer_bit_index) + """ + position = "out_binding" if is_out else "arg" + if isinstance(arg, AllocatorArg): + if not isinstance(inp.type_expr, ArrayTypeExpr): + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"AllocatorArg requires an array[qubit, N] input (got " + f"{type(inp.type_expr).__name__})" + ) + raise GuppyCodegenError(msg) + input_size = inp.type_expr.size + if arg.name not in self.context.root_allocators: + msg = f"BlockCall {callee!r} {position} {arg.name!r} must be an outer root allocator name" + raise GuppyCodegenError(msg) + outer_size = self.context.root_allocators[arg.name] + if outer_size != input_size: + msg = ( + f"BlockCall {callee!r} {position} {arg.name!r} size {outer_size} " + f"does not match input {inp.name!r} size {input_size}" + ) + raise GuppyCodegenError(msg) + return "array", (arg.name, outer_size) + + if isinstance(arg, SingleQubitArg): + if not isinstance(inp.type_expr, QubitTypeExpr): + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"SingleQubitArg requires a bare qubit input (got " + f"{type(inp.type_expr).__name__})" + ) + raise GuppyCodegenError(msg) + slot = arg.slot + if slot.allocator not in self.context.root_allocators: + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: slot " + f"{slot.allocator}[{slot.index}] references an unknown outer allocator" + ) + raise GuppyCodegenError(msg) + outer_size = self.context.root_allocators[slot.allocator] + if not (0 <= slot.index < outer_size): + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: slot " + f"index {slot.index} out of bounds for allocator " + f"{slot.allocator!r} of size {outer_size}" + ) + raise GuppyCodegenError(msg) + return "single_qubit", (slot.allocator, slot.index) + + if isinstance(arg, SingleBitArg): + if not isinstance(inp.type_expr, BitTypeExpr): + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"SingleBitArg requires a bare bit input (got " + f"{type(inp.type_expr).__name__})" + ) + raise GuppyCodegenError(msg) + bit = arg.bit + if bit.register not in self.context.registers: + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: bit " + f"{bit.register}[{bit.index}] references an unknown outer CReg" + ) + raise GuppyCodegenError(msg) + reg_size = self.context.registers[bit.register].size + if not (0 <= bit.index < reg_size): + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: bit " + f"index {bit.index} out of bounds for CReg {bit.register!r} " + f"of size {reg_size}" + ) + raise GuppyCodegenError(msg) + return "single_bit", (bit.register, bit.index) + + if isinstance(arg, QubitBundleArg): + if not isinstance(inp.type_expr, ArrayTypeExpr) or not isinstance( + inp.type_expr.element, + QubitTypeExpr, + ): + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"QubitBundleArg requires an array[qubit, N] input (got " + f"{type(inp.type_expr).__name__})" + ) + raise GuppyCodegenError(msg) + input_size = inp.type_expr.size + if len(arg.slots) != input_size: + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"QubitBundleArg has {len(arg.slots)} slots but input " + f"{inp.name!r} expects {input_size}" + ) + raise GuppyCodegenError(msg) + resolved: list[tuple[str, int]] = [] + seen: set[tuple[str, int]] = set() + for slot in arg.slots: + if slot.allocator not in self.context.root_allocators: + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"bundle slot {slot.allocator}[{slot.index}] references an " + "unknown outer allocator" + ) + raise GuppyCodegenError(msg) + outer_size = self.context.root_allocators[slot.allocator] + if not (0 <= slot.index < outer_size): + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"bundle slot index {slot.index} out of bounds for allocator " + f"{slot.allocator!r} of size {outer_size}" + ) + raise GuppyCodegenError(msg) + key = (slot.allocator, slot.index) + if key in seen: + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: " + f"bundle references slot {slot.allocator}[{slot.index}] more " + "than once (a qubit cannot be passed twice)" + ) + raise GuppyCodegenError(msg) + seen.add(key) + resolved.append(key) + return "qubit_bundle", tuple(resolved) + + msg = ( + f"BlockCall {callee!r} {position} for input {inp.name!r}: BlockArg " + f"{type(arg).__name__} is not yet supported" + ) + raise GuppyCodegenError(msg) + + def _unpack_block_call_return(self, kind: str, info: tuple, ret_temp: str) -> list[str]: + """Unpack a single return value back into outer-scope linearity bindings.""" + linearity = self._linearity() + if kind == "array": + out_name, _outer_size = info + return self._unpack_return_array(out_name, ret_temp) + if kind == "single_qubit": + outer_alloc, outer_index = info + # Bind the returned qubit to a fresh local that the linearity tracker + # treats as the new owner of the outer slot. Uses the standard + # `{allocator}_{index}` naming so subsequent gates resolve cleanly. + new_local = f"{outer_alloc}_{outer_index}" + linearity.set_live(Slot(outer_alloc, outer_index), new_local) + return [f"{self.context.indent()}{new_local} = {ret_temp}"] + if kind == "single_bit": + # Write the mutated bit back into the outer CReg. Bits are copyable + # so there's no linearity rebind -- just a value assignment. + register, bit_index = info + return [f"{self.context.indent()}{register}[{bit_index}] = {ret_temp}[0]"] + if kind == "qubit_bundle": + # Destructure the returned array back into the SAME outer slots the + # bundle consumed, rebinding each via the canonical local name. + new_locals = [f"{alloc}_{idx}" for alloc, idx in info] + if len(new_locals) == 1: + line = f"{self.context.indent()}{new_locals[0]}, = {ret_temp}" + else: + line = f"{self.context.indent()}{', '.join(new_locals)} = {ret_temp}" + for (alloc, idx), local in zip(info, new_locals, strict=True): + linearity.set_live(Slot(alloc, idx), local) + return [line] + msg = f"Unsupported return kind {kind!r}" # pragma: no cover + raise GuppyCodegenError(msg) + + def _require_out_binding_matches(self, callee: str, out_name: str, inp: BlockInput) -> None: + # _validate_block_decl guarantees type_expr is ArrayTypeExpr[QubitTypeExpr]. + input_size = cast("ArrayTypeExpr", inp.type_expr).size + if out_name not in self.context.root_allocators: + msg = f"BlockCall {callee!r} out_binding {out_name!r} is not an outer allocator" + raise GuppyCodegenError(msg) + outer_size = self.context.root_allocators[out_name] + if outer_size != input_size: + msg = ( + f"BlockCall {callee!r} out_binding {out_name!r} size {outer_size} " + f"does not match input {inp.name!r} size {input_size}" + ) + raise GuppyCodegenError(msg) - lines.append(f"{self.context.indent()}# parallel end") + def _unpack_return_array(self, out_name: str, ret_temp: str) -> list[str]: + size = self.context.root_allocators[out_name] + new_locals = [f"{out_name}_{i}" for i in range(size)] + if size == 1: + line = f"{self.context.indent()}{new_locals[0]}, = {ret_temp}" + else: + line = f"{self.context.indent()}{', '.join(new_locals)} = {ret_temp}" + linearity = self._linearity() + for i, local in enumerate(new_locals): + linearity.set_live(Slot(out_name, i), local) + return [line] + + def _emit_block(self, body: tuple[Statement, ...]) -> list[str]: + lines: list[str] = [] + for stmt in body: + lines.extend(self._emit_stmt(stmt)) return lines - # === References === + def _emit_print(self, node: PrintOp) -> list[str]: + """Lower PrintOp to a Guppy `result(., )` call. - def visit_slot_ref(self, node: SlotRef) -> list[str]: - """Slot refs are rendered inline.""" - return [self._render_slot_ref(node)] - - def visit_bit_ref(self, node: BitRef) -> list[str]: - """Bit refs are rendered inline.""" - return [f"{node.register}[{node.index}]"] + Per v2-print.md, Print is scope-orthogonal: it does not allocate, does + not touch the result-register set, and does not affect main's return + type. Path-signature consistency for Print inside If branches and + inline-CReg definite-assignment are enforced by separate validation + passes; this emitter assumes both have already accepted the AST. + """ + full_tag = f"{node.namespace}.{node.tag}" + + value_expr: str + if isinstance(node.value, BitRef): + register = node.value.register + if register not in self.context.registers: + msg = ( + f"Print(c[{node.value.index}]) references unknown CReg {register!r}; " + "declare the CReg or measure into it before Print." + ) + raise GuppyCodegenError(msg) + value_expr = f"{register}[{node.value.index}]" + elif isinstance(node.value, str): + register = node.value + if register not in self.context.registers: + msg = ( + f"Print({register}) references unknown CReg {register!r}; " + "declare the CReg or measure into it before Print." + ) + raise GuppyCodegenError(msg) + value_expr = register + else: + msg = f"Unsupported Print value type for Guppy codegen: {type(node.value).__name__}" + raise GuppyCodegenError(msg) - # === Expressions === + return [f'{self.context.indent()}result("{full_tag}", {value_expr})'] - def visit_literal(self, node: LiteralExpr) -> list[str]: - """Literals are rendered inline.""" - return [self._render_literal(node)] + def _emit_permute(self, node: PermuteOp) -> list[str]: + if len(node.sources) != len(node.targets): + msg = "Permute source/target length mismatch" + raise GuppyCodegenError(msg) + + quantum_mapping: dict[Slot, Slot] = {} + classical_mapping: dict[BitRef, BitRef] = {} + for source, target in zip(node.sources, node.targets, strict=True): + source_refs = self._expand_permute_ref(source) + target_refs = self._expand_permute_ref(target) + if len(source_refs) != len(target_refs): + msg = f"Permute element count mismatch for {source!r} -> {target!r}" + raise GuppyCodegenError(msg) + for source_ref, target_ref in zip(source_refs, target_refs, strict=True): + if isinstance(source_ref, Slot) and isinstance(target_ref, Slot): + quantum_mapping[source_ref] = target_ref + elif isinstance(source_ref, BitRef) and isinstance(target_ref, BitRef): + classical_mapping[source_ref] = target_ref + else: + msg = f"Permute cannot map quantum and classical refs together: {source!r} -> {target!r}" + raise GuppyCodegenError(msg) + + lines: list[str] = [] + if quantum_mapping: + self._linearity().permute(quantum_mapping, label="Permute") + + if classical_mapping: + lines.extend(self._emit_classical_permute(classical_mapping)) + + if node.add_comment and (quantum_mapping or classical_mapping): + pairs = ", ".join( + f"{source} -> {target}" for source, target in zip(node.sources, node.targets, strict=True) + ) + lines.insert(0, f"{self.context.indent()}# Permute: {pairs}") + return lines - def visit_var(self, node: VarExpr) -> list[str]: - """Variables are rendered inline.""" - return [node.name] + def _expand_permute_ref(self, ref: str) -> list[Slot | BitRef]: + parsed = self._parse_indexed_ref(ref) + if parsed is not None: + name, index = parsed + if name in self.context.root_allocators: + return [Slot(name, index)] + if name in self.context.registers: + return [BitRef(register=name, index=index)] + msg = f"Unknown Permute ref {ref!r}" + raise GuppyCodegenError(msg) + + if ref in self.context.root_allocators: + return [Slot(ref, index) for index in range(self.context.root_allocators[ref])] + if ref in self.context.registers: + return [BitRef(register=ref, index=index) for index in range(self.context.registers[ref].size)] + + msg = f"Unknown Permute ref {ref!r}" + raise GuppyCodegenError(msg) + + def _emit_classical_permute(self, mapping: dict[BitRef, BitRef]) -> list[str]: + if set(mapping) != set(mapping.values()): + msg = "Classical Permute must be bijective over the same bit set" + raise GuppyCodegenError(msg) + + lines: list[str] = [] + visited: set[BitRef] = set() + for start, target in mapping.items(): + if start in visited or target == start: + visited.add(start) + continue + cycle = [start] + visited.add(start) + current = target + while current != start: + if current in visited: + msg = "Classical Permute contains a malformed cycle" + raise GuppyCodegenError(msg) + cycle.append(current) + visited.add(current) + current = mapping[current] + + lines.extend( + f"{self.context.indent()}mem_swap({self._render_bit_ref(cycle[index])}, " + f"{self._render_bit_ref(cycle[index + 1])})" + for index in range(len(cycle) - 1) + ) + return lines - def visit_bit_expr(self, node: BitExpr) -> list[str]: - """Bit expressions are rendered inline.""" - return [f"{node.ref.register}[{node.ref.index}]"] + def _emit_end_cleanup(self) -> list[str]: + return [f"{self.context.indent()}discard({local})" for _slot, local in self._linearity().discard_live()] - def visit_binary(self, node: BinaryExpr) -> list[str]: - """Binary expressions are rendered inline.""" - return [self._render_binary(node)] + def _emit_explicit_return(self, node: ReturnOp) -> list[str]: + values = [self._return_value_expr(value) for value in node.values] + lines = self._emit_end_cleanup() + if values: + lines.append(f"{self.context.indent()}return {', '.join(values)}") + else: + lines.append(f"{self.context.indent()}return") + return lines - def visit_unary(self, node: UnaryExpr) -> list[str]: - """Unary expressions are rendered inline.""" - return [self._render_unary(node)] + def _return_value_expr(self, value: Expression | str) -> str: + if isinstance(value, str): + if value in self.context.root_allocators: + return self._consume_allocator_for_return(value) + if value in self.context.registers: + return value + msg = f"Unsupported Guppy return value {value!r}" + raise GuppyCodegenError(msg) + return self._render_expression(value) + + def _consume_allocator_for_return(self, allocator: str) -> str: + linearity = self._linearity() + locals_ = [ + linearity.consume(Slot(allocator, index)) for index in range(self.context.root_allocators[allocator]) + ] + return f"array({', '.join(locals_)})" + + def _linearity(self) -> GuppyLinearityState: + if self.context.linearity is None: + msg = "Guppy linearity state was not initialized" + raise GuppyCodegenError(msg) + return self.context.linearity + + def _slot_from_ref(self, ref: SlotRef) -> Slot: + if ref.allocator not in self.context.root_allocators: + msg = f"AST -> Guppy v1 does not support allocator {ref.allocator!r}" + raise GuppyCodegenError(msg) + return Slot(ref.allocator, ref.index) + + def _local_name(self, slot: Slot) -> str: + # Read from the disambiguated slot-locals table populated + # by `GuppyContext.populate_slot_locals` so this site agrees + # with the linearity-state binding and the entry-unpack LHS. + # Fall back to the bare formula only if a caller emits a slot + # before the table is populated (defensive; should not happen + # on a normal emission path). + cached = self.context.slot_locals.get(slot) + if cached is not None: + return cached + return f"{slot.allocator}_{slot.index}" + + def _render_bit_ref(self, ref: BitRef) -> str: + if ref.register not in self.context.registers: + msg = f"Unknown classical register {ref.register!r}" + raise GuppyCodegenError(msg) + return f"{ref.register}[{ref.index}]" + + def _render_expression(self, expr: Expression, *, bool_context: bool = False) -> str: + if isinstance(expr, LiteralExpr): + return self._render_literal(expr, bool_context=bool_context) + if isinstance(expr, VarExpr): + return expr.name + if isinstance(expr, BitExpr): + return self._render_bit_ref(expr.ref) + if isinstance(expr, BinaryExpr): + return self._render_binary(expr) + if isinstance(expr, UnaryExpr): + return self._render_unary(expr, bool_context=bool_context) + msg = f"Unsupported Guppy expression {expr!r}" + raise GuppyCodegenError(msg) + + def _render_literal(self, expr: LiteralExpr, *, bool_context: bool = False) -> str: + if isinstance(expr.value, bool): + return "True" if expr.value else "False" + if bool_context and isinstance(expr.value, int): + if expr.value in {0, 1}: + return "True" if expr.value else "False" + msg = f"Cannot render integer literal {expr.value!r} as a Guppy bool" + raise GuppyCodegenError(msg) + return str(expr.value) + + def _render_binary(self, expr: BinaryExpr) -> str: + op = BINARY_OP_TO_PYTHON.get(expr.op) + if op is None: + msg = f"Unsupported Guppy binary op {expr.op.name}" + raise GuppyCodegenError(msg) + + compares_bool_expression = expr.op in BOOL_COMPARISON_OPS and ( + self._is_bool_expression(expr.left) or self._is_bool_expression(expr.right) + ) + operand_bool_context = expr.op in BOOL_OPERAND_BINARY_OPS or compares_bool_expression + left = self._render_expression(expr.left, bool_context=operand_bool_context) + right = self._render_expression(expr.right, bool_context=operand_bool_context) + return f"({left} {op} {right})" - # === Type expressions === + def _render_unary(self, expr: UnaryExpr, *, bool_context: bool = False) -> str: + operand = self._render_expression(expr.operand, bool_context=bool_context or expr.op == UnaryOp.NOT) + if expr.op == UnaryOp.NOT: + return f"(not {operand})" + if expr.op == UnaryOp.NEG: + return f"(-{operand})" + msg = f"Unsupported Guppy unary op {expr.op.name}" + raise GuppyCodegenError(msg) - def visit_qubit_type(self, _node: object) -> list[str]: + def _is_bool_expression(self, expr: Expression) -> bool: + if isinstance(expr, BitExpr): + return True + if isinstance(expr, LiteralExpr): + return isinstance(expr.value, bool) + if isinstance(expr, UnaryExpr): + return expr.op == UnaryOp.NOT + return isinstance(expr, BinaryExpr) and expr.op in { + BinaryOp.AND, + BinaryOp.OR, + BinaryOp.XOR, + BinaryOp.EQ, + BinaryOp.NE, + BinaryOp.LT, + BinaryOp.LE, + BinaryOp.GT, + BinaryOp.GE, + } + + def _parse_indexed_ref(self, ref: str) -> tuple[str, int] | None: + match = re.fullmatch(r"([A-Za-z_]\w*)\[(\d+)\]", ref) + if match is None: + return None + return match.group(1), int(match.group(2)) + + def visit_qubit_type(self, _node: QubitTypeExpr) -> list[str]: + """Render a qubit type expression.""" return ["qubit"] - def visit_bit_type(self, _node: object) -> list[str]: + def visit_bit_type(self, _node: BitTypeExpr) -> list[str]: + """Render a bit type expression.""" return ["bool"] - def visit_array_type(self, node) -> list[str]: - elem = self.visit(node.element)[0] if self.visit(node.element) else "qubit" + def visit_array_type(self, node: object) -> list[str]: + """Render an array type expression.""" + if isinstance(node.element, QubitTypeExpr): + elem = "qubit" + elif isinstance(node.element, BitTypeExpr): + elem = "bool" + else: + elem = "qubit" return [f"array[{elem}, {node.size}]"] - def visit_allocator_type(self, node) -> list[str]: - return [f"array[qubit, {node.capacity}]"] - # === Helper methods === +def ast_to_guppy(program: Program) -> str: + """Convert an AST Program to Guppy Python code.""" + generator = AstToGuppy() + return "\n".join(generator.generate(program)) - def _render_slot_ref(self, node: SlotRef) -> str: - """Render a slot reference as array access. - For child allocators, translates to root allocator with computed offset. - E.g., data[0] -> base[0], ancilla[0] -> base[4] - """ - # Get the root allocator and absolute index - root = self.context.get_root_allocator(node.allocator) - abs_index = self.context.get_absolute_index(node.allocator, node.index) +def validate_slr_for_guppy_v1(block: object | None) -> None: + """Reject SLR constructs that the v1 AST -> Guppy path cannot represent soundly.""" + if block is None: + return + _validate_slr_node_for_guppy_v1(block) - return f"{root}[{abs_index}]" - def _render_expression(self, expr: Expression) -> str: - """Render an expression to a string.""" - if isinstance(expr, LiteralExpr): - return self._render_literal(expr) - if isinstance(expr, VarExpr): - return expr.name - if isinstance(expr, BitExpr): - # Use underscore naming to match measurement variable names - return f"{expr.ref.register}_{expr.ref.index}" - if isinstance(expr, BinaryExpr): - return self._render_binary(expr) - if isinstance(expr, UnaryExpr): - return self._render_unary(expr) - return str(expr) - - def _render_literal(self, node: LiteralExpr) -> str: - """Render a literal value.""" - if isinstance(node.value, bool): - return "True" if node.value else "False" - return str(node.value) - - def _render_binary(self, node: BinaryExpr) -> str: - """Render a binary expression.""" - left = self._render_expression(node.left) - right = self._render_expression(node.right) - op = BINARY_OP_TO_PYTHON.get(node.op, str(node.op)) - return f"({left} {op} {right})" +def _validate_slr_node_for_guppy_v1(node: object) -> None: + node_type = type(node).__name__ + if node_type == "While": + msg = "AST -> Guppy v1 does not support While loops" + raise GuppyCodegenError(msg) - def _render_unary(self, node: UnaryExpr) -> str: - """Render a unary expression.""" - operand = self._render_expression(node.operand) - op = UNARY_OP_TO_PYTHON.get(node.op, str(node.op)) - return f"({op} {operand})" + if getattr(node, "is_qgate", False): + _validate_slr_gate_for_guppy_v1(node) + for child in getattr(node, "ops", ()) or (): + _validate_slr_node_for_guppy_v1(child) -def ast_to_guppy(program: Program) -> str: - """Convert an AST Program to Guppy Python code. + else_block = getattr(node, "else_block", None) + if else_block is not None: + _validate_slr_node_for_guppy_v1(else_block) - Convenience function for simple code generation. - Args: - program: The AST Program to convert. +def _validate_slr_gate_for_guppy_v1(gate: object) -> None: + qargs = getattr(gate, "qargs", ()) or () + cout = getattr(gate, "cout", ()) or () - Returns: - Generated Guppy Python code as a string. - """ - generator = AstToGuppy() - lines = generator.generate(program) - return "\n".join(lines) + if _contains_symbolic_index(qargs) or _contains_symbolic_index(cout): + msg = "AST -> Guppy v1 does not support symbolic LoopVar indexing" + raise GuppyCodegenError(msg) + + # (The non-Z `Prep` string-basis preflight reject was + # removed -- prep basis is the gate identity now; the dedicated + # gates carry it through `PrepareOp.basis` and the converter + # already fails loud on any stray prep string arg.) + + +def _contains_symbolic_index(value: object) -> bool: + for item in _nested_items(value): + if hasattr(item, "index_var") or type(item).__name__.startswith("Symbolic"): + return True + return False + + +def _nested_items(value: object) -> Iterator[object]: + if isinstance(value, str): + yield value + return + if isinstance(value, list | tuple): + for item in value: + yield from _nested_items(item) + return + yield value diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/guppy_linearity.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/guppy_linearity.py new file mode 100644 index 000000000..204969d67 --- /dev/null +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/guppy_linearity.py @@ -0,0 +1,238 @@ +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Guppy-only slot ownership tracking for AST code generation. + +This module is deliberately target-scoped. It tracks the Guppy local that +currently owns each logical SLR qubit slot while `ast/codegen/guppy.py` +emits source. It does not annotate AST nodes and does not model non-Guppy +codegens. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + +@dataclass(frozen=True, slots=True) +class Slot: + """Logical qubit slot from an SLR allocator, such as `q[0]`.""" + + allocator: str + index: int + + def __str__(self) -> str: + """Return a compact user-facing slot name.""" + return f"{self.allocator}[{self.index}]" + + +class SlotState(Enum): + """Guppy ownership state for a logical qubit slot.""" + + LIVE = auto() + CONSUMED = auto() + + +@dataclass(frozen=True, slots=True) +class Binding: + """Current Guppy local name and ownership state for one slot.""" + + local: str + state: SlotState + + +LinearitySnapshot = dict[Slot, Binding] + + +class LinearityError(Exception): + """Raised when AST emission would produce unsound Guppy ownership.""" + + +class GuppyLinearityState: + """Track logical qubit slots while the Guppy emitter writes locals.""" + + def __init__(self, bindings: Mapping[Slot, Binding]) -> None: + """Create state from an explicit binding table in stable order.""" + self._order = tuple(bindings) + self._bindings = dict(bindings) + + @classmethod + def from_allocators( + cls, + allocators: Mapping[str, int], + *, + slot_locals: Mapping[Slot, str] | None = None, + ) -> GuppyLinearityState: + """Create live slot bindings for root QReg/QAlloc declarations. + + `slot_locals`, when provided, is the single namespace-wide + slot-to-Guppy-local name table from `GuppyContext.slot_locals` + (disambiguates the default `f"{allocator}_{index}"` against + register names so the entry-unpack LHS does not shadow another + declared register). When omitted, the default name is used -- + kept for callers that build a linearity table outside the main + emitter (no register collision risk for those isolated paths). + """ + bindings: dict[Slot, Binding] = {} + for allocator, size in allocators.items(): + if size < 0: + msg = f"Allocator {allocator!r} has negative size {size}" + raise LinearityError(msg) + for index in range(size): + slot = Slot(allocator, index) + local = slot_locals[slot] if slot_locals is not None and slot in slot_locals else f"{allocator}_{index}" + bindings[slot] = Binding(local=local, state=SlotState.LIVE) + return cls(bindings) + + def bindings(self) -> Iterable[tuple[Slot, Binding]]: + """Iterate bindings in stable allocator/index order.""" + return ((slot, self._bindings[slot]) for slot in self._order) + + def binding(self, slot: Slot) -> Binding: + """Return the current binding for a slot, including consumed slots.""" + self._require_known(slot) + return self._bindings[slot] + + def status(self, slot: Slot) -> SlotState: + """Return whether a slot is live or consumed.""" + return self.binding(slot).state + + def live(self, slot: Slot) -> str: + """Return the live Guppy local for a slot, or raise if consumed.""" + binding = self.binding(slot) + if binding.state is not SlotState.LIVE: + msg = f"Slot {slot} is consumed and has no live Guppy local" + raise LinearityError(msg) + return binding.local + + def set_live(self, slot: Slot, local: str) -> None: + """Record the current live owner for a slot; the name may be unchanged.""" + self._require_known(slot) + self._bindings[slot] = Binding(local=local, state=SlotState.LIVE) + + def consume(self, slot: Slot) -> str: + """Return the live local and mark the slot consumed; raise if unavailable.""" + local = self.live(slot) + self._bindings[slot] = Binding(local=local, state=SlotState.CONSUMED) + return local + + def discard_live(self) -> list[tuple[Slot, str]]: + """Consume all remaining live slots for end-of-function cleanup.""" + discarded: list[tuple[Slot, str]] = [] + for slot in self._order: + binding = self._bindings[slot] + if binding.state is SlotState.LIVE: + discarded.append((slot, binding.local)) + self._bindings[slot] = Binding(local=binding.local, state=SlotState.CONSUMED) + return discarded + + def snapshot(self) -> LinearitySnapshot: + """Return an opaque copy for speculative branch or loop emission.""" + return dict(self._bindings) + + def restore(self, snapshot: LinearitySnapshot) -> None: + """Restore a previous snapshot before emitting another region.""" + self._require_valid_snapshot(snapshot, label="restore") + self._bindings = dict(snapshot) + + def merge_if( + self, + before: LinearitySnapshot, + then_state: LinearitySnapshot, + else_state: LinearitySnapshot | None = None, + *, + label: str, + ) -> None: + """Accept an if only when both exits leave identical slot bindings.""" + self._require_valid_snapshot(before, label=f"{label} before") + self._require_valid_snapshot(then_state, label=f"{label} then") + merged_else = before if else_state is None else else_state + self._require_valid_snapshot(merged_else, label=f"{label} else") + + if then_state != merged_else: + msg = ( + f"{label} leaves divergent Guppy slot states; " + f"then={self._snapshot_summary(then_state)}, else={self._snapshot_summary(merged_else)}" + ) + raise LinearityError(msg) + self._bindings = dict(then_state) + + def assert_same( + self, + before: LinearitySnapshot, + after: LinearitySnapshot, + *, + label: str, + ) -> None: + """Require a loop/region body to preserve exact slot bindings.""" + self._require_valid_snapshot(before, label=f"{label} before") + self._require_valid_snapshot(after, label=f"{label} after") + if before != after: + msg = ( + f"{label} changes Guppy slot state across a required invariant; " + f"before={self._snapshot_summary(before)}, after={self._snapshot_summary(after)}" + ) + raise LinearityError(msg) + self._bindings = dict(after) + + def permute(self, mapping: Mapping[Slot, Slot], *, label: str) -> None: + """Apply a static logical-slot permutation to the binding table. + + `mapping` is interpreted as `logical_source -> old_logical_target`: + after `permute({a[0]: a[1], a[1]: a[0]})`, references to `a[0]` + use the binding that previously belonged to `a[1]`. + """ + keys = set(mapping) + values = set(mapping.values()) + if keys != values: + msg = f"{label} must be bijective over the same slot set" + raise LinearityError(msg) + + for slot in keys: + self._require_known(slot) + for slot in values: + self._require_known(slot) + + old_bindings = dict(self._bindings) + for source, target in mapping.items(): + self._bindings[source] = old_bindings[target] + + def _require_known(self, slot: Slot) -> None: + if slot not in self._bindings: + msg = f"Unknown Guppy slot {slot}" + raise LinearityError(msg) + + def _require_valid_snapshot(self, snapshot: LinearitySnapshot, *, label: str) -> None: + if set(snapshot) != set(self._bindings): + msg = f"{label} snapshot has different slot set" + raise LinearityError(msg) + + def _snapshot_summary(self, snapshot: LinearitySnapshot) -> str: + parts = [] + for slot in self._order: + binding = snapshot[slot] + parts.append(f"{slot}:{binding.local}/{binding.state.name}") + return "{" + ", ".join(parts) + "}" + + +__all__ = [ + "Binding", + "GuppyLinearityState", + "LinearityError", + "LinearitySnapshot", + "Slot", + "SlotState", +] diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/qasm.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/qasm.py index 2242fe699..bb4c412d5 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/codegen/qasm.py +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/qasm.py @@ -32,6 +32,8 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from pecos.slr.ast.codegen._block_flatten import flatten_block_calls +from pecos.slr.ast.codegen._prep_tail import prep_tail from pecos.slr.ast.nodes import ( AllocatorDecl, BinaryExpr, @@ -60,6 +62,7 @@ ParallelBlock, PermuteOp, PrepareOp, + PrintOp, Program, RepeatStmt, ReturnOp, @@ -77,8 +80,6 @@ # Hadamard GateKind.H: "h", # Phase gates - GateKind.S: "s", - GateKind.Sdg: "sdg", GateKind.T: "rz(pi/4)", # T gate as rotation GateKind.Tdg: "rz(-pi/4)", # Square root gates @@ -218,6 +219,7 @@ def generate(self, program: Program) -> list[str]: Returns: List of code lines. """ + program = flatten_block_calls(program) self.context = QasmContext() return self.visit(program) @@ -315,6 +317,37 @@ def visit_gate(self, node: GateOp) -> list[str]: """Generate gate operation.""" lines = [] + # Fail-loud arity guard (defense-in-depth vs the angle-first + # mis-order footgun): a parameterized gate with fewer targets + # than its qubit arity would otherwise SILENTLY emit no QASM + # line (the two-qubit path is gated on `len(targets) >= 2`, the + # single-qubit path iterates zero targets). The SLR + # `QGate.__call__` rejects the mis-ordered call at the source; + # this guards a malformed GateOp reaching QASM from any other + # path. Multi-target (parallel) application is fine (len >= arity). + if node.gate.is_parameterized and len(node.targets) < node.gate.arity: + msg = ( + f"QASM codegen: parameterized gate {node.gate.name!r} has " + f"{len(node.targets)} qubit target(s) but needs at least " + f"{node.gate.arity} (a mis-ordered `gate(qubit, angle)` call " + "drops the qubit). Call it as `gate(angle, qubit...)`." + ) + raise NotImplementedError(msg) + + # Typed-angle guard: a user/direct-AST parameterized gate's params + # must be typed `Angle` literals (matches Guppy + the typed-AST + # contract); reject bare floats so backends do not diverge. + if node.gate.is_parameterized: + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + for p in node.params: + if not (isinstance(p, LiteralExpr) and isinstance(p.value, Angle)): + msg = ( + f"QASM codegen: parameterized gate {node.gate.name!r} requires typed `Angle` " + f"params (use `rad(...)` / `turns(...)` in SLR); got {p!r}." + ) + raise NotImplementedError(msg) + # Handle special face rotation gates if node.gate == GateKind.F: for target in node.targets: @@ -369,22 +402,25 @@ def visit_gate(self, node: GateOp) -> list[str]: return lines def visit_prepare(self, node: PrepareOp) -> list[str]: - """Generate reset/prep operation.""" + """Generate reset/prep operation (Z-reset + canonical basis tail).""" lines = [] + tail = prep_tail(node.basis) # Get root allocator for this allocator root = self.context.get_root_allocator(node.allocator) if node.slots is None: - # Reset all qubits in the allocator + if tail: + msg = f"QASM codegen: prepare_all with non-PZ basis {node.basis!r} is not supported" + raise NotImplementedError(msg) capacity = self.context.allocators.get(node.allocator, 1) - for i in range(capacity): - abs_index = self.context.get_absolute_index(node.allocator, i) - lines.append(self._maybe_conditional(f"reset {root}[{abs_index}];")) + indices = [self.context.get_absolute_index(node.allocator, i) for i in range(capacity)] else: - for slot in node.slots: - abs_index = self.context.get_absolute_index(node.allocator, slot) - lines.append(self._maybe_conditional(f"reset {root}[{abs_index}];")) + indices = [self.context.get_absolute_index(node.allocator, slot) for slot in node.slots] + + for abs_index in indices: + lines.append(self._maybe_conditional(f"reset {root}[{abs_index}];")) + lines.extend(self._maybe_conditional(f"{GATE_TO_QASM[gk]} {root}[{abs_index}];") for gk in tail) return lines @@ -444,6 +480,11 @@ def visit_return(self, _node: ReturnOp) -> list[str]: """Return is not a QASM concept - ignored.""" return [] + def visit_print(self, node: PrintOp) -> list[str]: + """Print has no native QASM equivalent; emit as a comment for traceability.""" + value_repr = node.value if isinstance(node.value, str) else f"{node.value.register}[{node.value.index}]" + return [f"// Print {node.namespace}.{node.tag} {value_repr}"] + def visit_permute(self, node: PermuteOp) -> list[str]: """Handle permutation. @@ -783,6 +824,11 @@ def _render_expression(self, expr: Expression) -> str: def _render_literal(self, node: LiteralExpr) -> str: """Render a literal value.""" + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + if isinstance(node.value, Angle): + # OpenQASM rotations are in radians; signed principal value. + return str(node.value.value.to_radians_signed()) return str(node.value) def _render_binary(self, node: BinaryExpr) -> str: diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/qir.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/qir.py index 5ad44075b..ff67c035f 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/codegen/qir.py +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/qir.py @@ -25,11 +25,14 @@ from __future__ import annotations +import math import re -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Any -import pecos as pc +from pecos.slr.ast.codegen._block_flatten import flatten_block_calls +from pecos.slr.ast.codegen._prep_tail import prep_tail from pecos.slr.ast.nodes import ( AllocatorDecl, AssignOp, @@ -47,8 +50,10 @@ ParallelBlock, PermuteOp, PrepareOp, + PrintOp, RegisterDecl, RepeatStmt, + ReturnOp, SlotRef, UnaryExpr, UnaryOp, @@ -81,8 +86,6 @@ # Hadamard GateKind.H: "h", # Phase gates - GateKind.S: "s", - GateKind.Sdg: "s__adj", GateKind.T: "t", GateKind.Tdg: "t__adj", # Square root gates - mapped to S variants @@ -92,18 +95,192 @@ GateKind.RX: "rx", GateKind.RY: "ry", GateKind.RZ: "rz", - # Two-qubit gates + # Two-qubit gates -- only the qir-qis ALLOWED_QIS_FNS that + # actually execute through `qir_to_qis -> selene`. The native + # Quantinuum 2q is `rzz` (parameterized); `__quantum__qis__zz__body` + # is NOT in the allowlist (cf. ~/Repos/qir-qis/src/lib.rs:59). + # SZZ/SZZdg/SXX/SXXdg/SYY/SYYdg are lowered via `_GATE_DECOMP` + # to RZZ + 1q Cliffords (verified up-to-phase + end-to-end). GateKind.CX: "cnot", GateKind.CZ: "cz", GateKind.RZZ: "rzz", - GateKind.SZZ: "zz", +} + +# Decomposition table: a sequence of (primitive_kind, qubit_idx_tuple, +# params_tuple) steps. Each step's qubit_idx_tuple indexes into the +# input gate's `targets`; params_tuple is the (constant) angles for a +# parameterized primitive (RZZ, RY, RX, RZ). Every entry was found by +# extracting the gate's authoritative unitary from +# `pecos.simulators.StateVec` (or the canonical matrix oracle in +# `tests/pecos/integration/state_sim_tests/gate_matrix_def.py`), +# searching/deriving a decomposition into ONLY the qir-qis ALLOWED +# primitive set (`h, x, y, z, s, s__adj, t, t__adj, rx, ry, rz, rzz, +# rxy, cnot, cz`), verifying it equal up to a GLOBAL PHASE +# (unobservable for measurement-terminated circuits) to the PECOS +# unitary, AND verifying it end-to-end through `qir_to_qis -> selene` +# with discriminating deterministic identities (a no-op lowering +# would fail). For *Clifford* gates the selene Stim backend can +# verify; non-Clifford gates (e.g. CH, T-decompositions, arbitrary- +# angle rotations) use the selene Quest statevector backend. +# Sequences are in CIRCUIT order (first applied first). Decompositions +# minimize 2q-gate count first (2q ops are the hardware cost driver). +# A decomposition step's params slot is either: +# - a tuple of constant floats (most common, e.g. SZZ -> RZZ(pi/2)), or +# - a callable that takes the *input* gate's params (e.g. CRZ(theta)) +# and returns the step's params (e.g. (theta/2,) on RZ, (-theta/2,) +# on RZZ). This lets parameterized controlled-rotation gates thread +# their angle through a 2q-minimal decomposition. +_DecompParams = tuple[float, ...] | Callable[[tuple[float, ...]], tuple[float, ...]] +_DecompStep = tuple[GateKind, tuple[int, ...], _DecompParams] +_GATE_DECOMP: dict[GateKind, tuple[_DecompStep, ...]] = { + # ---- single-qubit Clifford sqrt + face rotations ---- + GateKind.SX: ((GateKind.H, (0,), ()), (GateKind.SZ, (0,), ()), (GateKind.H, (0,), ())), + GateKind.SXdg: ((GateKind.H, (0,), ()), (GateKind.SZdg, (0,), ()), (GateKind.H, (0,), ())), + GateKind.SY: ((GateKind.H, (0,), ()), (GateKind.X, (0,), ())), + GateKind.SYdg: ((GateKind.H, (0,), ()), (GateKind.Z, (0,), ())), + GateKind.F: ((GateKind.SZdg, (0,), ()), (GateKind.H, (0,), ())), + GateKind.Fdg: ((GateKind.H, (0,), ()), (GateKind.SZ, (0,), ())), + GateKind.F4: ((GateKind.H, (0,), ()), (GateKind.SZdg, (0,), ())), + GateKind.F4dg: ((GateKind.SZ, (0,), ()), (GateKind.H, (0,), ())), + # ---- two-qubit Clifford gates ---- + # SZZ/SZZdg directly via the native parameterized rzz. + GateKind.SZZ: ((GateKind.RZZ, (0, 1), (math.pi / 2,)),), + GateKind.SZZdg: ((GateKind.RZZ, (0, 1), (-math.pi / 2,)),), + # SXX = (H⊗H)·SZZ·(H⊗H); SXXdg with -π/2. + GateKind.SXX: ( + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + (GateKind.RZZ, (0, 1), (math.pi / 2,)), + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + ), + GateKind.SXXdg: ( + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + (GateKind.RZZ, (0, 1), (-math.pi / 2,)), + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + ), + # SYY = (S⊗S)·(H⊗H)·SZZ·(H⊗H)·(Sdg⊗Sdg) since Y = S·X·S† and + # XX = (H⊗H)·ZZ·(H⊗H). SYYdg with -π/2. + GateKind.SYY: ( + (GateKind.SZdg, (0,), ()), + (GateKind.SZdg, (1,), ()), + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + (GateKind.RZZ, (0, 1), (math.pi / 2,)), + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + (GateKind.SZ, (0,), ()), + (GateKind.SZ, (1,), ()), + ), + GateKind.SYYdg: ( + (GateKind.SZdg, (0,), ()), + (GateKind.SZdg, (1,), ()), + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + (GateKind.RZZ, (0, 1), (-math.pi / 2,)), + (GateKind.H, (0,), ()), + (GateKind.H, (1,), ()), + (GateKind.SZ, (0,), ()), + (GateKind.SZ, (1,), ()), + ), + # CY = Sdg(target); CX(control,target); S(target). + GateKind.CY: ( + (GateKind.SZdg, (1,), ()), + (GateKind.CX, (0, 1), ()), + (GateKind.SZ, (1,), ()), + ), + # CH = (I_c x Ry(-pi/4)_t) . CX(c,t) . (I_c x Ry(pi/4)_t) -- 1 CX + # (the 2q-minimal Clifford+rotation form; conjugation by Ry maps + # X to H since Ry(-pi/4) X Ry(pi/4) = cos(-pi/4) X - sin(-pi/4) Z + # = (X+Z)/sqrt(2) = H). The PECOS oracle CH() in gate_matrix_def + # uses a Clifford+T 2-CX form; ours matches it up to global phase + # (max_err 3e-14) and matches textbook block-diag(I,H) exactly. + GateKind.CH: ( + (GateKind.RY, (1,), (math.pi / 4,)), + (GateKind.CX, (0, 1), ()), + (GateKind.RY, (1,), (-math.pi / 4,)), + ), + # ---- parameterized controlled rotations ---- + # CRZ(theta) = (RZ(theta/2) o RZ(theta/2)) . RZZ(-theta/2): 1 RZZ, + # 2 single-qubit RZ. The RZ on the control absorbs the e^{i theta/2} + # phase that PECOS's R*(theta) all carry (otherwise it would be a + # c=1-only relative phase, which is observable). Verified against + # gate_matrix_def.CRZ(theta) for 5 random angles. + GateKind.CRZ: ( + (GateKind.RZZ, (0, 1), lambda p: (-p[0] / 2,)), + (GateKind.RZ, (0,), lambda p: (p[0] / 2,)), + (GateKind.RZ, (1,), lambda p: (p[0] / 2,)), + ), + # CRX(theta) = (I o H) . CRZ(theta) . (I o H): conjugate CRZ by H + # on the target since H.Z.H = X. Same 1 RZZ. + GateKind.CRX: ( + (GateKind.H, (1,), ()), + (GateKind.RZZ, (0, 1), lambda p: (-p[0] / 2,)), + (GateKind.RZ, (0,), lambda p: (p[0] / 2,)), + (GateKind.RZ, (1,), lambda p: (p[0] / 2,)), + (GateKind.H, (1,), ()), + ), + # CRY(theta) = (I o (S.H)) . CRZ(theta) . (I o (H.Sdg)): conjugate + # CRZ by (S.H) on the target since S.X.Sdg = Y (and H.Z.H = X). + # Same 1 RZZ. + GateKind.CRY: ( + (GateKind.SZdg, (1,), ()), + (GateKind.H, (1,), ()), + (GateKind.RZZ, (0, 1), lambda p: (-p[0] / 2,)), + (GateKind.RZ, (0,), lambda p: (p[0] / 2,)), + (GateKind.RZ, (1,), lambda p: (p[0] / 2,)), + (GateKind.H, (1,), ()), + (GateKind.SZ, (1,), ()), + ), } # Gates with rotation parameters PARAMETERIZED_GATES = {GateKind.RX, GateKind.RY, GateKind.RZ, GateKind.RZZ} # Two-qubit gates -TWO_QUBIT_GATES = {GateKind.CX, GateKind.CZ, GateKind.RZZ, GateKind.SZZ} +TWO_QUBIT_GATES = {GateKind.CX, GateKind.CZ, GateKind.RZZ} + + +def _param_to_radians(p: object) -> float: + """Resolve a gate angle param to a signed-radians float. + + Accepts a `LiteralExpr` wrapping either a typed `Angle` or a bare + number, or a raw number (decomposition steps thread raw floats). + Typed angles use the signed principal value so the float-based + decomposition arithmetic (`-theta/2`) avoids the global-phase flip + that the unsigned `[0, 2pi)` form would introduce at the wrap point. + """ + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + value = p.value if isinstance(p, LiteralExpr) else p + if isinstance(value, Angle): + return value.value.to_radians_signed() + return float(value) + + +def _require_typed_angle_params(node: GateOp, backend: str) -> None: + """Fail loud if a parameterized user gate has a non-`Angle` param. + + Enforces the typed-AST-dialect contract uniformly across backends: a + parameterized `GateOp` reaching codegen from the user / direct-AST path + must carry typed `Angle` literals (`rad(...)` / `turns(...)`), not bare + floats. (Internal decomposition steps thread raw floats but reach the + per-gate emitters directly, not this top-level entry.) + """ + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + if not node.gate.is_parameterized: + return + for p in node.params: + if not (isinstance(p, LiteralExpr) and isinstance(p.value, Angle)): + gate_name = getattr(node.gate, "name", node.gate) + msg = ( + f"{backend} codegen: parameterized gate {gate_name!r} requires typed `Angle` " + f"params (use `rad(...)` / `turns(...)` in SLR); got {p!r}." + ) + raise NotImplementedError(msg) @dataclass @@ -113,9 +290,16 @@ class QirCodeGenContext: qubit_map: dict[tuple[str, int], int] = field(default_factory=dict) qubit_count: int = 0 creg_map: dict[str, int] = field(default_factory=dict) # name -> size + qreg_sizes: dict[str, int] = field(default_factory=dict) # name -> capacity measurement_count: int = 0 allocator_parents: dict[str, str | None] = field(default_factory=dict) allocator_offsets: dict[str, int] = field(default_factory=dict) + # Static logical permutation, mirroring the Guppy linearity + # tracker's `.permute()` (compile-time relabel; QIR/Selene have no + # runtime permute intrinsic). Maps a logical (reg, index) ref to + # the (reg, index) whose storage it should resolve to. Consulted + # at every qubit-ref and classical-bit-ref lowering. + permutation_map: dict[tuple[str, int], tuple[str, int]] = field(default_factory=dict) def get_root_allocator(self, name: str) -> str: """Get the root allocator for a given allocator name.""" @@ -134,6 +318,11 @@ def get_qubit_index(self, allocator: str, index: int) -> int: For child allocators, translates to root allocator with computed offset. """ + # Resolve any active logical permutation first (identity until + # a Permute runs; decl-time pre-population sees the empty map, + # so real qubits are still allocated 1:1). + allocator, index = self.permutation_map.get((allocator, index), (allocator, index)) + # Translate to root allocator and absolute index root = self.get_root_allocator(allocator) abs_index = self.get_absolute_index(allocator, index) @@ -182,6 +371,8 @@ def generate(self, program: Program) -> str: msg = "LLVM dependencies not available. Install with 'pip install pecos[qir]'" raise ImportError(msg) + program = flatten_block_calls(program) + self.context = QirCodeGenContext() self._gate_cache = {} self._creg_ptrs = {} @@ -206,15 +397,20 @@ def generate(self, program: Program) -> str: # Setup creg helper functions self._setup_creg_funcs() - # Setup measurement function - self._mz_to_bit = self._declare_function( - "mz_to_creg_bit", + # Standard QIR classical model. A measurement lowers to a + # static `%Result*` slot -> `__quantum__qis__mz__body` -> + # `__quantum__rt__read_result` -> `store` into a per-CReg mutable + # `[N x i1]` entry-block `alloca` buffer. `%Result*` + # is the existing `result_ptr` type. + self._mz_body = self._declare_function( + "__quantum__qis__mz__body", self._types["void"], - [ - self._types["qubit_ptr"], - self._types["bool"].as_pointer(), - self._types["int"], - ], + [self._types["qubit_ptr"], self._types["result_ptr"]], + ) + self._read_result = self._declare_function( + "__quantum__rt__read_result", + self._types["bool"], + [self._types["result_ptr"]], ) # Setup main function @@ -222,9 +418,6 @@ def generate(self, program: Program) -> str: self._main_func = llvm_ir.Function(self._module, main_fnty, name="main") entry_block = self._main_func.append_basic_block(name="entry") self._builder = llvm_ir.IRBuilder(entry_block) - self._builder.comment( - f"// Generated from AST using: PECOS version {pc.__version__}", - ) # Setup operator map self._setup_op_map() @@ -246,38 +439,15 @@ def generate(self, program: Program) -> str: return self._finalize_module() def _setup_creg_funcs(self) -> None: - """Setup classical register helper functions.""" + """Declare the standard classical-output runtime function. + The static `[N x i1]` CReg model replaced the bespoke + `create_creg`/`get_creg_bit`/ + `set_creg_bit`/`get_int_from_creg`/`set_creg_to_int`/`mz_to_creg_bit` + runtime helpers with native `alloca`/`store`/`load`/`gep`/`zext`, + so only the standard `__quantum__rt__int_record_output` remains. + """ self._creg_funcs = { - "create_creg": self._declare_function( - "create_creg", - self._types["bool"].as_pointer(), - [self._types["int"]], - ), - "creg_to_int": self._declare_function( - "get_int_from_creg", - self._types["int"], - [self._types["bool"].as_pointer()], - ), - "get_creg_bit": self._declare_function( - "get_creg_bit", - self._types["bool"], - [self._types["bool"].as_pointer(), self._types["int"]], - ), - "set_creg_bit": self._declare_function( - "set_creg_bit", - self._types["void"], - [ - self._types["bool"].as_pointer(), - self._types["int"], - self._types["bool"], - ], - ), - "set_creg": self._declare_function( - "set_creg_to_int", - self._types["void"], - [self._types["bool"].as_pointer(), self._types["int"]], - ), "int_result": self._declare_function( "__quantum__rt__int_record_output", self._types["void"], @@ -292,13 +462,22 @@ def _declare_function(self, name: str, ret_ty: Any, arg_tys: list) -> Any: def _setup_op_map(self) -> None: """Setup binary operator mapping.""" + # CReg comparisons are UNSIGNED (the CReg's `[N x i1]` buffer + # packs to an i64 unsigned bit pattern via `_pack_creg`). The + # `icmp_unsigned` choice matters only when bit 63 of a 64-bit + # CReg is set; for narrower CRegs (or for bit-level `c[i]` + # comparisons that zext to i64 with the top 63 bits zero) + # signed and unsigned agree. Switching is safe -- no existing + # test asserts the signed semantics on a high-bit-set 64-bit + # CReg, and unsigned is the correct interpretation of the + # bit-pattern semantics SLR exposes. self._op_map = { - BinaryOp.EQ: lambda lhs, rhs: self._builder.icmp_signed("==", lhs, rhs), - BinaryOp.NE: lambda lhs, rhs: self._builder.icmp_signed("!=", lhs, rhs), - BinaryOp.LT: lambda lhs, rhs: self._builder.icmp_signed("<", lhs, rhs), - BinaryOp.GT: lambda lhs, rhs: self._builder.icmp_signed(">", lhs, rhs), - BinaryOp.LE: lambda lhs, rhs: self._builder.icmp_signed("<=", lhs, rhs), - BinaryOp.GE: lambda lhs, rhs: self._builder.icmp_signed(">=", lhs, rhs), + BinaryOp.EQ: lambda lhs, rhs: self._builder.icmp_unsigned("==", lhs, rhs), + BinaryOp.NE: lambda lhs, rhs: self._builder.icmp_unsigned("!=", lhs, rhs), + BinaryOp.LT: lambda lhs, rhs: self._builder.icmp_unsigned("<", lhs, rhs), + BinaryOp.GT: lambda lhs, rhs: self._builder.icmp_unsigned(">", lhs, rhs), + BinaryOp.LE: lambda lhs, rhs: self._builder.icmp_unsigned("<=", lhs, rhs), + BinaryOp.GE: lambda lhs, rhs: self._builder.icmp_unsigned(">=", lhs, rhs), BinaryOp.MUL: self._builder.mul, BinaryOp.DIV: self._builder.udiv, BinaryOp.XOR: self._builder.xor, @@ -326,19 +505,37 @@ def _process_declarations(self, program: Program) -> None: # Process allocator declarations - only allocate for root allocators for decl in program.declarations: if isinstance(decl, AllocatorDecl): + self.context.qreg_sizes[decl.name] = decl.capacity if decl.parent is None: for i in range(decl.capacity): self.context.get_qubit_index(decl.name, i) elif isinstance(decl, RegisterDecl): - self.context.creg_map[decl.name] = decl.size - # Create classical register - if decl.size < 64: - self._creg_ptrs[decl.name] = self._builder.call( - self._creg_funcs["create_creg"], - [llvm_ir.Constant(self._types["int"], decl.size)], - name=decl.name, + # A CReg is a mutable `[N x i1]` buffer in the + # entry block (declarations are processed at entry, before + # any control flow, so the builder is positioned there), + # zero-initialised so unmeasured/unset bits read 0. The + # record pack is a single `i64` for `int_record_output`, so + # the model caps at 64 bits. A >64-bit CReg must fail LOUD + # here -- silently dropping its storage/output (the old + # `create_creg` `size < 64` behaviour) is a miscompile. + # Fail BEFORE recording any state (no partial creg_map). + if decl.size > 64: + msg = ( + f"QIR codegen: CReg {decl.name!r} has {decl.size} bits, " + "but the static classical model packs each CReg " + "into a single i64 for " + "`__quantum__rt__int_record_output` (64-bit cap). " + ">64-bit CRegs are not supported by the QIR backend." ) + raise NotImplementedError(msg) + self.context.creg_map[decl.name] = decl.size + arr_ty = llvm_ir.ArrayType(self._types["bool"], decl.size) + creg_ptr = self._builder.alloca(arr_ty, decl.name) + self._builder.store(creg_ptr, llvm_ir.Constant(arr_ty)) + self._creg_ptrs[decl.name] = creg_ptr + if program.allocator: + self.context.qreg_sizes[program.allocator.name] = program.allocator.capacity if program.allocator and program.allocator.parent is None: for i in range(program.allocator.capacity): self.context.get_qubit_index(program.allocator.name, i) @@ -390,13 +587,133 @@ def _process_statement(self, stmt: Statement) -> None: self._process_parallel(stmt) elif isinstance(stmt, PermuteOp): self._process_permute(stmt) + elif isinstance(stmt, ReturnOp): + self._process_return(stmt) + elif isinstance(stmt, PrintOp): + # Classical-output streaming (`Print` -> Guppy `result(...)`) + # is unimplemented in the QIR backend. Silently dropping it + # loses observable program output -- fail LOUD instead. + msg = ( + "QIR codegen does not support Print (classical output " + "streaming is unimplemented; silently dropping it would " + "lose observable program output)." + ) + raise NotImplementedError(msg) + + def _process_return(self, node: ReturnOp) -> None: + """Validate that returned CLASSICAL registers have QIR storage. + + `_generate_results` records every Main-declared CReg, but a + CReg surfaced ONLY via `Return(creg)` (never measured / + assigned / read) reaches no other `_require_creg` site, so + an inline / local-scope returned CReg produced ZERO recorded + output for an explicit `Return` -- the build succeeded and + validated, the program just silently returned nothing + (same silent-output-loss class as the four + point-of-use sites). Qubit returns record no classical + output and are skipped via per-value provenance + (`ReturnOp.value_kinds` from `_convert_return`), so a + `Return(qreg)` is not false-rejected AND a returned inline + CReg whose name collides with a declared QReg is still + validated (a name-membership skip was unsound). + """ + # Provenance comes from `_convert_return` (it knows the real + # QReg/CReg object), NOT from a name-membership guess: a + # returned inline CReg can share a declared QReg's name, which + # a `qubit_map`-name skip silently mistook for a qubit return + # and dropped. Unknown kind + # ("" -- e.g. a directly-constructed ReturnOp) falls back to + # "classical", the fail-loud-safe default. + kinds = node.value_kinds + for i, value in enumerate(node.values): + kind = kinds[i] if i < len(kinds) else "classical" + if isinstance(value, str): + if kind == "quantum": + continue # qubit-register return: no classical record + self._require_creg(value) + elif isinstance(value, BitRef): + self._require_creg(value.register) + elif isinstance(value, BitExpr): + self._require_creg(value.ref.register) def _process_gate(self, node: GateOp) -> None: """Process a gate operation.""" + # Fail-loud arity guard (defense-in-depth vs the angle-first + # mis-order footgun): a parameterized gate with fewer targets + # than its qubit arity would otherwise SILENTLY emit no call + # (the per-target emit loops just iterate zero/too-few times), + # or hit a raw IndexError in the decomposition path. The SLR + # `QGate.__call__` already rejects the mis-ordered call at the + # source; this guards a malformed GateOp reaching codegen from + # any other path. Multi-target (parallel) application is fine + # (len >= arity). + if node.gate.is_parameterized and len(node.targets) < node.gate.arity: + gate_name = getattr(node.gate, "name", node.gate) + msg = ( + f"QIR codegen: parameterized gate {gate_name!r} has " + f"{len(node.targets)} qubit target(s) but needs at least " + f"{node.gate.arity} (a mis-ordered `gate(qubit, angle)` call " + "drops the qubit). Call it as `gate(angle, qubit...)`." + ) + raise NotImplementedError(msg) + + # Typed-angle guard: a user/direct-AST parameterized gate's params + # must be typed `Angle` literals (matches the Guppy backend and the + # typed-AST-dialect contract). Internal decomposition steps thread + # raw floats but reach `_process_*_gate` directly, bypassing here. + _require_typed_angle_params(node, "QIR") + qir_name = GATE_TO_QIR.get(node.gate) - if qir_name is None: - # Skip unsupported gates + if qir_name is None and node.gate in _GATE_DECOMP: + # A gate with no direct QIR primitive but a verified + # decomposition into the qir-qis ALLOWED primitive set. + # Emit each step in circuit order, routing its qubits + # through the input gate's `targets` and threading params + # (constant for non-parameterized gates like SZZ -> RZZ(pi/2); + # callable on the input gate's params for parameterized + # gates like CRZ(theta) -> RZZ(-theta/2)). LiteralExpr + # bracket-params are unwrapped to floats here so the + # callable can do arithmetic on them; non-literal + # expressions (VarExpr / BinaryExpr at gate-param position) + # are not yet supported for parameterized decomposition + # (out of scope; classical-var lowering covers it). + input_params_raw = tuple(node.params or ()) + for prim_kind, idxs, params_spec in _GATE_DECOMP[node.gate]: + prim_targets = tuple(node.targets[i] for i in idxs) + if callable(params_spec): + try: + input_params_resolved = tuple(_param_to_radians(p) for p in input_params_raw) + except (AttributeError, TypeError) as exc: + msg = ( + f"Parameterized decomposition of gate {node.gate.name} requires literal " + f"params; got non-literal: {input_params_raw}" + ) + raise NotImplementedError(msg) from exc + prim_params = params_spec(input_params_resolved) + else: + prim_params = params_spec + prim_node = replace(node, gate=prim_kind, targets=prim_targets, params=prim_params) + if prim_kind in TWO_QUBIT_GATES: + self._process_two_qubit_gate(prim_node, GATE_TO_QIR[prim_kind]) + else: + self._process_single_qubit_gate(prim_node, GATE_TO_QIR[prim_kind]) return + if qir_name is None: + # A gate with no + # GATE_TO_QIR entry was SILENTLY DROPPED -- valid QIR, + # wrong semantics, qir-qis-uncatchable. Fail + # loud instead of miscompiling. Gates with a real QIR + # lowering should be added to GATE_TO_QIR (a feature); + # until then a program using one must not be silently + # mis-emitted. + gate_name = getattr(node.gate, "name", node.gate) + msg = ( + f"QIR codegen: gate {gate_name!r} has no QIR lowering " + "(not in GATE_TO_QIR). Emitting QIR without it would be " + "a silent miscompile; it is not supported by the QIR " + "backend." + ) + raise NotImplementedError(msg) if node.gate in TWO_QUBIT_GATES: self._process_two_qubit_gate(node, qir_name) @@ -417,7 +734,10 @@ def _process_single_qubit_gate(self, node: GateOp, qir_name: str) -> None: args = [] if node.gate in PARAMETERIZED_GATES and node.params: - args.extend(llvm_ir.Constant(self._types["double"], float(p)) for p in node.params) + # An angle param reaches here as a `LiteralExpr` wrapping a + # typed `Angle` (or a raw float from a decomposition step); + # resolve to signed radians for the QIR `double`. + args.extend(llvm_ir.Constant(self._types["double"], _param_to_radians(p)) for p in node.params) args.append(qubit_ptr) self._builder.call(gate_func, args, name="") @@ -436,7 +756,10 @@ def _process_two_qubit_gate(self, node: GateOp, qir_name: str) -> None: args = [] if node.gate in PARAMETERIZED_GATES and node.params: - args.extend(llvm_ir.Constant(self._types["double"], float(p)) for p in node.params) + # An angle param reaches here as a `LiteralExpr` wrapping a + # typed `Angle` (or a raw float from a decomposition step); + # resolve to signed radians for the QIR `double`. + args.extend(llvm_ir.Constant(self._types["double"], _param_to_radians(p)) for p in node.params) args.extend([q0_ptr, q1_ptr]) self._builder.call(gate_func, args, name="") @@ -476,35 +799,118 @@ def _get_qubit_ptr(self, target: SlotRef) -> Any: self._types["qubit_ptr"], ) + def _require_creg(self, reg_name: str) -> None: + """Fail LOUD if `reg_name` has no entry-block CReg storage. + + Only CRegs declared at Main scope (`program.declarations`) + get an `alloca [N x i1]` in `_process_declarations`. An + inline / local-scope CReg -- e.g. one created in a block and + only surfaced via `Return(creg)` -- has no storage, so every + measure/assign/read against it used to be SILENTLY skipped + (the store dropped, a read folded to constant 0) and the + explicit returned value vanished from the QIS records (the + `docs.inline_measure_creg` defect surfaced). Mirror the + fail-loud doctrine: a silent miscompile must become a loud + `NotImplementedError`, not a buried wrong answer. + """ + if reg_name not in self._creg_ptrs: + msg = ( + f"QIR codegen: classical register {reg_name!r} is " + "used/measured/returned but was not declared at Main " + "scope, so it has no QIR storage. Inline / local-scope " + "CRegs are not supported by the QIR backend (their " + "values would be silently dropped from the recorded " + f"output). Declare {reg_name!r} in Main(...)." + ) + raise NotImplementedError(msg) + + def _creg_bit_ptr(self, reg_name: str, index: int) -> Any: + """`getelementptr [N x i1], [N x i1]* %creg, i64 0, i64 index`. + + Emitted at point-of-use (not cached across blocks) so the pointer + always dominates its uses under control flow. + """ + # Resolve any active logical permutation (classical bits, like + # qubits, are relabelled by a Permute -- mirrors Guppy's + # mem_swap-based classical permute). + reg_name, index = self.context.permutation_map.get((reg_name, index), (reg_name, index)) + return self._builder.gep( + self._creg_ptrs[reg_name], + [ + llvm_ir.Constant(self._types["int"], 0), + llvm_ir.Constant(self._types["int"], index), + ], + name="", + ) + + def _as_i1(self, value: Any) -> Any: + """Coerce an evaluated expression to `i1` (bit-store / predicate).""" + if value.type == self._types["bool"]: + return value + return self._builder.icmp_signed( + "!=", + value, + llvm_ir.Constant(self._types["int"], 0), + ) + + def _as_i64(self, value: Any) -> Any: + """Coerce an evaluated expression to `i64` (the canonical width).""" + if value.type == self._types["bool"]: + return self._builder.zext(value, self._types["int"]) + return value + def _process_measure(self, node: MeasureOp) -> None: - """Process a measurement operation.""" + """Process a measurement operation. + + Every measured target emits `__quantum__qis__mz__body(q, %Result*)` + against a static result slot; `read_result` + `store` into the CReg + buffer only when a classical result target exists (qir-qis accepts + `mz` without `read_result`, but rejects `read_result` before `mz`). + """ for i, target in enumerate(node.targets): self.context.measurement_count += 1 + slot = self.context.measurement_count - 1 qubit_ptr = self._get_qubit_ptr(target) + result_ptr = llvm_ir.Constant(self._types["int"], slot).inttoptr( + self._types["result_ptr"], + ) + self._builder.call(self._mz_body, [qubit_ptr, result_ptr], name="") if i < len(node.results): result = node.results[i] - if result.register in self._creg_ptrs: - creg_ptr = self._creg_ptrs[result.register] - bit_index = llvm_ir.Constant(self._types["int"], result.index) - self._builder.call( - self._mz_to_bit, - [qubit_ptr, creg_ptr, bit_index], - name="", - ) + self._require_creg(result.register) + bit = self._builder.call( + self._read_result, + [result_ptr], + name="", + ) + self._builder.store( + self._creg_bit_ptr(result.register, result.index), + bit, + ) def _process_prepare(self, node: PrepareOp) -> None: - """Process a prepare/reset operation.""" + """Process a prepare/reset operation (Z-reset + canonical basis tail).""" + tail = prep_tail(node.basis) if node.slots is None: + # prepare_all is a pre-existing no-op gap; a NON-PZ + # prepare_all silently doing nothing would be a basis + # miscompile -- fail loud rather than extend the gap. + if tail: + msg = f"QIR codegen: prepare_all with non-PZ basis {node.basis!r} is not supported" + raise NotImplementedError(msg) return reset_func = self._get_or_create_gate("reset", has_params=False, num_qubits=1) + tail_funcs = [(self._get_or_create_gate(GATE_TO_QIR[gk], has_params=False, num_qubits=1)) for gk in tail] for slot in node.slots: qubit_ptr = self._get_qubit_ptr( SlotRef(allocator=node.allocator, index=slot), ) self._builder.call(reset_func, [qubit_ptr], name="") + for func in tail_funcs: + self._builder.call(func, [qubit_ptr], name="") def _process_barrier(self, node: BarrierOp) -> None: """Process a barrier operation.""" @@ -534,26 +940,34 @@ def _process_assign(self, node: AssignOp) -> None: """Process an assignment operation.""" if isinstance(node.target, BitRef): reg_name = node.target.register - if reg_name not in self._creg_ptrs: - return - - creg_ptr = self._creg_ptrs[reg_name] - bit_index = llvm_ir.Constant(self._types["int"], node.target.index) - - # Evaluate RHS - rhs = self._eval_expression(node.value) - - self._builder.call( - self._creg_funcs["set_creg_bit"], - [creg_ptr, bit_index, rhs], - name="", + self._require_creg(reg_name) + rhs = self._as_i1(self._eval_expression(node.value)) + self._builder.store( + self._creg_bit_ptr(reg_name, node.target.index), + rhs, ) + elif isinstance(node.target, str): + # Whole-CReg `c.set(int)` (converter.py:928/930 -> target=str). + # Unpack the i64 value bit-by-bit into the buffer. + reg_name = node.target + self._require_creg(reg_name) + size = self.context.creg_map.get(reg_name, 0) + val = self._as_i64(self._eval_expression(node.value)) + for i in range(size): + shifted = self._builder.lshr( + val, + llvm_ir.Constant(self._types["int"], i), + ) + self._builder.store( + self._creg_bit_ptr(reg_name, i), + self._builder.trunc(shifted, self._types["bool"]), + ) def _eval_expression(self, expr: Expression) -> Any: """Evaluate an expression to an LLVM value.""" if isinstance(expr, LiteralExpr): if isinstance(expr.value, bool): - return llvm_ir.Constant(self._types["bool"], 1 if expr.value else 0) + return llvm_ir.Constant(self._types["int"], 1 if expr.value else 0) if isinstance(expr.value, int): return llvm_ir.Constant(self._types["int"], expr.value) if isinstance(expr.value, float): @@ -562,19 +976,32 @@ def _eval_expression(self, expr: Expression) -> Any: if isinstance(expr, BitExpr): reg_name = expr.ref.register - if reg_name not in self._creg_ptrs: - return llvm_ir.Constant(self._types["bool"], 0) - creg_ptr = self._creg_ptrs[reg_name] - bit_index = llvm_ir.Constant(self._types["int"], expr.ref.index) - return self._builder.call( - self._creg_funcs["get_creg_bit"], - [creg_ptr, bit_index], + self._require_creg(reg_name) + # `load i1, gep c[i]` then `zext -> i64` (canonical width). + bit = self._builder.load( + self._creg_bit_ptr(reg_name, expr.ref.index), name="", ) + return self._builder.zext(bit, self._types["int"]) if isinstance(expr, VarExpr): - # Variable lookup - for now just return 0 - return llvm_ir.Constant(self._types["int"], 0) + # A classical `VarExpr` denotes a whole CReg used as a + # scalar (e.g. `If(m == 0)`, `o.set(m + n)`, Steane + # `smid_flag_x`). SLR has no scalar-integer classical type + # (verified: `pecos.slr.vars` exposes only Reg/CReg/Bit/ + # SymbolicElem/LoopVar). A `LoopVar` used as a symbolic index + # is resolved by `For` unrolling; a `LoopVar` appearing as a + # bare classical scalar (e.g. `If(i == 0)`) is NOT substituted + # by `_process_for` and reaches here as `VarExpr(name="i")`, + # where `_require_creg` fails it loud (no CReg named `i`) -- + # never silent-0. The lowering + # is the existing i64 pack (`OR_i (zext c[i] << i)`) + # factored out of `_generate_results` as `_pack_creg`. A + # `VarExpr` whose name is not a declared Main-scope CReg + # fails LOUD via `_require_creg`, preserving the + # anti-silent-0 guarantee. + self._require_creg(expr.name) + return self._pack_creg(expr.name) if isinstance(expr, BinaryExpr): left = self._eval_expression(expr.left) @@ -591,11 +1018,21 @@ def _eval_expression(self, expr: Expression) -> Any: return self._builder.not_(operand) return operand - return llvm_ir.Constant(self._types["int"], 0) + # Any unhandled expression type silently evaluating to constant + # 0 is a value miscompile qir-qis cannot catch (the + # fail-loud class -- same smell as the VarExpr arm above). Every + # currently-reachable type is handled above; a new/unhandled one + # must fail LOUD, not lower as 0. + msg = ( + f"QIR codegen: unsupported classical expression " + f"{type(expr).__name__} (it must not be silently evaluated " + "as 0 -- that would be a value miscompile)." + ) + raise NotImplementedError(msg) def _process_if(self, node: IfStmt) -> None: """Process an if statement.""" - pred = self._eval_expression(node.condition) + pred = self._as_i1(self._eval_expression(node.condition)) if node.else_body: with self._builder.if_else(pred) as (then, otherwise): @@ -611,19 +1048,54 @@ def _process_if(self, node: IfStmt) -> None: self._process_statement(stmt) def _process_while(self, node: WhileStmt) -> None: - """Process a while loop.""" - # QIR supports loops through LLVM branch instructions - # For simplicity, we process the body once (approximation) - for stmt in node.body: - self._process_statement(stmt) + """`While` is not supported by the QIR backend. + + Unbounded iteration / fixed-point linear state through an + unknown iteration count is out of scope for the sound emitter; + the AST->Guppy path also rejects `While`. + Fail LOUD here -- the previous single-pass approximation + silently dropped the loop condition and all iterations (a + miscompile qir-qis cannot catch, since one pass is valid QIR). + """ + msg = ( + "QIR codegen does not support While loops (unbounded " + "iteration / fixed-point linear state is out of scope for " + "the QIR backend; a single-pass approximation would be a " + "silent miscompile)." + ) + raise NotImplementedError(msg) + + def _static_int_bound(self, expr: Any, which: str) -> int: + """Resolve a static integer `For` bound. + + The converter wraps integer range bounds in `LiteralExpr` + (`converter.py` `_convert_for`), so the bound is never a raw + `int`. A non-literal / non-int bound is a symbolic/dynamic + `For`, which is unsupported -- fail LOUD rather than silently + drop the loop body (the previous `isinstance(int)` guard was + always false, so EVERY `For` body was silently dropped). + """ + if isinstance(expr, LiteralExpr) and isinstance(expr.value, int) and not isinstance(expr.value, bool): + return expr.value + msg = ( + f"QIR codegen: For loop {which} bound is not a static integer " + f"({type(expr).__name__}); only fixed-bound `For(i, , " + ")` is supported (symbolic/dynamic For is out of scope -- " + "and must not silently drop the loop body)." + ) + raise NotImplementedError(msg) def _process_for(self, node: ForStmt) -> None: - """Process a for loop by unrolling.""" - if isinstance(node.start, int) and isinstance(node.stop, int): - step = node.step if isinstance(node.step, int) else 1 - for _ in range(node.start, node.stop, step): - for stmt in node.body: - self._process_statement(stmt) + """Unroll a static fixed-bound `For` (v1-supported).""" + start = self._static_int_bound(node.start, "start") + stop = self._static_int_bound(node.stop, "stop") + step = 1 if node.step is None else self._static_int_bound(node.step, "step") + if step == 0: + msg = "QIR codegen: For loop step is 0 (infinite loop); only a non-zero static step is supported." + raise NotImplementedError(msg) + for _ in range(start, stop, step): + for stmt in node.body: + self._process_statement(stmt) def _process_repeat(self, node: RepeatStmt) -> None: """Process a repeat loop by unrolling.""" @@ -637,25 +1109,118 @@ def _process_parallel(self, node: ParallelBlock) -> None: for stmt in node.body: self._process_statement(stmt) - def _process_permute(self, node: PermuteOp) -> None: - """Process a permutation operation. + def _expand_permute_ref(self, ref: str) -> list[tuple[str, int]]: + """Expand a Permute ref string to logical (reg, index) pairs. - Updates the internal allocator mapping to swap qubit references. - QIR doesn't have a permute instruction, so this just updates - how we map allocator names to qubit indices. + `name[idx]` -> a single element; bare `name` -> every element + of the register (QReg capacity or CReg size). Mirrors the + Guppy codegen's `_expand_permute_ref`. """ - # Swap the allocator offsets - for src, tgt in zip(node.sources, node.targets, strict=False): - # Get current offsets - src_offset = self.context.allocator_offsets.get(src, 0) - tgt_offset = self.context.allocator_offsets.get(tgt, 0) - # Swap them - self.context.allocator_offsets[src] = tgt_offset - self.context.allocator_offsets[tgt] = src_offset + if ref.endswith("]") and "[" in ref: + name, idx = ref[:-1].split("[", 1) + return [(name, int(idx))] + if ref in self.context.qreg_sizes: + return [(ref, i) for i in range(self.context.qreg_sizes[ref])] + if ref in self.context.creg_map: + return [(ref, i) for i in range(self.context.creg_map[ref])] + msg = f"QIR codegen: unknown Permute ref {ref!r}" + raise NotImplementedError(msg) + + def _process_permute(self, node: PermuteOp) -> None: + """Realize a Permute as a static logical relabel. + + QIR (and the Selene runtime) have no permute intrinsic, so -- + exactly like the legacy gen_qir permutation_map and the Guppy + linearity tracker's `.permute()` -- a permutation is realized + at compile time by relabelling which storage each logical + (reg, index) ref resolves to. Build the source->target logical + mapping, require it bijective over the same ref set, then + compose it ATOMICALLY into the standing permutation_map + (snapshot old, then `map[s] = old.get(t, t)`). Every + qubit-ref and classical-bit-ref lowering consults + permutation_map, so subsequent refs hit the permuted storage. + Works uniformly for whole-register and element-wise, QReg and + CReg. + """ + if len(node.sources) != len(node.targets): + msg = "QIR codegen: Permute source/target length mismatch" + raise NotImplementedError(msg) + + # Accumulate the expanded refs as LISTS first and validate + # BEFORE building the dict: a dict would silently collapse a + # duplicate expanded source (e.g. + # `Permute([a[0], a[0]], [b[0], a[0]])`) so a genuinely + # non-bijective Permute would compile -- a silent miscompile. + src_all: list[tuple[str, int]] = [] + tgt_all: list[tuple[str, int]] = [] + for source, target in zip(node.sources, node.targets, strict=True): + src_refs = self._expand_permute_ref(source) + tgt_refs = self._expand_permute_ref(target) + if len(src_refs) != len(tgt_refs): + msg = f"QIR codegen: Permute element count mismatch for {source!r} -> {target!r}" + raise NotImplementedError(msg) + src_all.extend(src_refs) + tgt_all.extend(tgt_refs) + + if len(src_all) != len(set(src_all)): + msg = "QIR codegen: Permute has a duplicate source ref (not a permutation)" + raise NotImplementedError(msg) + if len(tgt_all) != len(set(tgt_all)): + msg = "QIR codegen: Permute has a duplicate target ref (not a permutation)" + raise NotImplementedError(msg) + if set(src_all) != set(tgt_all): + msg = "QIR codegen: Permute must be bijective over the same ref set" + raise NotImplementedError(msg) + + mapping: dict[tuple[str, int], tuple[str, int]] = dict(zip(src_all, tgt_all, strict=True)) + + # Human-readable comment mirroring the legacy gen_qir format + # (rendered from the post-substitution sources so it stays + # correct inside a flattened BlockCall). + if node.add_comment and node.sources: + if node.whole_register and len(node.sources) >= 2: + self._builder.comment(f"; Permutation: {node.sources[0]} <-> {node.sources[1]}") + else: + pairs = ", ".join(f"{s} -> {t}" for s, t in zip(node.sources, node.targets, strict=True)) + self._builder.comment(f"; Permutation: {pairs}") + + # Compose ATOMICALLY (Guppy `.permute` semantics): a whole + # register swap arrives as sources=(a,b)/targets=(b,a), so a + # sequential apply would cancel to a no-op; snapshotting the + # old map first applies the relabel exactly once. + old = dict(self.context.permutation_map) + for s_ref, t_ref in mapping.items(): + self.context.permutation_map[s_ref] = old.get(t_ref, t_ref) + + def _pack_creg(self, reg_name: str) -> Any: + """Pack a CReg's `[N x i1]` buffer into a single i64 value. + + `OR_i (zext c[i] << i)` -- this is the canonical SLR CReg-as- + integer lowering. Used by both `_generate_results` (for the + `__quantum__rt__int_record_output` call) and by + `_eval_expression(VarExpr)` (a whole-CReg scalar + reference in `If(m == 0)` / `o.set(m + n)` / etc.). Sharing + the pack ensures the record-output and VarExpr interpretations + of `m` are bit-identical -- the same packed i64. + """ + c_int: Any = llvm_ir.Constant(self._types["int"], 0) + for i in range(self.context.creg_map.get(reg_name, 0)): + bit = self._builder.load( + self._creg_bit_ptr(reg_name, i), + name="", + ) + widened = self._builder.zext(bit, self._types["int"]) + if i: + widened = self._builder.shl( + widened, + llvm_ir.Constant(self._types["int"], i), + ) + c_int = self._builder.or_(c_int, widened) + return c_int def _generate_results(self) -> None: """Generate result output calls.""" - for reg_name, creg_ptr in self._creg_ptrs.items(): + for reg_name in self._creg_ptrs: # Create tag for the register name reg_name_bytes = bytearray(reg_name.encode("utf-8")) tag_type = llvm_ir.ArrayType(llvm_ir.IntType(8), len(reg_name)) @@ -664,12 +1229,10 @@ def _generate_results(self) -> None: reg_tag.global_constant = True reg_tag.linkage = "private" - # Convert creg to int and output - c_int = self._builder.call( - self._creg_funcs["creg_to_int"], - [creg_ptr], - name="", - ) + # Pack the [N x i1] buffer into one i64 (shared with the + # VarExpr lowering). + c_int = self._pack_creg(reg_name) + reg_tag_gep = reg_tag.gep( ( llvm_ir.Constant(llvm_ir.IntType(32), 0), @@ -688,9 +1251,25 @@ def _finalize_module(self) -> str: mod_w_attr = ll_text.replace("@main()", "@main() #0") mod_w_attr += '\nattributes #0 = { "entry_point"' - mod_w_attr += ' "qir_profiles"="custom"' + # adaptive_profile: PECOS emits measurement-conditioned `If` -> + # adaptive, not base profile. + mod_w_attr += ' "qir_profiles"="adaptive_profile"' + mod_w_attr += ' "output_labeling_schema"="labeled"' mod_w_attr += f' "required_num_qubits"="{self.context.qubit_count}"' mod_w_attr += f' "required_num_results"="{self.context.measurement_count}" }}' + + # QIR module flags (Adaptive Profile). pecos_rslib_llvm.ir.Module has + # no named-metadata API, so append as raw IR text -- same approach as + # the entry attributes above. The emitted module carries no `!` + # metadata, so !0..!4 are collision-free. The static classical + # model sets dynamic_*/arrays = false (it keeps the static + # %Result + mutable local-buffer model; flags must match). + mod_w_attr += "\n!llvm.module.flags = !{!0, !1, !2, !3, !4}" + mod_w_attr += '\n!0 = !{i32 1, !"qir_major_version", i32 1}' + mod_w_attr += '\n!1 = !{i32 7, !"qir_minor_version", i32 0}' + mod_w_attr += '\n!2 = !{i32 1, !"dynamic_qubit_management", i1 false}' + mod_w_attr += '\n!3 = !{i32 1, !"dynamic_result_management", i1 false}' + mod_w_attr += '\n!4 = !{i32 1, !"arrays", i1 false}' return mod_w_attr def _fix_internal_consts(self, llvm_ir: str) -> str: diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/quantum_circuit.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/quantum_circuit.py index 981d63f4a..45c050ab9 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/codegen/quantum_circuit.py +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/quantum_circuit.py @@ -28,6 +28,8 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from pecos.slr.ast.codegen._block_flatten import flatten_block_calls +from pecos.slr.ast.codegen._prep_tail import prep_tail from pecos.slr.ast.nodes import ( AllocatorDecl, BarrierOp, @@ -35,10 +37,12 @@ GateKind, GateOp, IfStmt, + LiteralExpr, MeasureOp, ParallelBlock, PermuteOp, PrepareOp, + PrintOp, RegisterDecl, RepeatStmt, WhileStmt, @@ -60,8 +64,6 @@ # Hadamard GateKind.H: "H", # Phase gates - GateKind.S: "S", - GateKind.Sdg: "SDG", GateKind.T: "T", GateKind.Tdg: "TDG", # Square root gates @@ -100,6 +102,9 @@ GateKind.SYYdg, GateKind.SZZdg, GateKind.RZZ, + GateKind.CRX, + GateKind.CRY, + GateKind.CRZ, } @@ -112,6 +117,12 @@ class QCCodeGenContext: current_tick: dict[str, set] = field(default_factory=dict) allocator_parents: dict[str, str | None] = field(default_factory=dict) allocator_offsets: dict[str, int] = field(default_factory=dict) + qreg_sizes: dict[str, int] = field(default_factory=dict) # name -> capacity + # Static logical permutation (same model as the QIR codegen / + # the Guppy linearity tracker -- QuantumCircuit has no permute + # instruction). Maps a logical (reg, index) ref to the (reg, + # index) whose qubit it resolves to; consulted in `get_qubit`. + permutation_map: dict[tuple[str, int], tuple[str, int]] = field(default_factory=dict) def get_root_allocator(self, name: str) -> str: """Get the root allocator for a given allocator name.""" @@ -130,6 +141,11 @@ def get_qubit(self, allocator: str, index: int) -> int: For child allocators, translates to root allocator with computed offset. """ + # Resolve any active logical permutation first (identity until + # a Permute runs; decl-time pre-population sees the empty map, + # so real qubits are still allocated 1:1). + allocator, index = self.permutation_map.get((allocator, index), (allocator, index)) + # Translate to root allocator and absolute index root = self.get_root_allocator(allocator) abs_index = self.get_absolute_index(allocator, index) @@ -166,6 +182,8 @@ def generate(self, program: Program) -> QuantumCircuit: """ from pecos.circuits.quantum_circuit import QuantumCircuit # noqa: PLC0415 + program = flatten_block_calls(program) + self.context = QCCodeGenContext() self.circuit = QuantumCircuit() self._in_parallel = False @@ -198,12 +216,15 @@ def _process_declarations(self, program: Program) -> None: # Allocate qubits only for root allocators for decl in program.declarations: if isinstance(decl, AllocatorDecl): + self.context.qreg_sizes[decl.name] = decl.capacity if decl.parent is None: for i in range(decl.capacity): self.context.get_qubit(decl.name, i) elif isinstance(decl, RegisterDecl): pass # Classical registers don't need qubit allocation + if program.allocator: + self.context.qreg_sizes[program.allocator.name] = program.allocator.capacity if program.allocator and program.allocator.parent is None: for i in range(program.allocator.capacity): self.context.get_qubit(program.allocator.name, i) @@ -259,11 +280,36 @@ def _process_statement(self, stmt: Statement) -> None: self._process_parallel(stmt) elif isinstance(stmt, PermuteOp): self._process_permute(stmt) + elif isinstance(stmt, PrintOp): + # Print is not yet implemented for the QuantumCircuit + # backend. Fail LOUD rather than silently drop observable + # program output (fail-loud principle). Unlike Stim, PECOS owns + # the QuantumCircuit format, so a `Print` representation + # could be added later -- this is "not yet", not "cannot". + msg = ( + "QuantumCircuit codegen does not yet support Print " + "(classical output streaming is not implemented for this " + "backend; it is silently-drop-free by design -- fail " + "loud. PECOS controls this format, so Print support may " + "be added in future)." + ) + raise NotImplementedError(msg) def _process_gate(self, node: GateOp) -> None: """Process a gate operation.""" gate_name = GATE_TO_QC.get(node.gate, node.gate.name) + if node.params: + # Parameterized gates (RX/RY/RZ/RZZ/CRX/CRY/CRZ etc.): + # PECOS QuantumCircuit ticks are parallel sets keyed on + # (gate_name, params), so a tick mixing different param + # values would lose information. Flush the current tick, + # then emit the parameterized gate as its own tick via + # `circuit.append(gate, locations, angles=[...])` which the + # Rust gate registry routes to the typed-param dispatcher. + self._process_parameterized_gate(node, gate_name) + return + if node.gate in TWO_QUBIT_GATES: self._process_two_qubit_gate(node, gate_name) else: @@ -300,6 +346,53 @@ def _process_two_qubit_gate(self, node: GateOp, gate_name: str) -> None: ) self._add_to_tick(gate_name, (q0, q1)) + def _process_parameterized_gate(self, node: GateOp, gate_name: str) -> None: + """Emit a parameterized gate (rotation angle threaded through). + + Resolves `LiteralExpr` bracket-params to raw floats (the AST + converter wraps the `qb.RZ(0.5, q)` angle as a `LiteralExpr`, so a bare + `float(p)` would fail). The QC `circuit.append(..., + angles=...)` path forwards the angle list to Rust's typed- + parameter dispatcher (e.g. `RZ` requires 1 angle, `RXXRYYRZZ` + requires 3). + + Flushes the current tick before emission so the parameterized + gate gets its own tick -- mixing different param values within + a single tick would lose information (tick batches by + `(gate_name, target)` only). Non-literal expressions (VarExpr, + BinaryExpr at gate-param position) are not yet supported for + QC parameterized gates; they fail loud here. + """ + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + # Typed-angle guard: a user/direct-AST parameterized gate's params + # must be typed `Angle` literals (matches Guppy + the typed-AST + # contract); reject bare floats so backends do not diverge. + for p in node.params: + if not (isinstance(p, LiteralExpr) and isinstance(p.value, Angle)): + msg = ( + f"QuantumCircuit codegen: parameterized gate {gate_name!r} requires typed `Angle` " + f"params (use `rad(...)` / `turns(...)` in SLR); got {p!r}." + ) + raise NotImplementedError(msg) + angles = [p.value.value.to_radians_signed() for p in node.params] + + self._flush_tick() + if node.gate in TWO_QUBIT_GATES: + if len(node.targets) < 2: + msg = f"QuantumCircuit codegen: two-qubit gate {gate_name!r} needs >=2 targets, got {len(node.targets)}" + raise ValueError(msg) + # Emit each consecutive pair (mirrors the un-parameterized + # two-qubit path which iterates over target pairs). + for i in range(0, len(node.targets) - 1, 2): + q0 = self.context.get_qubit(node.targets[i].allocator, node.targets[i].index) + q1 = self.context.get_qubit(node.targets[i + 1].allocator, node.targets[i + 1].index) + self.circuit.append(gate_name, {(q0, q1)}, angles=angles) + else: + for target in node.targets: + qubit = self.context.get_qubit(target.allocator, target.index) + self.circuit.append(gate_name, {qubit}, angles=angles) + def _process_measure(self, node: MeasureOp) -> None: """Process a measurement operation.""" for target in node.targets: @@ -307,13 +400,26 @@ def _process_measure(self, node: MeasureOp) -> None: self._add_to_tick("Measure", qubit) def _process_prepare(self, node: PrepareOp) -> None: - """Process a prepare/reset operation.""" + """Process a prepare/reset operation (Z-reset + canonical basis tail). + + QC ticks are parallel sets; reset and the Clifford tail MUST + be sequential, so each is its own flushed tick (PZ has no + tail -> byte-identical to the prior behaviour). + """ + tail = prep_tail(node.basis) if node.slots is None: + if tail: + msg = f"QuantumCircuit codegen: prepare_all with non-PZ basis {node.basis!r} is not supported" + raise NotImplementedError(msg) return - for slot in node.slots: - qubit = self.context.get_qubit(node.allocator, slot) + qubits = [self.context.get_qubit(node.allocator, slot) for slot in node.slots] + for qubit in qubits: self._add_to_tick("RESET", qubit) + for gk in tail: + self._flush_tick() # sequence reset/prev-tail before this gate + for qubit in qubits: + self._add_to_tick(GATE_TO_QC[gk], qubit) def _add_to_tick(self, gate_name: str, target: int | tuple[int, int]) -> None: """Add a gate to the current tick.""" @@ -349,17 +455,40 @@ def _process_while(self, node: WhileStmt) -> None: ) raise NotImplementedError(msg) + def _static_int_bound(self, expr: object, which: str) -> int: + """Resolve a static integer `For` bound. + + The AST converter wraps integer range bounds in `LiteralExpr` + (`converter.py` `_convert_for`), so the bound is never a raw + `int` -- the old `isinstance(int)` guard was always false, so + the `else` branch *rejected every* static `For` (even valid + `For(i, 0, 3)`). Resolve the literal; a non-literal / non-int + bound is a symbolic/dynamic `For`: fail LOUD. + """ + if isinstance(expr, LiteralExpr) and isinstance(expr.value, int) and not isinstance(expr.value, bool): + return expr.value + msg = ( + f"QuantumCircuit codegen: For loop {which} bound is not a " + f"static integer ({type(expr).__name__}); only fixed-bound " + "`For(i, , )` is supported (symbolic/dynamic For " + "is out of scope)." + ) + raise NotImplementedError(msg) + def _process_for(self, node: ForStmt) -> None: - """Process a for loop by unrolling.""" - # Unroll if bounds are static - if isinstance(node.start, int) and isinstance(node.stop, int): - step = node.step if isinstance(node.step, int) else 1 - for _ in range(node.start, node.stop, step): - for stmt in node.body: - self._process_statement(stmt) - else: - msg = f"Cannot unroll For loop with non-integer bounds: start={node.start}, stop={node.stop}" - raise TypeError(msg) + """Unroll a static fixed-bound `For` (v1-supported).""" + start = self._static_int_bound(node.start, "start") + stop = self._static_int_bound(node.stop, "stop") + step = 1 if node.step is None else self._static_int_bound(node.step, "step") + if step == 0: + msg = ( + "QuantumCircuit codegen: For loop step is 0 (infinite " + "loop); only a non-zero static step is supported." + ) + raise NotImplementedError(msg) + for _ in range(start, stop, step): + for stmt in node.body: + self._process_statement(stmt) def _process_repeat(self, node: RepeatStmt) -> None: """Process a repeat loop by unrolling.""" @@ -381,21 +510,69 @@ def _process_parallel(self, node: ParallelBlock) -> None: self._in_parallel = False self._flush_tick() - def _process_permute(self, node: PermuteOp) -> None: - """Process a permutation operation. + def _expand_permute_ref(self, ref: str) -> list[tuple[str, int]]: + """Expand a Permute ref string to logical (reg, index) pairs. + + `name[idx]` -> a single element; bare `name` -> every element + of the qubit register. QuantumCircuit has no realized + classical-register model, so a bare CReg permute is not + realizable -> fail loud (never a silent no-op). Mirrors the + QIR codegen's helper. + """ + if ref.endswith("]") and "[" in ref: + name, idx = ref[:-1].split("[", 1) + return [(name, int(idx))] + if ref in self.context.qreg_sizes: + return [(ref, i) for i in range(self.context.qreg_sizes[ref])] + msg = ( + f"QuantumCircuit codegen: whole-register Permute of {ref!r} " + "is not supported (no classical-register model); a " + "qubit-register or element-wise Permute is realizable." + ) + raise NotImplementedError(msg) - Updates the internal allocator mapping to swap qubit references. - QuantumCircuit doesn't have a permute instruction, so this just updates - how we map allocator names to qubit indices. + def _process_permute(self, node: PermuteOp) -> None: + """Realize a Permute as a static logical relabel. + + QuantumCircuit has no permute instruction, so -- exactly like + the QIR codegen and the Guppy linearity tracker -- a + Permute is realized at compile time by relabelling which qubit + each logical (reg, index) ref resolves to (consulted in + `get_qubit`). The old `allocator_offsets` swap was a no-op for + element-wise refs and self-cancelling for a whole-register + (a,b)/(b,a) pair -- a silent miscompile. """ - # Swap the allocator offsets - for src, tgt in zip(node.sources, node.targets, strict=False): - # Get current offsets - src_offset = self.context.allocator_offsets.get(src, 0) - tgt_offset = self.context.allocator_offsets.get(tgt, 0) - # Swap them - self.context.allocator_offsets[src] = tgt_offset - self.context.allocator_offsets[tgt] = src_offset + if len(node.sources) != len(node.targets): + msg = "QuantumCircuit codegen: Permute source/target length mismatch" + raise NotImplementedError(msg) + + # Validate the expanded ref lists BEFORE building the dict (a + # dict would silently collapse a duplicate source). + src_all: list[tuple[str, int]] = [] + tgt_all: list[tuple[str, int]] = [] + for source, target in zip(node.sources, node.targets, strict=True): + src_refs = self._expand_permute_ref(source) + tgt_refs = self._expand_permute_ref(target) + if len(src_refs) != len(tgt_refs): + msg = f"QuantumCircuit codegen: Permute element count mismatch for {source!r} -> {target!r}" + raise NotImplementedError(msg) + src_all.extend(src_refs) + tgt_all.extend(tgt_refs) + + if len(src_all) != len(set(src_all)): + msg = "QuantumCircuit codegen: Permute has a duplicate source ref (not a permutation)" + raise NotImplementedError(msg) + if len(tgt_all) != len(set(tgt_all)): + msg = "QuantumCircuit codegen: Permute has a duplicate target ref (not a permutation)" + raise NotImplementedError(msg) + if set(src_all) != set(tgt_all): + msg = "QuantumCircuit codegen: Permute must be bijective over the same ref set" + raise NotImplementedError(msg) + + # Compose ATOMICALLY (snapshot old, then map[s] = old.get(t, t)) + # so a whole-register (a,b)/(b,a) pair applies once. + old = dict(self.context.permutation_map) + self.context.permutation_map.update({s: old.get(t, t) for s, t in zip(src_all, tgt_all, strict=True)}) def ast_to_quantum_circuit(program: Program) -> QuantumCircuit: diff --git a/python/quantum-pecos/src/pecos/slr/ast/codegen/stim.py b/python/quantum-pecos/src/pecos/slr/ast/codegen/stim.py index 7d2768f37..f3cf2f31e 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/codegen/stim.py +++ b/python/quantum-pecos/src/pecos/slr/ast/codegen/stim.py @@ -28,6 +28,8 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from pecos.slr.ast.codegen._block_flatten import flatten_block_calls +from pecos.slr.ast.codegen._prep_tail import prep_tail from pecos.slr.ast.nodes import ( AllocatorDecl, BarrierOp, @@ -35,10 +37,12 @@ GateKind, GateOp, IfStmt, + LiteralExpr, MeasureOp, ParallelBlock, PermuteOp, PrepareOp, + PrintOp, RegisterDecl, RepeatStmt, WhileStmt, @@ -61,8 +65,6 @@ # Hadamard GateKind.H: "H", # Phase gates - GateKind.S: "S", - GateKind.Sdg: "S_DAG", GateKind.T: "T", GateKind.Tdg: "T_DAG", # Square root gates (mapped to S variants) @@ -100,6 +102,34 @@ GateKind.RZZ, } +# Decomposition table: Clifford-only PECOS gates with no direct Stim +# primitive, lowered into the Stim-native gate set. Each entry is a +# sequence of (Stim gate name, qubit-index tuple-into-targets) steps +# in CIRCUIT order (first applied first). The compositions mirror the +# already-verified QIR `_GATE_DECOMP` entries: each H/S/S_DAG step is +# itself a Stim primitive, so a decomp using only those primitives is +# correct iff the QIR-side decomp is correct (which was verified +# up-to-phase against the PECOS oracle and end-to-end via selene). +# +# Stim is a Clifford-only stabilizer simulator: gates that involve +# non-Clifford rotations (CH, CRX/CRY/CRZ, T-decomposable forms) or +# arbitrary continuous rotations (RX/RY/RZ/RZZ at non-pi/2 angles) +# remain fail-loud here -- there is no Clifford-only decomposition, +# and the qir-qis-style "decompose to a Clifford+T target" path is +# fundamentally unavailable to Stim (the user's "support all gates +# in all languages, IF DECOMPOSABLE" directive admits this caveat). +_GATE_DECOMP: dict[GateKind, tuple[tuple[str, tuple[int, ...]], ...]] = { + # F = H . SZdg (matrix product; circuit-time: SZdg first, then H). + # F cycles Paulis X -> Y -> Z -> X (face rotation of the Bloch cube). + GateKind.F: (("S_DAG", (0,)), ("H", (0,))), + # Fdg = SZ . H (inverse of F; cycles X <- Y <- Z <- X). + GateKind.Fdg: (("H", (0,)), ("S", (0,))), + # F4 = SZdg . H -- the F-rotation around a different face axis. + GateKind.F4: (("H", (0,)), ("S_DAG", (0,))), + # F4dg = H . SZ -- inverse of F4. + GateKind.F4dg: (("S", (0,)), ("H", (0,))), +} + @dataclass class StimCodeGenContext: @@ -110,6 +140,12 @@ class StimCodeGenContext: measurement_count: int = 0 allocator_parents: dict[str, str | None] = field(default_factory=dict) allocator_offsets: dict[str, int] = field(default_factory=dict) + qreg_sizes: dict[str, int] = field(default_factory=dict) # name -> capacity + # Static logical permutation (same model as the QIR codegen / + # the Guppy linearity tracker -- Stim has no permute instruction). + # Maps a logical (reg, index) ref to the (reg, index) whose qubit + # it resolves to. Consulted at every qubit-ref lowering. + permutation_map: dict[tuple[str, int], tuple[str, int]] = field(default_factory=dict) def get_root_allocator(self, name: str) -> str: """Get the root allocator for a given allocator name.""" @@ -128,6 +164,11 @@ def get_qubit(self, allocator: str, index: int) -> int: For child allocators, translates to root allocator with computed offset. """ + # Resolve any active logical permutation first (identity until + # a Permute runs; decl-time pre-population sees the empty map, + # so real qubits are still allocated 1:1). + allocator, index = self.permutation_map.get((allocator, index), (allocator, index)) + # Translate to root allocator and absolute index root = self.get_root_allocator(allocator) abs_index = self.get_absolute_index(allocator, index) @@ -165,6 +206,8 @@ def generate(self, program: Program) -> stim.Circuit: """ import stim # noqa: PLC0415 + program = flatten_block_calls(program) + self.context = StimCodeGenContext() self.circuit = stim.Circuit() @@ -193,6 +236,7 @@ def _process_declarations(self, program: Program) -> None: # Allocate qubits only for root allocators for decl in program.declarations: if isinstance(decl, AllocatorDecl): + self.context.qreg_sizes[decl.name] = decl.capacity # Only allocate for root allocators (those without parents) if decl.parent is None: for i in range(decl.capacity): @@ -200,6 +244,8 @@ def _process_declarations(self, program: Program) -> None: elif isinstance(decl, RegisterDecl): pass # Classical registers don't need qubit allocation + if program.allocator: + self.context.qreg_sizes[program.allocator.name] = program.allocator.capacity if program.allocator and program.allocator.parent is None: for i in range(program.allocator.capacity): self.context.get_qubit(program.allocator.name, i) @@ -249,14 +295,51 @@ def _process_statement(self, stmt: Statement) -> None: self._process_parallel(stmt) elif isinstance(stmt, PermuteOp): self._process_permute(stmt) + elif isinstance(stmt, PrintOp): + # Classical-output streaming is unimplemented in the Stim + # backend. Silently dropping it loses observable program + # output -- fail LOUD (same decision as the QIR backend). + msg = ( + "Stim codegen does not support Print (classical output " + "streaming is unimplemented; silently dropping it would " + "lose observable program output)." + ) + raise NotImplementedError(msg) # Other statement types (Comment, Assign, Return) don't generate Stim output def _process_gate(self, node: GateOp) -> None: """Process a gate operation.""" stim_gate = GATE_TO_STIM.get(node.gate) - if stim_gate is None: - # Skip unsupported gates + if stim_gate is None and node.gate in _GATE_DECOMP: + # A Clifford gate with no direct Stim primitive but a + # verified decomposition into Stim-native primitives. + # Emit each step in circuit order. + qubits = [self.context.get_qubit(t.allocator, t.index) for t in node.targets] + for prim_name, idxs in _GATE_DECOMP[node.gate]: + prim_qubits = [qubits[i] for i in idxs] + self.circuit.append_operation(prim_name, prim_qubits) return + if stim_gate is None: + # A gate with no GATE_TO_STIM entry was SILENTLY DROPPED + # -- the emitted Stim circuit ran but with wrong + # semantics (a silent miscompile, + # uncatchable downstream). Fail loud instead. Stim is + # Clifford-only, so non-Clifford rotations + # (RX/RY/RZ/RZZ/CR*) are fundamentally unrepresentable + # here; CH is non-Clifford too. Gates that are Clifford + # but lack a direct Stim primitive get a verified + # `_GATE_DECOMP` entry above; anything that reaches this + # raise has no representable form. + gate_name = getattr(node.gate, "name", node.gate) + msg = ( + f"Stim codegen: gate {gate_name!r} has no Stim lowering " + "(not in GATE_TO_STIM, no Clifford decomposition in " + "_GATE_DECOMP). Emitting the circuit without it would be " + "a silent miscompile; it is not supported by the Stim " + "backend (non-Clifford gates like CH, CR*, continuous " + "rotations are fundamentally unrepresentable here)." + ) + raise NotImplementedError(msg) if node.gate in TWO_QUBIT_GATES: self._process_two_qubit_gate(node, stim_gate) @@ -300,12 +383,18 @@ def _process_measure(self, node: MeasureOp) -> None: self.context.measurement_count += len(qubits) def _process_prepare(self, node: PrepareOp) -> None: - """Process a prepare/reset operation.""" + """Process a prepare/reset operation (Z-reset + canonical basis tail).""" + tail = prep_tail(node.basis) if node.slots is None: + if tail: + msg = f"Stim codegen: prepare_all with non-PZ basis {node.basis!r} is not supported" + raise NotImplementedError(msg) return qubits = [self.context.get_qubit(node.allocator, slot) for slot in node.slots] self.circuit.append_operation("R", qubits) + for gk in tail: + self.circuit.append_operation(GATE_TO_STIM[gk], qubits) def _process_barrier(self) -> None: """Process a barrier as TICK.""" @@ -326,22 +415,49 @@ def _process_if(self, node: IfStmt) -> None: self._process_statement(stmt) def _process_while(self, node: WhileStmt) -> None: - """Process a while loop.""" - # Stim doesn't support runtime loops - process body once - self.circuit.append("TICK") - for stmt in node.body: - self._process_statement(stmt) + """`While` is not supported by the Stim backend. + + Stim has no runtime loop. The previous "process body once + TICK" + silently dropped the loop condition and all iterations -- a + miscompile. Fail LOUD instead (same decision as the QIR backend; + real While is out of scope). + """ + _ = node + msg = ( + "Stim codegen does not support While loops (Stim has no " + "runtime loop; a single-pass approximation would be a silent " + "miscompile)." + ) + raise NotImplementedError(msg) + + def _static_int_bound(self, expr: object, which: str) -> int: + """Resolve a static integer `For` bound. + + The AST converter wraps integer range bounds in `LiteralExpr` + (`converter.py` `_convert_for`), so the bound is never a raw + `int` -- the old `isinstance(int)` guard was always false and + silently dropped every `For` body. A non-literal / non-int + bound is a symbolic/dynamic `For`: fail LOUD, never drop. + """ + if isinstance(expr, LiteralExpr) and isinstance(expr.value, int) and not isinstance(expr.value, bool): + return expr.value + msg = ( + f"Stim codegen: For loop {which} bound is not a static integer " + f"({type(expr).__name__}); only fixed-bound `For(i, , " + ")` is supported (symbolic/dynamic For is out of scope -- " + "and must not silently drop the loop body)." + ) + raise NotImplementedError(msg) def _process_for(self, node: ForStmt) -> None: - """Process a for loop.""" - # Try to unroll if bounds are static - if isinstance(node.start, int) and isinstance(node.stop, int): - step = node.step if isinstance(node.step, int) else 1 - for _ in range(node.start, node.stop, step): - for stmt in node.body: - self._process_statement(stmt) - else: - # Can't unroll - process body once + """Unroll a static fixed-bound `For` (v1-supported).""" + start = self._static_int_bound(node.start, "start") + stop = self._static_int_bound(node.stop, "stop") + step = 1 if node.step is None else self._static_int_bound(node.step, "step") + if step == 0: + msg = "Stim codegen: For loop step is 0 (infinite loop); only a non-zero static step is supported." + raise NotImplementedError(msg) + for _ in range(start, stop, step): for stmt in node.body: self._process_statement(stmt) @@ -372,21 +488,68 @@ def _process_parallel(self, node: ParallelBlock) -> None: for stmt in node.body: self._process_statement(stmt) - def _process_permute(self, node: PermuteOp) -> None: - """Process a permutation operation. + def _expand_permute_ref(self, ref: str) -> list[tuple[str, int]]: + """Expand a Permute ref string to logical (reg, index) pairs. - Updates the internal allocator mapping to swap qubit references. - Stim doesn't have a permute instruction, so this just updates - how we map allocator names to qubit indices. + `name[idx]` -> a single element; bare `name` -> every element + of the qubit register. Stim has no classical-register model, + so a bare CReg permute is not realizable -> fail loud (never + a silent no-op). Mirrors the QIR codegen's helper. + """ + if ref.endswith("]") and "[" in ref: + name, idx = ref[:-1].split("[", 1) + return [(name, int(idx))] + if ref in self.context.qreg_sizes: + return [(ref, i) for i in range(self.context.qreg_sizes[ref])] + msg = ( + f"Stim codegen: whole-register Permute of {ref!r} is not " + "supported (no classical-register model in Stim); a " + "qubit-register or element-wise Permute is realizable." + ) + raise NotImplementedError(msg) + + def _process_permute(self, node: PermuteOp) -> None: + """Realize a Permute as a static logical relabel. + + Stim has no permute instruction, so -- exactly like the QIR + codegen and the Guppy linearity tracker -- a Permute is + realized at compile time by relabelling which qubit each + logical (reg, index) ref resolves to (consulted in + `get_qubit`). The old `allocator_offsets` swap was a no-op + for element-wise refs and self-cancelling for a whole-register + (a,b)/(b,a) pair -- a silent miscompile. """ - # Swap the allocator mappings - for src, tgt in zip(node.sources, node.targets, strict=False): - # Get current offsets - src_offset = self.context.allocator_offsets.get(src, 0) - tgt_offset = self.context.allocator_offsets.get(tgt, 0) - # Swap them - self.context.allocator_offsets[src] = tgt_offset - self.context.allocator_offsets[tgt] = src_offset + if len(node.sources) != len(node.targets): + msg = "Stim codegen: Permute source/target length mismatch" + raise NotImplementedError(msg) + + # Validate the expanded ref lists BEFORE building the dict (a + # dict would silently collapse a duplicate source). + src_all: list[tuple[str, int]] = [] + tgt_all: list[tuple[str, int]] = [] + for source, target in zip(node.sources, node.targets, strict=True): + src_refs = self._expand_permute_ref(source) + tgt_refs = self._expand_permute_ref(target) + if len(src_refs) != len(tgt_refs): + msg = f"Stim codegen: Permute element count mismatch for {source!r} -> {target!r}" + raise NotImplementedError(msg) + src_all.extend(src_refs) + tgt_all.extend(tgt_refs) + + if len(src_all) != len(set(src_all)): + msg = "Stim codegen: Permute has a duplicate source ref (not a permutation)" + raise NotImplementedError(msg) + if len(tgt_all) != len(set(tgt_all)): + msg = "Stim codegen: Permute has a duplicate target ref (not a permutation)" + raise NotImplementedError(msg) + if set(src_all) != set(tgt_all): + msg = "Stim codegen: Permute must be bijective over the same ref set" + raise NotImplementedError(msg) + + # Compose ATOMICALLY (snapshot old, then map[s] = old.get(t, t)) + # so a whole-register (a,b)/(b,a) pair applies once. + old = dict(self.context.permutation_map) + self.context.permutation_map.update({s: old.get(t, t) for s, t in zip(src_all, tgt_all, strict=True)}) def ast_to_stim(program: Program) -> stim.Circuit: diff --git a/python/quantum-pecos/src/pecos/slr/ast/compare.py b/python/quantum-pecos/src/pecos/slr/ast/compare.py index 8e87c9146..b8d37e309 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/compare.py +++ b/python/quantum-pecos/src/pecos/slr/ast/compare.py @@ -143,6 +143,17 @@ def _compare_nodes(self, a: Any, b: Any) -> bool: return False return True + # Typed angle: a leaf compared by exact value (angle64 fraction + + # source unit). Recursing into its angle64 field would fail since + # angle64 is not an AST node / primitive handled below. + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + if isinstance(a, Angle): + if a != b: + self._add_diff(f"angle mismatch: {a!r} vs {b!r}") + return False + return True + # Handle enums if hasattr(a, "name") and hasattr(a, "value") and not is_dataclass(a): if a != b: diff --git a/python/quantum-pecos/src/pecos/slr/ast/converter.py b/python/quantum-pecos/src/pecos/slr/ast/converter.py index 1bacf6ab9..7e70e5400 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/converter.py +++ b/python/quantum-pecos/src/pecos/slr/ast/converter.py @@ -21,7 +21,7 @@ prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), ) @@ -33,7 +33,9 @@ from typing import TYPE_CHECKING, Any +from pecos.slr.ast._block_substitution import BodyRemap, substitute_stmt from pecos.slr.ast.nodes import ( + AllocatorArg, AllocatorDecl, ArrayTypeExpr, AssignOp, @@ -43,6 +45,9 @@ BitExpr, BitRef, BitTypeExpr, + BlockCall, + BlockDecl, + BlockInput, CommentOp, ForStmt, GateKind, @@ -53,11 +58,16 @@ ParallelBlock, PermuteOp, PrepareOp, + PrintOp, Program, + QubitBundleArg, QubitTypeExpr, RegisterDecl, RepeatStmt, + ResourceEffect, ReturnOp, + SingleBitArg, + SingleQubitArg, SlotRef, UnaryExpr, UnaryOp, @@ -67,6 +77,7 @@ if TYPE_CHECKING: from pecos.slr.ast.nodes import ( + BlockArg, Expression, Statement, TypeExpr, @@ -75,6 +86,18 @@ from pecos.slr.main import Main +# SLR prep-gate symbol -> canonical AST prep basis. The basis +# is the gate identity, not a string argument. (`Prep` was removed, +# hard-replaced repo-wide by `PZ`.) +_PREP_BASIS: dict[str, str] = { + "PZ": "PZ", + "PNZ": "PNZ", + "PX": "PX", + "PNX": "PNX", + "PY": "PY", + "PNY": "PNY", +} + # Mapping from SLR gate class names to AST GateKind GATE_KIND_MAP: dict[str, GateKind] = { # Single-qubit Paulis @@ -84,8 +107,8 @@ # Hadamard "H": GateKind.H, # Phase gates - "S": GateKind.S, - "Sdg": GateKind.Sdg, + "S": GateKind.SZ, + "Sdg": GateKind.SZdg, "T": GateKind.T, "Tdg": GateKind.Tdg, # Square root gates @@ -159,6 +182,15 @@ class SlrToAst: def __init__(self) -> None: """Initialize the converter.""" self._position = 0 # Track position for source locations + self._block_decls: list[BlockDecl] = [] # Hoisted BlockDecls accumulated during convert() + self._decl_counter = 0 + # True while converting the ops of a legacy-flattened composite + # Block. A `Return` inside such a block is a block-boundary qubit + # handoff (the qubits are already in linear scope by allocator name), + # NOT the root Main return -- elide it. Provenance-based, not + # position/count (a single final root ReturnOp can come from a qeclib + # block, e.g. `Main(q, EncodingCircuit(q))`). + self._in_flattened_block = False def convert(self, block: Main | Block) -> Program: """Convert an SLR Main/Block to an AST Program. @@ -170,6 +202,8 @@ def convert(self, block: Main | Block) -> Program: An AST Program node. """ self._position = 0 + self._block_decls = [] + self._decl_counter = 0 # Get the block name name = getattr(block, "block_name", block.__class__.__name__) @@ -200,6 +234,7 @@ def convert(self, block: Main | Block) -> Program: body=tuple(body), returns=returns, allocator=allocator, + block_decls=tuple(self._block_decls), ) def _convert_declarations(self, block: Block) -> list: @@ -294,7 +329,6 @@ def _convert_var_to_declaration(self, var: Any): return RegisterDecl( name=var.sym, size=var.size, - is_result=getattr(var, "result", True), ) if var_class == "QAlloc": @@ -354,30 +388,234 @@ def _convert_statement(self, op: Any) -> Statement | None: return CommentOp(text=op.txt) if op_class == "Return": + if self._in_flattened_block: + # Elide block-boundary Return from a flattened composite. + return None return self._convert_return(op) if op_class == "Permute": return self._convert_permute(op) + if op_class == "Print": + return self._convert_print(op) + # Assignment operations if op_class == "SET": return self._convert_assignment(op) # Nested blocks (Block subclasses) if hasattr(op, "ops"): - # This is a nested block - flatten its statements into the parent - # We return a special marker that _convert_statements will handle - return ("__FLATTEN__", self._convert_statements(op.ops)) + # A Block subclass that declares `block_inputs` opts in to + # BlockDecl/BlockCall lowering. Without it, the legacy flatten path + # remains the default for all qeclib Blocks. + if _has_block_inputs(op): + return self._convert_block_call(op) + + # Legacy flatten path. Returns inside the flattened composite are + # block-boundary handoffs, not the root Main return: + # elide them via the in_flattened_block provenance flag, including + # inside nested If/Repeat/For/While/Parallel. Save/restore so + # nested composites and the root scope are handled correctly. + prev_in_flattened = self._in_flattened_block + self._in_flattened_block = True + try: + flattened = self._convert_statements(op.ops) + finally: + self._in_flattened_block = prev_in_flattened + return ("__FLATTEN__", flattened) return None + def _convert_block_call(self, block: Block) -> BlockCall: + """Build a BlockDecl + BlockCall pair from a Block subclass with `block_inputs`. + + Each call site emits a fresh BlockDecl (no dedup yet; dedup is + a later optimization once qeclib conversion lands). The BlockDecl + body is the block's `ops` with allocator names rewritten from the + outer-scope binding name to the input parameter name. + """ + # Concrete-isinstance detection (`.sym`/`.size` duck-typing + # misclassifies CReg as QReg). Local import keeps the + # SLR var classes off the module-import path (circular-safe, matching + # the other lazy `from pecos.slr...` imports in this file). + from pecos.slr.vars import ( # noqa: PLC0415 + Bit as SlrBit, + ) + from pecos.slr.vars import ( # noqa: PLC0415 + CReg, + QReg, + Qubit, + SymbolicElem, + ) + + inputs_spec = type(block).block_inputs # type: ignore[attr-defined] + cls_name = type(block).__name__ + remap = BodyRemap() + block_inputs: list[BlockInput] = [] + arg_bindings: list[BlockArg] = [] + out_bindings: list[BlockArg] = [] + scratch_inputs: list[str] = [] + + # Cross-input aliasing guard. The same concrete qubit slot (or + # classical bit) bound to two block-input positions -- + # e.g. `[q[0], q[0]]` or two single-qubit inputs both `q[0]` -- + # silently corrupts body substitution: Guppy rejects it on linearity + # while the QASM flatten path emits invalid `cx q[0], q[0];`. + # BodyRemap.add_slot/add_bit also reject this (defense-in-depth), but + # we check here first to give a Block-context error message. + seen_qubit: dict[tuple[str, int], str] = {} + seen_bit: dict[tuple[str, int], str] = {} + + def _claim_qubit(reg: str, index: int, *, here: str, owner: str) -> None: + prior = seen_qubit.get((reg, index)) + if prior is not None: + msg = ( + f"{here}: qubit {reg}[{index}] is also bound by input " + f"{prior!r}; a qubit cannot be aliased to two block-input " + "positions (no-cloning)" + ) + raise ValueError(msg) + seen_qubit[(reg, index)] = owner + + def _claim_bit(reg: str, index: int, *, here: str, owner: str) -> None: + prior = seen_bit.get((reg, index)) + if prior is not None: + msg = ( + f"{here}: bit {reg}[{index}] is also bound by input " + f"{prior!r}; the same outer bit cannot back two " + "block-input positions (lossy substitution)" + ) + raise ValueError(msg) + seen_bit[(reg, index)] = owner + + for input_name, effect_value in inputs_spec.items(): + if not hasattr(block, input_name): + msg = ( + f"Block {cls_name!r} declares input {input_name!r} but the " + f"instance does not bind self.{input_name} (set it in __init__ " + "after super().__init__())" + ) + raise ValueError(msg) + var = getattr(block, input_name) + where = f"Block {cls_name!r} input {input_name!r}" + effect = _normalize_effect(effect_value, where=where) + + type_expr: TypeExpr + arg: BlockArg + if isinstance(var, QReg): + type_expr = ArrayTypeExpr(element=QubitTypeExpr(), size=var.size) + arg = AllocatorArg(name=var.sym) + remap.add_whole_alloc(var.sym, input_name, var.size) + elif isinstance(var, Qubit): + type_expr = QubitTypeExpr() + _claim_qubit(var.reg.sym, var.index, here=where, owner=input_name) + arg = SingleQubitArg(slot=SlotRef(allocator=var.reg.sym, index=var.index)) + remap.add_slot((var.reg.sym, var.index), (input_name, 0)) + elif isinstance(var, SlrBit): + type_expr = BitTypeExpr() + _claim_bit(var.reg.sym, var.index, here=where, owner=input_name) + arg = SingleBitArg(bit=BitRef(register=var.reg.sym, index=var.index)) + remap.add_bit((var.reg.sym, var.index), (input_name, 0)) + elif isinstance(var, (list, tuple)): + if not var: + msg = f"{where}: empty {type(var).__name__} bundle is not supported" + raise ValueError(msg) + if all(isinstance(e, Qubit) for e in var): + slots = tuple(SlotRef(allocator=e.reg.sym, index=e.index) for e in var) + type_expr = ArrayTypeExpr(element=QubitTypeExpr(), size=len(var)) + arg = QubitBundleArg(slots=slots) + for k, e in enumerate(var): + _claim_qubit(e.reg.sym, e.index, here=where, owner=input_name) + remap.add_slot((e.reg.sym, e.index), (input_name, k)) + elif all(isinstance(e, SlrBit) for e in var): + msg = f"{where}: list[Bit] (classical bit bundle / BitBundleArg) is not yet supported" + raise ValueError(msg) + else: + msg = ( + f"{where}: a bundle must be all Qubit (or all Bit, deferred); " + f"got mixed/unsupported element types in {var!r}" + ) + raise ValueError(msg) + # Every unsupported-input-shape branch raises ValueError uniformly + # (the bundle branches above do too): the public rejection surface + # is a single exception type so callers catch one thing, and the + # rejection tests pin ValueError. TRY004 wants + # TypeError for the isinstance-guarded branches, but that would + # split the surface across two exception types -- noqa is the + # correct call here. + elif isinstance(var, SymbolicElem): + msg = ( + f"{where}: symbolic (loop-variable-indexed) Qubit/Bit is not " + "supported as a block input; pass a concrete element" + ) + raise ValueError(msg) # noqa: TRY004 + elif isinstance(var, CReg): + msg = ( + f"{where}: whole CReg input is not yet supported " + "(only single Bit via SingleBitArg); pass a single bit `c[i]`" + ) + raise ValueError(msg) # noqa: TRY004 + else: + msg = ( + f"{where}: unsupported binding type {type(var).__name__} " + f"({var!r}); supported: QReg, Qubit, list[Qubit], Bit" + ) + raise ValueError(msg) # noqa: TRY004 + + block_inputs.append(BlockInput(name=input_name, effect=effect, type_expr=type_expr)) + arg_bindings.append(arg) + # out_bindings mirror the arg BlockArg for each LIVE_PRESERVED input + # (the emitter's iter-5b cross-check requires arg == out, by value). + # SCRATCH is NOT live-preserved -> never in out_bindings (same as + # CONSUMED); it stays in arg_bindings so the 5e.2 cross-input alias + # guard still rejects a scratch slot aliased to another input. + if effect is ResourceEffect.LIVE_PRESERVED: + out_bindings.append(arg) + if effect is ResourceEffect.SCRATCH: + scratch_inputs.append(input_name) + + # Build the BlockDecl body by converting block.ops in a NEW sub-converter, + # then remapping outer -> param via the shared substitute_stmt. Share + # `_decl_counter` so nested converted Blocks get globally-unique names. + sub = SlrToAst() + sub._decl_counter = self._decl_counter # type: ignore[attr-defined] + sub_body = sub._convert_statements(block.ops) # type: ignore[attr-defined] + self._decl_counter = sub._decl_counter # type: ignore[attr-defined] + rewritten_body = tuple(substitute_stmt(stmt, remap) for stmt in sub_body) + + # A SCRATCH input is only sound if the block + # resets it before any other use and never touches it after the + # terminal measurement without re-Prepping. The Guppy lowering + # allocates the qubit internally on that assumption; validate + # it here so a mis-declared scratch input fails loudly at + # conversion rather than miscompiling. + for scratch_name in scratch_inputs: + _validate_scratch_input(rewritten_body, scratch_name, cls_name=cls_name) + + # Unique decl name per call site (dedup is a later optimization). + decl_name = f"{cls_name.lower()}_{self._decl_counter}" + self._decl_counter += 1 + + decl = BlockDecl(name=decl_name, inputs=tuple(block_inputs), body=rewritten_body) + # Hoist any block_decls emitted by the sub-converter (nested BlockCalls) first. + self._block_decls.extend(sub._block_decls) + self._block_decls.append(decl) + + return BlockCall( + callee=decl_name, + arg_bindings=tuple(arg_bindings), + out_bindings=tuple(out_bindings), + ) + def _convert_gate(self, gate: Any) -> Statement: """Convert an SLR gate to an AST GateOp, PrepareOp, or MeasureOp.""" gate_name = gate.sym - # Handle special operations - if gate_name == "Prep": - return self._convert_prep(gate) + # Handle special operations. All prep gates route through one + # path; the basis is the GATE IDENTITY + # (PZ/PNZ/PX/PNX/PY/PNY), never a string argument. + if gate_name in _PREP_BASIS: + return self._convert_prep(gate, basis=_PREP_BASIS[gate_name]) if gate_name == "Measure": return self._convert_measure(gate) @@ -430,12 +668,33 @@ def _convert_gate(self, gate: Any) -> Statement: return GateOp(gate=gate_kind, targets=targets, params=params) - def _convert_prep(self, gate: Any) -> Statement: - """Convert an SLR Prep gate to an AST PrepareOp or flattened list.""" + def _convert_prep(self, gate: Any, basis: str = "PZ") -> Statement: + """Convert an SLR prep gate to an AST PrepareOp. + + `basis` is the canonical eigenstate from the gate IDENTITY + (`PZ`/`PNZ`/`PX`/`PNX`/`PY`/`PNY`; `Prep` -> `PZ`), NOT + a string argument. + """ if not gate.qargs: msg = "Prep gate has no qubit arguments" raise ValueError(msg) + # A stray STRING qarg on ANY prep gate is rejected loudly. + # `_expand_qubit_args` silently drops strings, so a basis + # string (`PZ(q, "X")`, the legacy `Prep(q, "X")`, etc.) + # would otherwise be silently dropped and the gate lowered as + # its plain basis -- a miscompile under the prep-basis + # symmetry rule. The prep basis is the gate identity; pass NO string. + for arg in gate.qargs: + if isinstance(arg, str): + msg = ( + f"AST conversion: prep gate {gate.sym!r} got a stray " + f"string argument {arg!r}. The prep basis is the gate " + "identity, not an argument -- use PZ/PNZ/PX/PNX/PY/PNY " + "(a string would be silently dropped -- a miscompile)." + ) + raise NotImplementedError(msg) + # Expand full registers into individual qubits expanded_qargs = self._expand_qubit_args(gate.qargs) @@ -453,7 +712,7 @@ def _convert_prep(self, gate: Any) -> Statement: slots = tuple(q.index for q in expanded_qargs) - return PrepareOp(allocator=allocator, slots=slots) + return PrepareOp(allocator=allocator, slots=slots, basis=basis) def _convert_measure(self, gate: Any) -> MeasureOp: """Convert an SLR Measure gate to an AST MeasureOp.""" @@ -472,12 +731,13 @@ def _convert_measure(self, gate: Any) -> MeasureOp: def _expand_qubit_args(self, qargs: list) -> list: """Expand qubit arguments, converting full registers to individual qubits. - Filters out non-qubit arguments like strings (e.g., basis state "Z" in Prep). + Filters out non-qubit arguments like strings (a stray basis + string on a prep gate is rejected upstream in `_convert_prep`). """ expanded = [] for q in qargs: if isinstance(q, str): - # Skip string arguments (e.g., basis state in Prep) + # Skip non-qubit string arguments continue if isinstance(q, list): # This is a slice (list of qubits) - recursively expand @@ -594,18 +854,35 @@ def _convert_barrier(self, op: Any) -> BarrierOp: return BarrierOp(allocators=allocators) def _convert_return(self, op: Any) -> ReturnOp: - """Convert an SLR Return to an AST ReturnOp.""" + """Convert an SLR Return to an AST ReturnOp. + + Carries per-value provenance (`value_kinds`) derived from the + SLR object's type. A whole-register return flattens to its + bare `sym` name; a returned inline CReg can collide with a + declared QReg of the same name, which is undecidable + downstream by name alone (the name-collision bug). The + QReg/CReg distinction is known HERE (the real object), so it + is preserved instead of guessed in codegen. + """ + from pecos.slr.vars import QReg # noqa: PLC0415 + values: list = [] + kinds: list[str] = [] for var in op.return_vars: if isinstance(var, str): + # Raw user-supplied name (provenance unknown); the + # fail-loud-safe default is "classical" so a backend + # validates it as a declared classical register. values.append(var) + kinds.append("classical") elif hasattr(var, "sym"): values.append(var.sym) + kinds.append("quantum" if isinstance(var, QReg) else "classical") else: - # Try to convert as expression values.append(self._convert_expression(var)) + kinds.append("expr") - return ReturnOp(values=tuple(values)) + return ReturnOp(values=tuple(values), value_kinds=tuple(kinds)) def _convert_permute(self, op: Any) -> PermuteOp: """Convert an SLR Permute to an AST PermuteOp.""" @@ -627,6 +904,7 @@ def _convert_permute(self, op: Any) -> PermuteOp: sources=(elems_i.sym, elems_f.sym), targets=(elems_f.sym, elems_i.sym), add_comment=add_comment, + whole_register=True, ) # Extract register/allocator names from sources (elems_i) @@ -669,6 +947,27 @@ def _convert_permute(self, op: Any) -> PermuteOp: add_comment=add_comment, ) + def _convert_print(self, op: Any) -> PrintOp: + """Convert an SLR Print to an AST PrintOp. + + SLR-side already validated tag/namespace characters, derived the tag + from the value's name when omitted, and rejected non-CReg/Bit values. + AST conversion lowers the value into AST shape: a `BitRef` for + Bit values, or a CReg name string for whole-register Print. + """ + value: BitRef | str + if hasattr(op.value, "reg") and hasattr(op.value, "index"): + # Bit reference (e.g., c[0]). + value = self._convert_bit_ref(op.value) + elif hasattr(op.value, "sym"): + # Whole CReg. + value = op.value.sym + else: + msg = f"Print value must be a CReg or Bit; got {type(op.value).__name__}" + raise TypeError(msg) + + return PrintOp(value=value, tag=op.tag, namespace=op.namespace) + def _convert_assignment(self, op: Any) -> AssignOp: """Convert an SLR SET operation to an AST AssignOp.""" # Target @@ -691,6 +990,16 @@ def _convert_expression(self, expr: Any) -> Expression: if expr is None: return LiteralExpr(value=0) + # Typed angle (rotation-gate parameter). Stored verbatim in the + # literal so pretty-print can round-trip the source unit and each + # backend can unwrap to the underlying `angle64`. Must come before + # the generic `hasattr(expr, "value")` fallback, which would strip + # the wrapper down to the bare `angle64` and lose the unit. + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + if isinstance(expr, Angle): + return LiteralExpr(value=expr) + # Literal values if isinstance(expr, bool | int | float): return LiteralExpr(value=expr) @@ -752,3 +1061,184 @@ def slr_to_ast(block: Main | Block) -> Program: """ converter = SlrToAst() return converter.convert(block) + + +# === BlockDecl/BlockCall helpers === + +_EFFECT_NAME_MAP: dict[str, ResourceEffect] = { + "live_preserved": ResourceEffect.LIVE_PRESERVED, + "consumed": ResourceEffect.CONSUMED, + "produced": ResourceEffect.PRODUCED, + "dropped": ResourceEffect.DROPPED, + "scratch": ResourceEffect.SCRATCH, +} + + +def _has_block_inputs(block: Block) -> bool: + return isinstance(getattr(type(block), "block_inputs", None), dict) + + +def _normalize_effect(value: Any, *, where: str) -> ResourceEffect: + """Accept either a ResourceEffect enum or its lowercased string name.""" + if isinstance(value, ResourceEffect): + return value + if isinstance(value, str): + normalized = value.lower() + if normalized in _EFFECT_NAME_MAP: + return _EFFECT_NAME_MAP[normalized] + msg = ( + f"{where}: effect {value!r} is not a valid ResourceEffect. " + f"Expected one of: ResourceEffect enum, or strings {sorted(_EFFECT_NAME_MAP)}" + ) + raise ValueError(msg) + + +def _ref_base_name(ref: str) -> str: + """Leading identifier of a PermuteOp ref string (`q[0]`/`q`/`q.x` -> `q`).""" + return ref.split("[", 1)[0].split(".", 1)[0] + + +def _scratch_events(stmt: Statement, name: str) -> list[str]: + """Ordered scratch-lifecycle events for allocator `name` in `stmt`. + + Returns a list drawn from PREP / MEASURE / USE in execution order. + A reference inside control flow / Parallel / a nested BlockCall / + PermuteOp cannot be linearized for the reset-first analysis, so + it is reported as the sentinel UNSUPPORTED and the validator rejects + (scratch inputs must have a flat Prep -> ... -> Measure + lifecycle). + """ + if isinstance(stmt, CommentOp): + return [] + if isinstance(stmt, PrepareOp): + return ["PREP"] if stmt.allocator == name else [] + if isinstance(stmt, MeasureOp): + return ["MEASURE"] if any(t.allocator == name for t in stmt.targets) else [] + if isinstance(stmt, GateOp): + return ["USE"] if any(t.allocator == name for t in stmt.targets) else [] + if isinstance(stmt, BarrierOp): + return ["USE"] if name in stmt.allocators else [] + if isinstance(stmt, PermuteOp): + refs = (*stmt.sources, *stmt.targets) + return ["UNSUPPORTED"] if any(_ref_base_name(r) == name for r in refs) else [] + if isinstance(stmt, BlockCall): + for arg in (*stmt.arg_bindings, *stmt.out_bindings): + if _block_arg_mentions_allocator(arg, name): + return ["UNSUPPORTED"] + return [] + if isinstance(stmt, (IfStmt, WhileStmt, ForStmt, RepeatStmt, ParallelBlock)): + bodies: list[tuple[Statement, ...]] = ( + [stmt.then_body, stmt.else_body] if isinstance(stmt, IfStmt) else [stmt.body] + ) + for b in bodies: + for inner in b: + if _scratch_events(inner, name): + return ["UNSUPPORTED"] + return [] + if isinstance(stmt, ReturnOp): + # A scratch input must never be handed back to the caller: the + # Guppy lowering allocates it internally and does not thread it through caller + # state, so returning it would diverge from the flatten/QASM path + # Detection cannot be precise here -- the + # substitution leaves Return values as the OUTER name (a partial + # VarExpr passes through `whole_name` unchanged), so a returned + # scratch slot is indistinguishable from a returned classical + # value at this point. Conservatively reject ANY ReturnOp in a + # scratch-bearing block (in-scope blocks like `Check` + # have no Return; relax deliberately in a later stage if needed). + return ["UNSUPPORTED"] + # Anything else carries no qubit-slot ref. AssignOp/PrintOp operate on + # classical bit/int expressions only -- a qubit scratch slot cannot + # appear there -- so they are irrelevant to a qubit scratch lifecycle. + return [] + + +def _block_arg_mentions_allocator(arg: BlockArg, name: str) -> bool: + """True if a nested BlockCall arg references qubit allocator `name`.""" + slot = getattr(arg, "slot", None) + if slot is not None and getattr(slot, "allocator", None) == name: + return True + if getattr(arg, "name", None) == name: # AllocatorArg(name=...) + return True + slots = getattr(arg, "slots", None) + return bool(slots) and any(getattr(s, "allocator", None) == name for s in slots) + + +def _validate_scratch_input( + body: tuple[Statement, ...], + name: str, + *, + cls_name: str, +) -> None: + """Enforce the reset-first rule for a SCRATCH input. + + The block must reset `name` before any other use, and every Prep + lifecycle must be closed by a Measure before the next Prep and + before the body ends (the Guppy lowering allocates the scratch + qubit internally, so an unmeasured trailing Prep/use would diverge + from the flatten/QASM path). Anything the + linearization analysis cannot handle (control flow / Parallel / nested + BlockCall / Permute over the scratch slot) is rejected loudly -- + silently allowing it would let the Guppy internal-allocation + lowering miscompile. + """ + events: list[str] = [] + for stmt in body: + events.extend(_scratch_events(stmt, name)) + + where = f"Block {cls_name!r} scratch input {name!r}" + if "UNSUPPORTED" in events: + msg = ( + f"{where}: scratch inputs must have a flat PZ -> ... -> " + "Measure lifecycle; referencing the scratch slot inside " + "control flow, Parallel, a nested BlockCall, or a Permute -- " + "or any ReturnOp in a scratch-bearing block -- is not " + "supported" + ) + raise ValueError(msg) + if not events: + msg = f"{where}: declared SCRATCH but never used in the block body (expected PZ(...) then Measure(...))" + raise ValueError(msg) + if events[0] != "PREP": + msg = ( + f"{where}: first use is {events[0]}, but a SCRATCH input must " + "be reset (PZ) before any other use (reset-first)" + ) + raise ValueError(msg) + # Segment state machine: a PREP opens a lifecycle that MUST close with a + # MEASURE before the next PREP and before the body ends. The Guppy + # lowering allocates the scratch qubit internally; an unmeasured trailing + # Prep/use would leave the flatten/QASM path with a live reset outer slot + # while Guppy silently drops it -- a semantic divergence. + # So every Prep must be terminated by a Measure. + open_prep = False + for ev in events: + if ev == "PREP": + if open_prep: + msg = ( + f"{where}: re-Prepped before measuring the previous " + "scratch lifecycle (every PZ must be closed by a " + "Measure)" + ) + raise ValueError(msg) + open_prep = True + elif ev == "MEASURE": + if not open_prep: + msg = ( + f"{where}: measured without an open PZ (measure " + "before PZ, or measured again without an " + "intervening PZ)" + ) + raise ValueError(msg) + open_prep = False + else: # USE + if not open_prep: + msg = f"{where}: used after measurement without re-PZ (scratch lifecycle must be PZ -> use -> Measure)" + raise ValueError(msg) + if open_prep: + msg = ( + f"{where}: scratch lifecycle not closed -- a Prep has no " + "terminating Measure; the block must measure the scratch " + "qubit before returning (allocated internally)" + ) + raise ValueError(msg) diff --git a/python/quantum-pecos/src/pecos/slr/ast/nodes.py b/python/quantum-pecos/src/pecos/slr/ast/nodes.py index f79d1013c..2cd89fa28 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/nodes.py +++ b/python/quantum-pecos/src/pecos/slr/ast/nodes.py @@ -24,7 +24,6 @@ from __future__ import annotations -from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum, auto from typing import TYPE_CHECKING @@ -32,7 +31,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pecos.slr.ast.visitor import AstVisitor, T + from pecos.slr.angle import Angle # ============================================================================= @@ -63,22 +62,17 @@ def __str__(self) -> str: @dataclass(frozen=True, kw_only=True) -class AstNode(ABC): +class AstNode: """Base class for all AST nodes. All AST nodes are immutable frozen dataclasses that support: - - Visitor pattern via accept() + - Visitor traversal via `BaseVisitor` (centralized dispatch) - Child traversal via children() - Optional source location tracking """ location: SourceLocation | None = field(default=None, compare=False, repr=False) - @abstractmethod - def accept(self, visitor: AstVisitor[T]) -> T: - """Accept a visitor for traversal.""" - ... - def children(self) -> Sequence[AstNode]: """Return child nodes for traversal. Override in subclasses.""" return () @@ -101,8 +95,6 @@ class GateKind(Enum): H = auto() # Phase gates - S = auto() - Sdg = auto() T = auto() Tdg = auto() @@ -214,6 +206,21 @@ class UnaryOp(Enum): NEG = auto() +class ResourceEffect(Enum): + """Effect declared by a `BlockDecl` input on the outer scope's binding.""" + + LIVE_PRESERVED = auto() # caller binding survives the call unchanged + CONSUMED = auto() # caller binding is invalidated by the call + PRODUCED = auto() # callee writes; caller's binding is rebound from return + DROPPED = auto() # callee discards; caller binding is invalidated + # Reset-reused scratch ancilla: the block resets the input at entry and + # measures it at exit, depending on no incoming state. The caller's slot is + # a flatten-path naming vehicle only; in Guppy the block allocates the qubit + # internally so the same outer slot can feed a subsequent BlockCall (the + # `consumed` model would kill it). + SCRATCH = auto() + + # ============================================================================= # References # ============================================================================= @@ -230,9 +237,6 @@ class SlotRef(AstNode): allocator: str index: int - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_slot_ref(self) - def __str__(self) -> str: return f"{self.allocator}[{self.index}]" @@ -244,9 +248,6 @@ class BitRef(AstNode): register: str index: int - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_bit_ref(self) - def __str__(self) -> str: return f"{self.register}[{self.index}]" @@ -257,18 +258,15 @@ def __str__(self) -> str: @dataclass(frozen=True, kw_only=True) -class Expression(AstNode, ABC): +class Expression(AstNode): """Base class for all expressions.""" @dataclass(frozen=True, kw_only=True) class LiteralExpr(Expression): - """Literal value (int, float, bool).""" - - value: int | float | bool + """Literal value (int, float, bool, or a typed rotation `Angle`).""" - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_literal(self) + value: int | float | bool | Angle @dataclass(frozen=True, kw_only=True) @@ -277,9 +275,6 @@ class VarExpr(Expression): name: str - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_var(self) - @dataclass(frozen=True, kw_only=True) class BitExpr(Expression): @@ -287,9 +282,6 @@ class BitExpr(Expression): ref: BitRef - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_bit_expr(self) - def children(self) -> Sequence[AstNode]: return (self.ref,) @@ -302,9 +294,6 @@ class BinaryExpr(Expression): left: Expression right: Expression - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_binary(self) - def children(self) -> Sequence[AstNode]: return (self.left, self.right) @@ -316,9 +305,6 @@ class UnaryExpr(Expression): op: UnaryOp operand: Expression - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_unary(self) - def children(self) -> Sequence[AstNode]: return (self.operand,) @@ -329,7 +315,7 @@ def children(self) -> Sequence[AstNode]: @dataclass(frozen=True, kw_only=True) -class TypeExpr(AstNode, ABC): +class TypeExpr(AstNode): """Base class for type expressions.""" @@ -337,17 +323,11 @@ class TypeExpr(AstNode, ABC): class QubitTypeExpr(TypeExpr): """Single qubit type.""" - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_qubit_type(self) - @dataclass(frozen=True, kw_only=True) class BitTypeExpr(TypeExpr): """Single classical bit type.""" - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_bit_type(self) - @dataclass(frozen=True, kw_only=True) class ArrayTypeExpr(TypeExpr): @@ -356,9 +336,6 @@ class ArrayTypeExpr(TypeExpr): element: TypeExpr size: int - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_array_type(self) - def children(self) -> Sequence[AstNode]: return (self.element,) @@ -369,9 +346,6 @@ class AllocatorTypeExpr(TypeExpr): capacity: int - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_allocator_type(self) - # ============================================================================= # Declarations @@ -379,7 +353,7 @@ def accept(self, visitor: AstVisitor[T]) -> T: @dataclass(frozen=True, kw_only=True) -class Declaration(AstNode, ABC): +class Declaration(AstNode): """Base class for all declarations.""" @@ -395,9 +369,6 @@ class AllocatorDecl(Declaration): capacity: int parent: str | None = None # Name of parent allocator - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_allocator_decl(self) - @dataclass(frozen=True, kw_only=True) class RegisterDecl(Declaration): @@ -405,10 +376,6 @@ class RegisterDecl(Declaration): name: str size: int - is_result: bool = True - - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_register_decl(self) # ============================================================================= @@ -417,7 +384,7 @@ def accept(self, visitor: AstVisitor[T]) -> T: @dataclass(frozen=True, kw_only=True) -class Statement(AstNode, ABC): +class Statement(AstNode): """Base class for all statements.""" @@ -429,9 +396,6 @@ class GateOp(Statement): targets: tuple[SlotRef, ...] params: tuple[Expression, ...] = () - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_gate(self) - def children(self) -> Sequence[AstNode]: return (*self.targets, *self.params) @@ -446,9 +410,14 @@ class PrepareOp(Statement): allocator: str slots: tuple[int, ...] | None = None # None means prepare_all - - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_prepare(self) + # Canonical prep basis / target eigenstate, one of + # {PZ,PNZ,PX,PNX,PY,PNY} (|0>,|1>,|+>,|->,|+i>,|-i>). Set by + # `_convert_prep` from the SLR gate symbol (PZ default). Carried + # on the AST so codegens lower the correct reset+Clifford tail; + # MUST be preserved through block substitution (else a non-PZ + # prep inside a BlockCall body silently reverts to PZ -- a + # soundness-critical case). + basis: str = "PZ" @dataclass(frozen=True, kw_only=True) @@ -461,9 +430,6 @@ class MeasureOp(Statement): targets: tuple[SlotRef, ...] results: tuple[BitRef, ...] = () - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_measure(self) - def children(self) -> Sequence[AstNode]: return (*self.targets, *self.results) @@ -475,9 +441,6 @@ class AssignOp(Statement): target: BitRef | str # Variable name or bit reference value: Expression - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_assign(self) - def children(self) -> Sequence[AstNode]: nodes: list[AstNode] = [] if isinstance(self.target, BitRef): @@ -492,9 +455,6 @@ class BarrierOp(Statement): allocators: tuple[str, ...] = () - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_barrier(self) - @dataclass(frozen=True, kw_only=True) class CommentOp(Statement): @@ -502,23 +462,52 @@ class CommentOp(Statement): text: str - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_comment(self) - @dataclass(frozen=True, kw_only=True) class ReturnOp(Statement): """Return statement.""" values: tuple[Expression | str, ...] = () # Can be variable names - - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_return(self) + # Per-value provenance, parallel to `values`: "quantum" (a QReg -- + # no classical record), "classical" (a CReg -- must be recorded), + # or "expr". Set by `_convert_return` from the SLR object type; + # `()` means unknown (e.g. a directly-constructed ReturnOp), in + # which case a backend treats a bare-name value as classical + # (the fail-loud-safe default). A bare `values` string cannot be + # disambiguated CReg-vs-QReg by name alone (a returned inline + # CReg can collide with a declared QReg name -- the + # name-collision bug), so provenance is carried here, not guessed. + value_kinds: tuple[str, ...] = () def children(self) -> Sequence[AstNode]: return tuple(v for v in self.values if isinstance(v, AstNode)) +@dataclass(frozen=True, kw_only=True) +class PrintOp(Statement): + """Emit an intermediate streamed value at the call site. + + Lowers to Guppy's `result(name, value)`. Scope-orthogonal side-effect: + does not allocate, does not modify the result-register set used for + return-shape computation. + + `value` is either a `BitRef` (single bit) or a register name string + (whole-CReg emission). `tag` is the resolved tag (the SLR-side + conversion derives the default from the value's name when the user + did not pass `tag=` explicitly). `namespace` is the tag prefix; the + full emitted Guppy tag is `f"{namespace}.{tag}"`. + """ + + value: BitRef | str + tag: str + namespace: str = "result" + + def children(self) -> Sequence[AstNode]: + if isinstance(self.value, AstNode): + return (self.value,) + return () + + @dataclass(frozen=True, kw_only=True) class PermuteOp(Statement): """Permute qubit register assignments. @@ -533,9 +522,11 @@ class PermuteOp(Statement): sources: tuple[str, ...] # Initial register/allocator names targets: tuple[str, ...] # Final register/allocator names add_comment: bool = True # Whether to add a comment in generated code - - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_permute(self) + # Whole-register swap `Permute(a, b)` -> comment `; Permutation: + # a <-> b`; else per-element `; Permutation: a[0] -> b[1], ...` + # (the comment is rendered at codegen from the post-substitution + # sources/targets, mirroring the legacy gen_qir format). + whole_register: bool = False # ============================================================================= @@ -551,9 +542,6 @@ class IfStmt(Statement): then_body: tuple[Statement, ...] else_body: tuple[Statement, ...] = () - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_if(self) - def children(self) -> Sequence[AstNode]: return (self.condition, *self.then_body, *self.else_body) @@ -565,9 +553,6 @@ class WhileStmt(Statement): condition: Expression body: tuple[Statement, ...] - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_while(self) - def children(self) -> Sequence[AstNode]: return (self.condition, *self.body) @@ -582,9 +567,6 @@ class ForStmt(Statement): step: Expression | None = None body: tuple[Statement, ...] = () - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_for(self) - def children(self) -> Sequence[AstNode]: nodes: list[AstNode] = [self.start, self.stop] if self.step: @@ -600,9 +582,6 @@ class RepeatStmt(Statement): count: int body: tuple[Statement, ...] - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_repeat(self) - def children(self) -> Sequence[AstNode]: return self.body @@ -613,13 +592,126 @@ class ParallelBlock(Statement): body: tuple[Statement, ...] - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_parallel(self) - def children(self) -> Sequence[AstNode]: return self.body +# ============================================================================= +# Reusable block declarations +# ============================================================================= + + +@dataclass(frozen=True, kw_only=True) +class BlockInput(AstNode): + """One declared input parameter to a `BlockDecl`.""" + + name: str + effect: ResourceEffect + type_expr: TypeExpr + + def children(self) -> Sequence[AstNode]: + return (self.type_expr,) + + +@dataclass(frozen=True, kw_only=True) +class BlockDecl(AstNode): + """Reusable Block declaration that lowers to a top-level Guppy function. + + Non-Guppy codegens inline `body` at every `BlockCall` site. + """ + + name: str + inputs: tuple[BlockInput, ...] + body: tuple[Statement, ...] + return_op: ReturnOp | None = None + + def children(self) -> Sequence[AstNode]: + nodes: list[AstNode] = list(self.inputs) + nodes.extend(self.body) + if self.return_op is not None: + nodes.append(self.return_op) + return nodes + + +# ---- BlockCall argument types (typed sum type) ---- + + +@dataclass(frozen=True, kw_only=True) +class BlockArg(AstNode): + """Base class for `BlockCall` argument bindings. + + Each BlockInput on the callee is bound to exactly one BlockArg at the + caller, describing what outer-scope state the input refers to. + """ + + +@dataclass(frozen=True, kw_only=True) +class AllocatorArg(BlockArg): + """Whole-allocator binding: every slot of an outer-scope allocator.""" + + name: str + + +@dataclass(frozen=True, kw_only=True) +class SingleQubitArg(BlockArg): + """Single-qubit slot binding.""" + + slot: SlotRef + + def children(self) -> Sequence[AstNode]: + return (self.slot,) + + +@dataclass(frozen=True, kw_only=True) +class SingleBitArg(BlockArg): + """Single classical-bit binding (write-back via array[bool, 1] proxy in emitter).""" + + bit: BitRef + + def children(self) -> Sequence[AstNode]: + return (self.bit,) + + +@dataclass(frozen=True, kw_only=True) +class QubitBundleArg(BlockArg): + """Non-contiguous bundle of qubit slots packed into a single array[qubit, N].""" + + slots: tuple[SlotRef, ...] + + def children(self) -> Sequence[AstNode]: + return self.slots + + +@dataclass(frozen=True, kw_only=True) +class BitBundleArg(BlockArg): + """Non-contiguous bundle of classical bits packed into a single array[bool, N].""" + + bits: tuple[BitRef, ...] + + def children(self) -> Sequence[AstNode]: + return self.bits + + +@dataclass(frozen=True, kw_only=True) +class BlockCall(Statement): + """Invoke a `BlockDecl` from the outer scope. + + `arg_bindings` lists outer-scope bindings (typed `BlockArg`) in the same + order as the callee's declared inputs (one per input). + `out_bindings` lists outer-scope bindings that receive the callee's + outputs (`live_preserved`/`produced` inputs + explicit `Return` values, + in declaration order then return order). Empty for callees that return + nothing. + """ + + callee: str + arg_bindings: tuple[BlockArg, ...] + out_bindings: tuple[BlockArg, ...] = () + + def children(self) -> Sequence[AstNode]: + return (*self.arg_bindings, *self.out_bindings) + + # ============================================================================= # Program # ============================================================================= @@ -641,19 +733,25 @@ class Program(AstNode): body: tuple[Statement, ...] = () returns: tuple[TypeExpr, ...] = () allocator: AllocatorDecl | None = None # Base allocator - - def accept(self, visitor: AstVisitor[T]) -> T: - return visitor.visit_program(self) + block_decls: tuple[BlockDecl, ...] = () def children(self) -> Sequence[AstNode]: nodes: list[AstNode] = [] if self.allocator: nodes.append(self.allocator) nodes.extend(self.declarations) + nodes.extend(self.block_decls) nodes.extend(self.body) nodes.extend(self.returns) return nodes + def get_block_decl(self, name: str) -> BlockDecl | None: + """Find a BlockDecl by name.""" + for decl in self.block_decls: + if decl.name == name: + return decl + return None + def get_allocator(self, name: str) -> AllocatorDecl | None: """Find an allocator declaration by name.""" if self.allocator and self.allocator.name == name: diff --git a/python/quantum-pecos/src/pecos/slr/ast/optimizations/base.py b/python/quantum-pecos/src/pecos/slr/ast/optimizations/base.py index 24d9b5798..6e4dda800 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/optimizations/base.py +++ b/python/quantum-pecos/src/pecos/slr/ast/optimizations/base.py @@ -111,6 +111,7 @@ def optimize(self, program: Program) -> OptimizationResult: body=optimized_body, returns=program.returns, allocator=program.allocator, + block_decls=program.block_decls, location=program.location, ) diff --git a/python/quantum-pecos/src/pecos/slr/ast/optimizations/gate_properties.py b/python/quantum-pecos/src/pecos/slr/ast/optimizations/gate_properties.py index 659ec0951..264d497d0 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/optimizations/gate_properties.py +++ b/python/quantum-pecos/src/pecos/slr/ast/optimizations/gate_properties.py @@ -40,8 +40,6 @@ # Mapping from gate to its inverse INVERSE_PAIRS: dict[GateKind, GateKind] = { - GateKind.S: GateKind.Sdg, - GateKind.Sdg: GateKind.S, GateKind.T: GateKind.Tdg, GateKind.Tdg: GateKind.T, GateKind.SX: GateKind.SXdg, diff --git a/python/quantum-pecos/src/pecos/slr/ast/optimizations/identity_removal.py b/python/quantum-pecos/src/pecos/slr/ast/optimizations/identity_removal.py index fb2bb9f83..6b734915d 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/optimizations/identity_removal.py +++ b/python/quantum-pecos/src/pecos/slr/ast/optimizations/identity_removal.py @@ -74,6 +74,7 @@ def optimize(self, program: Program) -> OptimizationResult: body=optimized_body, returns=program.returns, allocator=program.allocator, + block_decls=program.block_decls, location=program.location, ) @@ -151,7 +152,15 @@ def _is_identity(self, gate: GateOp) -> bool: # Can't evaluate symbolic angles at compile time return False + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + value = angle.value + if isinstance(value, Angle): + # Use the signed-radians principal value with the same tolerance + # as the float path: fixed-point addition (e.g. rad(0.5) + + # rad(-0.5) from merging) leaves a sub-ULP residual, so an exact + # `fraction == 0` check would miss near-identity rotations. + return abs(value.value.to_radians_signed()) < self.IDENTITY_ANGLE_TOLERANCE if not isinstance(value, (int, float)): return False diff --git a/python/quantum-pecos/src/pecos/slr/ast/optimizations/rotation_merging.py b/python/quantum-pecos/src/pecos/slr/ast/optimizations/rotation_merging.py index 577c01e03..c3ed4ad16 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/optimizations/rotation_merging.py +++ b/python/quantum-pecos/src/pecos/slr/ast/optimizations/rotation_merging.py @@ -57,10 +57,10 @@ class RotationMergingPass(OptimizationPass): Example: # Before optimization - RX(0.5, q[0]), RX(0.3, q[0]) + RX(rad(0.5), q[0]), RX(rad(0.3), q[0]) # After optimization - RX(0.8, q[0]) + RX(rad(0.8), q[0]) """ @property @@ -77,6 +77,7 @@ def optimize(self, program: Program) -> OptimizationResult: body=optimized_body, returns=program.returns, allocator=program.allocator, + block_decls=program.block_decls, location=program.location, ) @@ -176,12 +177,23 @@ def _merge_rotations(self, gate1: GateOp, gate2: GateOp) -> GateOp: If both angles are literals, computes the sum at compile time. Otherwise, creates a symbolic BinaryExpr for the sum. """ + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + angle1 = gate1.params[0] angle2 = gate2.params[0] # Try to evaluate at compile time if both are literals if isinstance(angle1, LiteralExpr) and isinstance(angle2, LiteralExpr): - merged_angle = LiteralExpr(value=angle1.value + angle2.value) + v1, v2 = angle1.value, angle2.value + if isinstance(v1, Angle) and isinstance(v2, Angle): + # Sum in exact fixed-point (angle64 wraps mod a full turn, + # which is the correct rotation-composition semantics). + merged_angle = LiteralExpr(value=Angle(v1.value + v2.value, v1.source_unit)) + elif not isinstance(v1, Angle) and not isinstance(v2, Angle): + merged_angle = LiteralExpr(value=v1 + v2) + else: + # Mixed Angle/number (unexpected from SLR) -- keep symbolic. + merged_angle = BinaryExpr(op=BinaryOp.ADD, left=angle1, right=angle2) else: # Create symbolic sum merged_angle = BinaryExpr( diff --git a/python/quantum-pecos/src/pecos/slr/ast/pretty_print.py b/python/quantum-pecos/src/pecos/slr/ast/pretty_print.py index caeb88c36..256550a2b 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/pretty_print.py +++ b/python/quantum-pecos/src/pecos/slr/ast/pretty_print.py @@ -31,9 +31,12 @@ from pecos.slr.ast.nodes import ( AllocatorDecl, + ArrayTypeExpr, BinaryOp, BitRef, + BitTypeExpr, Expression, + QubitTypeExpr, RegisterDecl, UnaryOp, ) @@ -47,6 +50,10 @@ BarrierOp, BinaryExpr, BitExpr, + BlockArg, + BlockCall, + BlockDecl, + BlockInput, CommentOp, ForStmt, GateOp, @@ -56,11 +63,13 @@ ParallelBlock, PermuteOp, PrepareOp, + PrintOp, Program, RepeatStmt, ReturnOp, SlotRef, Statement, + TypeExpr, UnaryExpr, VarExpr, WhileStmt, @@ -134,7 +143,14 @@ def _with_indent(self) -> _IndentContext: def visit_program(self, node: Program) -> str: """Visit program node.""" - lines = ["Main("] + # BlockDecls are hoisted above the Main block in the rendering so the + # output reads top-down (defs first, call sites below). + lines: list[str] = [] + for decl in node.block_decls: + lines.append(self.visit_block_decl(decl)) + lines.append("") + + lines.append("Main(") self._level += 1 # Allocator @@ -186,7 +202,7 @@ def visit_register_decl(self, node: RegisterDecl) -> str: def format_statement(self, stmt: Statement) -> str: """Format a statement.""" - return stmt.accept(self) + return self.visit(stmt) def visit_gate(self, node: GateOp) -> str: """Visit gate operation.""" @@ -194,18 +210,24 @@ def visit_gate(self, node: GateOp) -> str: targets = ", ".join(self.visit_slot_ref(t) for t in node.targets) if node.params: + # Angle-first SLR API -- `qb.RX(theta, q)`, angles + # before qubits (not the legacy `qb.RX[theta](q)` bracket). params = ", ".join(self.format_expression(p) for p in node.params) - return f"qb.{gate_name}[{params}]({targets})" + sep = ", " if targets else "" + return f"qb.{gate_name}({params}{sep}{targets})" return f"qb.{gate_name}({targets})" def visit_prepare(self, node: PrepareOp) -> str: """Visit prepare operation.""" + # Default `PZ` stays byte-identical (no churn for the existing + # corpus); a non-PZ basis is shown so the dump is faithful. + b = "" if node.basis == "PZ" else f", basis={node.basis!r}" if node.slots is None: - return f"{node.allocator}.prepare_all()" + return f"{node.allocator}.prepare_all({node.basis!r})" if b else f"{node.allocator}.prepare_all()" if len(node.slots) == 1: - return f"{node.allocator}.prepare({node.slots[0]})" + return f"{node.allocator}.prepare({node.slots[0]}{b})" slots = ", ".join(str(s) for s in node.slots) - return f"{node.allocator}.prepare({slots})" + return f"{node.allocator}.prepare({slots}{b})" def visit_measure(self, node: MeasureOp) -> str: """Visit measure operation.""" @@ -245,6 +267,11 @@ def visit_permute(self, node: PermuteOp) -> str: targets = ", ".join(node.targets) return f"Permute([{sources}], [{targets}])" + def visit_print(self, node: PrintOp) -> str: + """Visit print.""" + value = node.value if isinstance(node.value, str) else f"{node.value.register}[{node.value.index}]" + return f'Print({value}, tag="{node.tag}", namespace="{node.namespace}")' + # Control flow def visit_if(self, node: IfStmt) -> str: @@ -316,6 +343,69 @@ def visit_parallel(self, node: ParallelBlock) -> str: lines.append(")") return "\n".join(lines) + # Reusable blocks + + def visit_block_decl(self, node: BlockDecl) -> str: + """Visit a BlockDecl, rendering it as a reusable function-like declaration.""" + inputs_str = ", ".join(self.visit_block_input(inp) for inp in node.inputs) + lines = [f'BlockDecl("{node.name}", inputs=[{inputs_str}]):'] + + self._level += 1 + if not node.body and node.return_op is None: + lines.append(self._indented("pass")) + else: + lines.extend(self._indented(f"{self.format_statement(stmt)},") for stmt in node.body) + if node.return_op is not None: + lines.append(self._indented(f"{self.format_statement(node.return_op)},")) + self._level -= 1 + return "\n".join(lines) + + def visit_block_input(self, node: BlockInput) -> str: + """Visit a BlockInput, rendering it as `name: type @ effect`.""" + return f"{node.name}: {self._format_type_expr(node.type_expr)} @ {node.effect.name.lower()}" + + def visit_block_call(self, node: BlockCall) -> str: + """Visit a BlockCall as a parenthesized invocation site.""" + args = ", ".join(self._format_block_arg(a) for a in node.arg_bindings) + if node.out_bindings: + outs = ", ".join(self._format_block_arg(a) for a in node.out_bindings) + return f"({outs}) = BlockCall({node.callee!r}, {args})" + return f"BlockCall({node.callee!r}, {args})" + + def _format_block_arg(self, arg: BlockArg) -> str: + """Render a BlockArg inline (for visit_block_call).""" + from pecos.slr.ast.nodes import ( # noqa: PLC0415 -- subclass dispatch + AllocatorArg, + BitBundleArg, + QubitBundleArg, + SingleBitArg, + SingleQubitArg, + ) + + if isinstance(arg, AllocatorArg): + return arg.name + if isinstance(arg, SingleQubitArg): + return self.visit_slot_ref(arg.slot) + if isinstance(arg, SingleBitArg): + return self.visit_bit_ref(arg.bit) + if isinstance(arg, QubitBundleArg): + inner = ", ".join(self.visit_slot_ref(s) for s in arg.slots) + return f"[{inner}]" + if isinstance(arg, BitBundleArg): + inner = ", ".join(self.visit_bit_ref(b) for b in arg.bits) + return f"[{inner}]" + return repr(arg) + + def _format_type_expr(self, type_expr: TypeExpr) -> str: + """Render a TypeExpr inline (BlockInput rendering uses this).""" + if isinstance(type_expr, ArrayTypeExpr): + return f"array[{self._format_type_expr(type_expr.element)}, {type_expr.size}]" + if isinstance(type_expr, QubitTypeExpr): + return "qubit" + if isinstance(type_expr, BitTypeExpr): + return "bit" + return self.visit(type_expr) + # References def visit_slot_ref(self, node: SlotRef) -> str: @@ -330,10 +420,15 @@ def visit_bit_ref(self, node: BitRef) -> str: def format_expression(self, expr: Expression) -> str: """Format an expression.""" - return expr.accept(self) + return self.visit(expr) def visit_literal(self, node: LiteralExpr) -> str: """Visit literal expression.""" + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + + if isinstance(node.value, Angle): + # Round-trip the source unit: `rad(0.5)` / `turns(0.25)`. + return node.value.slr_repr() if isinstance(node.value, bool): return "True" if node.value else "False" if isinstance(node.value, float): @@ -376,7 +471,7 @@ def visit_bit_type(self, _node: object) -> str: def visit_array_type(self, node) -> str: """Visit array type.""" - element = node.element.accept(self) + element = self.visit(node.element) return f"Array[{element}, {node.size}]" def visit_allocator_type(self, node) -> str: diff --git a/python/quantum-pecos/src/pecos/slr/ast/serialize.py b/python/quantum-pecos/src/pecos/slr/ast/serialize.py index 03dcd7212..1adddab0f 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/serialize.py +++ b/python/quantum-pecos/src/pecos/slr/ast/serialize.py @@ -32,7 +32,11 @@ from dataclasses import fields, is_dataclass from typing import TYPE_CHECKING, Any +from pecos_rslib import angle64 + +from pecos.slr.angle import Angle from pecos.slr.ast.nodes import ( + AllocatorArg, AllocatorDecl, AllocatorTypeExpr, ArrayTypeExpr, @@ -40,9 +44,13 @@ BarrierOp, BinaryExpr, BinaryOp, + BitBundleArg, BitExpr, BitRef, BitTypeExpr, + BlockCall, + BlockDecl, + BlockInput, CommentOp, ForStmt, GateKind, @@ -53,11 +61,16 @@ ParallelBlock, PermuteOp, PrepareOp, + PrintOp, Program, + QubitBundleArg, QubitTypeExpr, RegisterDecl, RepeatStmt, + ResourceEffect, ReturnOp, + SingleBitArg, + SingleQubitArg, SlotRef, SourceLocation, UnaryExpr, @@ -82,6 +95,14 @@ "BitExpr": BitExpr, "BitRef": BitRef, "BitTypeExpr": BitTypeExpr, + "AllocatorArg": AllocatorArg, + "BitBundleArg": BitBundleArg, + "BlockCall": BlockCall, + "BlockDecl": BlockDecl, + "BlockInput": BlockInput, + "QubitBundleArg": QubitBundleArg, + "SingleBitArg": SingleBitArg, + "SingleQubitArg": SingleQubitArg, "CommentOp": CommentOp, "ForStmt": ForStmt, "GateOp": GateOp, @@ -91,6 +112,7 @@ "ParallelBlock": ParallelBlock, "PermuteOp": PermuteOp, "PrepareOp": PrepareOp, + "PrintOp": PrintOp, "Program": Program, "QubitTypeExpr": QubitTypeExpr, "RegisterDecl": RegisterDecl, @@ -108,6 +130,7 @@ "GateKind": GateKind, "BinaryOp": BinaryOp, "UnaryOp": UnaryOp, + "ResourceEffect": ResourceEffect, } @@ -139,8 +162,14 @@ def _serialize_value(value: Any) -> Any: return None if isinstance(value, (int, float, bool, str)): return value - if isinstance(value, (GateKind, BinaryOp, UnaryOp)): + if isinstance(value, (GateKind, BinaryOp, UnaryOp, ResourceEffect)): return {"_enum": type(value).__name__, "value": value.name} + if isinstance(value, Angle): + # Encode by the exact fixed-point fraction + source unit so the + # angle64 value round-trips losslessly and pretty-print keeps the + # unit label. (Must precede the generic dataclass branch below -- + # Angle is a dataclass but its `angle64` field is not serializable.) + return {"_angle": {"fraction": value.value.fraction, "unit": value.source_unit}} if isinstance(value, tuple): return [_serialize_value(v) for v in value] if isinstance(value, list): @@ -216,6 +245,10 @@ def _deserialize_value(value: Any, field_name: str, field_info: dict) -> Any: msg = f"Unknown enum type: {value['_enum']}" raise ValueError(msg) return enum_class[value["value"]] + if "_angle" in value: + # Typed rotation angle: rebuild from the fixed-point fraction. + a = value["_angle"] + return Angle(angle64(a["fraction"]), a["unit"]) if "_type" in value: # Nested AST node return dict_to_ast(value) diff --git a/python/quantum-pecos/src/pecos/slr/ast/validation/type_checker.py b/python/quantum-pecos/src/pecos/slr/ast/validation/type_checker.py index 727b4ddd8..e1f86bdf3 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/validation/type_checker.py +++ b/python/quantum-pecos/src/pecos/slr/ast/validation/type_checker.py @@ -73,8 +73,6 @@ GateKind.Y: 1, GateKind.Z: 1, GateKind.H: 1, - GateKind.S: 1, - GateKind.Sdg: 1, GateKind.T: 1, GateKind.Tdg: 1, GateKind.SX: 1, @@ -211,12 +209,14 @@ def _validate_gate(self, node: GateOp) -> None: ) def _validate_numeric_expression(self, expr: Expression, context: str) -> None: - """Validate that an expression is numeric.""" + """Validate that an expression is numeric (or a typed angle).""" + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + if isinstance(expr, LiteralExpr): - if not isinstance(expr.value, int | float): + if not isinstance(expr.value, int | float | Angle): self.errors.append( ValidationError( - message=f"Expected numeric value for {context}, got {type(expr.value).__name__}", + message=f"Expected numeric value or Angle for {context}, got {type(expr.value).__name__}", location=expr.location, severity=Severity.ERROR, code="E203", diff --git a/python/quantum-pecos/src/pecos/slr/ast/visitor.py b/python/quantum-pecos/src/pecos/slr/ast/visitor.py index c1290114f..edf4ed7e5 100644 --- a/python/quantum-pecos/src/pecos/slr/ast/visitor.py +++ b/python/quantum-pecos/src/pecos/slr/ast/visitor.py @@ -24,15 +24,20 @@ if TYPE_CHECKING: from pecos.slr.ast.nodes import ( + AllocatorArg, AllocatorDecl, AllocatorTypeExpr, ArrayTypeExpr, AssignOp, BarrierOp, BinaryExpr, + BitBundleArg, BitExpr, BitRef, BitTypeExpr, + BlockCall, + BlockDecl, + BlockInput, CommentOp, ForStmt, GateOp, @@ -42,11 +47,15 @@ ParallelBlock, PermuteOp, PrepareOp, + PrintOp, Program, + QubitBundleArg, QubitTypeExpr, RegisterDecl, RepeatStmt, ReturnOp, + SingleBitArg, + SingleQubitArg, SlotRef, UnaryExpr, VarExpr, @@ -56,6 +65,51 @@ T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) +# Node-class-name -> BaseVisitor method name. This is the single source of +# truth for visitor dispatch (replaces the per-node `accept()` double +# dispatch). Keyed by `type(node).__name__` so no runtime import of the +# node classes is needed (avoids an import cycle with nodes.py). +# `test_ast_visitor` asserts every concrete AstNode subclass is registered +# here, so a new node without an entry fails loudly in tests. +_DISPATCH: dict[str, str] = { + "Program": "visit_program", + "AllocatorDecl": "visit_allocator_decl", + "RegisterDecl": "visit_register_decl", + "GateOp": "visit_gate", + "PrepareOp": "visit_prepare", + "MeasureOp": "visit_measure", + "AssignOp": "visit_assign", + "BarrierOp": "visit_barrier", + "CommentOp": "visit_comment", + "ReturnOp": "visit_return", + "PrintOp": "visit_print", + "PermuteOp": "visit_permute", + "IfStmt": "visit_if", + "WhileStmt": "visit_while", + "ForStmt": "visit_for", + "RepeatStmt": "visit_repeat", + "ParallelBlock": "visit_parallel", + "BlockInput": "visit_block_input", + "BlockDecl": "visit_block_decl", + "BlockCall": "visit_block_call", + "AllocatorArg": "visit_allocator_arg", + "SingleQubitArg": "visit_single_qubit_arg", + "SingleBitArg": "visit_single_bit_arg", + "QubitBundleArg": "visit_qubit_bundle_arg", + "BitBundleArg": "visit_bit_bundle_arg", + "SlotRef": "visit_slot_ref", + "BitRef": "visit_bit_ref", + "LiteralExpr": "visit_literal", + "VarExpr": "visit_var", + "BitExpr": "visit_bit_expr", + "BinaryExpr": "visit_binary", + "UnaryExpr": "visit_unary", + "QubitTypeExpr": "visit_qubit_type", + "BitTypeExpr": "visit_bit_type", + "ArrayTypeExpr": "visit_array_type", + "AllocatorTypeExpr": "visit_allocator_type", +} + class AstVisitor(Protocol[T_co]): """Protocol defining the visitor interface for AST nodes. @@ -88,6 +142,26 @@ def visit_return(self, node: ReturnOp) -> T_co: ... def visit_permute(self, node: PermuteOp) -> T_co: ... + def visit_print(self, node: PrintOp) -> T_co: ... + + # Reusable blocks + def visit_block_decl(self, node: BlockDecl) -> T_co: ... + + def visit_block_input(self, node: BlockInput) -> T_co: ... + + def visit_block_call(self, node: BlockCall) -> T_co: ... + + # BlockCall argument bindings + def visit_allocator_arg(self, node: AllocatorArg) -> T_co: ... + + def visit_single_qubit_arg(self, node: SingleQubitArg) -> T_co: ... + + def visit_single_bit_arg(self, node: SingleBitArg) -> T_co: ... + + def visit_qubit_bundle_arg(self, node: QubitBundleArg) -> T_co: ... + + def visit_bit_bundle_arg(self, node: BitBundleArg) -> T_co: ... + # Control flow def visit_if(self, node: IfStmt) -> T_co: ... @@ -144,8 +218,29 @@ def combine_results(self, results: list[str]) -> str: """ def visit(self, node) -> T: - """Dispatch to the appropriate visit method.""" - return node.accept(self) + """Dispatch to the appropriate visit method by node type. + + Centralized dispatch (replaces per-node `accept()` double + dispatch): nodes carry no visitor coupling. `_DISPATCH` maps a + node class name to its `visit_*` method; the lookup is + late-bound via `getattr` so subclass overrides still apply. + + Resolution walks `type(node).__mro__` and uses the first + registered ancestor. This preserves the old inherited-`accept()` + semantics: a user subclass of a concrete node (e.g. + `class MyGate(GateOp)`) inherited `GateOp.accept` and dispatched + to `visit_gate`; MRO lookup reproduces that exactly. + """ + for cls in type(node).__mro__: + method = _DISPATCH.get(cls.__name__) + if method is not None: + return getattr(self, method)(node) + msg = ( + f"BaseVisitor: no visit method registered for AST node " + f"{type(node).__name__!r} (add it to _DISPATCH in " + f"pecos.slr.ast.visitor)" + ) + raise TypeError(msg) def visit_children(self, node) -> list[T]: """Visit all children and collect results.""" @@ -202,6 +297,45 @@ def visit_return(self, node: ReturnOp) -> T: def visit_permute(self, _node: PermuteOp) -> T: return self.default_result() + def visit_print(self, node: PrintOp) -> T: + results = self.visit_children(node) + return self.combine_results(results) + + # Reusable blocks + + def visit_block_decl(self, node: BlockDecl) -> T: + results = self.visit_children(node) + return self.combine_results(results) + + def visit_block_input(self, node: BlockInput) -> T: + results = self.visit_children(node) + return self.combine_results(results) + + def visit_block_call(self, node: BlockCall) -> T: + results = self.visit_children(node) + return self.combine_results(results) + + # BlockCall argument bindings + + def visit_allocator_arg(self, _node: AllocatorArg) -> T: + return self.default_result() + + def visit_single_qubit_arg(self, node: SingleQubitArg) -> T: + results = self.visit_children(node) + return self.combine_results(results) + + def visit_single_bit_arg(self, node: SingleBitArg) -> T: + results = self.visit_children(node) + return self.combine_results(results) + + def visit_qubit_bundle_arg(self, node: QubitBundleArg) -> T: + results = self.visit_children(node) + return self.combine_results(results) + + def visit_bit_bundle_arg(self, node: BitBundleArg) -> T: + results = self.visit_children(node) + return self.combine_results(results) + # Control flow def visit_if(self, node: IfStmt) -> T: diff --git a/python/quantum-pecos/src/pecos/slr/converters/from_quantum_circuit.py b/python/quantum-pecos/src/pecos/slr/converters/from_quantum_circuit.py index 0c8512808..4a44e89c3 100644 --- a/python/quantum-pecos/src/pecos/slr/converters/from_quantum_circuit.py +++ b/python/quantum-pecos/src/pecos/slr/converters/from_quantum_circuit.py @@ -277,18 +277,18 @@ def _convert_gate_set(gate_symbol, locations, q, c, measurement_offset): elif gate_upper in ["R", "RZ", "RESET"]: for loc in locations: if isinstance(loc, int): - ops.append(qubit.Prep(q[loc])) + ops.append(qubit.PZ(q[loc])) elif isinstance(loc, tuple) and len(loc) == 1: - ops.append(qubit.Prep(q[loc[0]])) + ops.append(qubit.PZ(q[loc[0]])) elif gate_upper == "RX": for loc in locations: if isinstance(loc, int): - ops.append(qubit.Prep(q[loc])) + ops.append(qubit.PZ(q[loc])) ops.append(qubit.H(q[loc])) elif gate_upper == "RY": for loc in locations: if isinstance(loc, int): - ops.append(qubit.Prep(q[loc])) + ops.append(qubit.PZ(q[loc])) ops.append(qubit.H(q[loc])) ops.append(qubit.SZ(q[loc])) else: diff --git a/python/quantum-pecos/src/pecos/slr/converters/from_stim.py b/python/quantum-pecos/src/pecos/slr/converters/from_stim.py index edc1127ff..8ecac2a32 100644 --- a/python/quantum-pecos/src/pecos/slr/converters/from_stim.py +++ b/python/quantum-pecos/src/pecos/slr/converters/from_stim.py @@ -217,11 +217,11 @@ def _map_gate(gate_name, targets, args, q, c, measurement_offset): ops.append(qubit.Measure(q[idx]) > c[measurement_offset + i]) elif gate_name in ["R", "RZ"]: # Reset - ops.extend(qubit.Prep(q[idx]) for idx in qubit_targets) + ops.extend(qubit.PZ(q[idx]) for idx in qubit_targets) elif gate_name in ["RX", "RY"]: # Reset in X or Y basis for idx in qubit_targets: - ops.append(qubit.Prep(q[idx])) + ops.append(qubit.PZ(q[idx])) if gate_name == "RX": ops.append(qubit.H(q[idx])) else: # RY diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/__init__.py b/python/quantum-pecos/src/pecos/slr/gen_codes/__init__.py index d2c013507..faa19035d 100644 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/__init__.py +++ b/python/quantum-pecos/src/pecos/slr/gen_codes/__init__.py @@ -34,14 +34,12 @@ For AST-based code generation, see :mod:`pecos.slr.ast.codegen`. """ -from pecos.slr.gen_codes.gen_guppy import GuppyGenerator from pecos.slr.gen_codes.gen_qasm import QASMGenerator from pecos.slr.gen_codes.gen_qir import QIRGenerator from pecos.slr.gen_codes.gen_quantum_circuit import QuantumCircuitGenerator from pecos.slr.gen_codes.gen_stim import StimGenerator __all__ = [ - "GuppyGenerator", "QASMGenerator", "QIRGenerator", "QuantumCircuitGenerator", diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_guppy.py b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_guppy.py deleted file mode 100644 index 1d7721d71..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_guppy.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Guppy code generation for SLR programs. - -This module provides the entry point for Guppy code generation. -The actual implementation is in the guppy/ subdirectory. -""" - -from pecos.slr.gen_codes.guppy import IRGuppyGenerator - -# Alias for convenience -GuppyGenerator = IRGuppyGenerator - -__all__ = ["GuppyGenerator", "IRGuppyGenerator"] diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qasm.py b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qasm.py index fdb295297..a04ba6ee2 100644 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qasm.py +++ b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qasm.py @@ -17,6 +17,15 @@ from pecos.slr.gen_codes.generator import Generator +def _param_qasm(p: object) -> str: + """Render a gate param for QASM: a typed `Angle` -> signed radians.""" + from pecos.slr.angle import Angle + + if isinstance(p, Angle): + return str(p.value.to_radians_signed()) + return str(p) + + class QASMGenerator(Generator): """Generate QASM code from SLR programs. @@ -645,7 +654,7 @@ def process_qgate(self, op): ], ) - case "Prep": + case "PZ": op_str = self.qgate_sq_qasm(op, "reset") case "T": @@ -686,7 +695,7 @@ def qgate_sq_qasm(self, op, repr_str: str | None = None): repr_str = op.sym.lower() if op.params: - str_cargs = ", ".join([str(p) for p in op.params]) + str_cargs = ", ".join([_param_qasm(p) for p in op.params]) repr_str = f"{repr_str}({str_cargs})" str_list = [] @@ -721,7 +730,7 @@ def qgate_tq_qasm(self, op, repr_str: str | None = None): repr_str = op.sym.lower() if op.params: - str_cargs = ",".join([str(p) for p in op.params]) + str_cargs = ",".join([_param_qasm(p) for p in op.params]) repr_str = f"{repr_str}({str_cargs})" str_list = [] diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qir.py b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qir.py index 8fa30eca0..d67cd7747 100644 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qir.py +++ b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_qir.py @@ -18,7 +18,7 @@ from pecos_rslib_llvm import binding, ir -from pecos.slr import Block, If, Repeat +from pecos.slr import Angle, Block, If, Repeat from pecos.slr.cops import ( NEG, NOT, @@ -299,7 +299,9 @@ def create_creg(self, creg: CReg): [ir.Constant(ir.IntType(64), creg.size)], f"{creg.sym}", ), - creg.result, + # The per-register result flag was removed; all user cregs + # are recorded (internal scratch temps still pass False below). + True, ) def create_qreg(self, qreg: QReg): @@ -320,7 +322,7 @@ def _generate_results(self) -> None: """Generates the proper results calls at the end of the SLR program, according to all the classical registers that were defined.""" for reg_name, (reg_inst, result) in self._creg_dict.items(): - if not result: # ignore non-result cregs + if not result: # skip internal scratch temps (all user cregs are recorded post-3b) continue # add global tag for each CReg reg_name_bytes = bytearray(reg_name.encode("utf-8")) @@ -804,7 +806,13 @@ def _create_qgate_call(self, gate: qgate_base.QGate) -> None: gate_declaration = self._gate_declaration_cache[gate.sym] gate_args = [] if gate.has_parameters: - gate_args = [ir.Constant(self._types.double_type, param) for param in gate.params] + gate_args = [ + ir.Constant( + self._types.double_type, + param.value.to_radians_signed() if isinstance(param, Angle) else float(param), + ) + for param in gate.params + ] gate_args.extend([self._qarg_to_qubit_ptr(qarg) for qarg in qargs]) # Create the actual invocation on the builder using the args passed in diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_quantum_circuit.py b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_quantum_circuit.py index e5deb6e60..5cbdb1961 100644 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_quantum_circuit.py +++ b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_quantum_circuit.py @@ -322,7 +322,7 @@ def _handle_quantum_op(self, op): "CY": "CY", "CZ": "CZ", "Measure": "Measure", - "Prep": "RESET", + "PZ": "RESET", "RX": "RX", "RY": "RY", "RZ": "RZ", diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_stim.py b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_stim.py index bded9112b..98d6d759a 100644 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/gen_stim.py +++ b/python/quantum-pecos/src/pecos/slr/gen_codes/gen_stim.py @@ -250,7 +250,7 @@ def _handle_quantum_op(self, op): elif op_class == "Measure": self.circuit.append_operation("M", qubits) self.measurement_count += len(qubits) - elif op_class == "Prep": + elif op_class == "PZ": self.circuit.append_operation("R", qubits) elif op_class in ["RX", "RY", "RZ"]: # Rotation gates - add as parameterized gates if supported diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/__init__.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/__init__.py deleted file mode 100644 index d2668b286..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Guppy code generation package for SLR programs.""" - -from pecos.slr.gen_codes.guppy.ir_generator import IRGuppyGenerator - -__all__ = ["IRGuppyGenerator"] diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/data_flow.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/data_flow.py deleted file mode 100644 index 8673d750f..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/data_flow.py +++ /dev/null @@ -1,379 +0,0 @@ -"""Data flow analysis for SLR to Guppy code generation. - -This module provides data flow analysis to track how quantum and classical values -flow through a program, particularly tracking measurement results and their usage. - -The key insight is that we need to distinguish between: -1. Operations BEFORE measurement (don't require unpacking) -2. Operations AFTER measurement that use the SAME qubit (require unpacking for replacement) -3. Operations AFTER measurement that use DIFFERENT qubits (don't require unpacking) - -Current heuristics over-approximate by treating ANY operation after ANY measurement -as requiring unpacking, leading to unnecessary array unpacking. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from pecos.slr import Block as SLRBlock - - -@dataclass -class ValueUse: - """Represents a use of a value (qubit or classical bit).""" - - array_name: str - index: int - position: int # Position in operation sequence - operation_type: str # e.g., "gate", "measurement", "condition" - is_consuming: bool = False # True if this use consumes the value (e.g., measurement) - - -@dataclass -class DataFlowInfo: - """Information about data flow for a single array element.""" - - array_name: str - index: int - is_classical: bool - - # Track all uses of this element - uses: list[ValueUse] = field(default_factory=list) - - # Track consumption points (measurements) - consumed_at: list[int] = field(default_factory=list) - - # Track replacements (e.g., Prep after measurement) - replaced_at: list[int] = field(default_factory=list) - - def add_use( - self, - position: int, - operation_type: str, - *, - is_consuming: bool = False, - ) -> None: - """Add a use of this value.""" - use = ValueUse( - array_name=self.array_name, - index=self.index, - position=position, - operation_type=operation_type, - is_consuming=is_consuming, - ) - self.uses.append(use) - - if is_consuming: - self.consumed_at.append(position) - - def add_replacement(self, position: int) -> None: - """Mark that this value is replaced at a position (e.g., Prep).""" - self.replaced_at.append(position) - - def has_use_after_consumption(self) -> bool: - """Check if this element is used after being consumed. - - This is the key analysis for determining if unpacking is needed. - If a qubit is measured and then used again (not just replaced), - we need unpacking to handle the replacement properly. - """ - if not self.consumed_at: - return False - - # Find the first consumption point - first_consumption = min(self.consumed_at) - - # Check if there are any non-replacement uses after consumption - for use in self.uses: - if use.position > first_consumption and use.position not in self.replaced_at: - # This is a real use after consumption, not just replacement - # However, we need to check if it's AFTER replacement - # Find if there's a replacement between consumption and this use - replacements_between = [r for r in self.replaced_at if first_consumption < r < use.position] - - if not replacements_between: - # Use after consumption with no replacement in between - # This requires unpacking - return True - - return False - - def requires_unpacking_for_flow(self) -> bool: - """Determine if this element requires unpacking based on data flow. - - This is more precise than the heuristic approach: - - Classical values can be used multiple times without issue - - Quantum values can be used multiple times if not measured - - Quantum values that are measured and then used require unpacking - """ - if self.is_classical: - # Classical values can be read multiple times - return False - - # Quantum values: check if used after consumption - return self.has_use_after_consumption() - - -@dataclass -class DataFlowAnalysis: - """Complete data flow analysis for a block.""" - - # Map from (array_name, index) to DataFlowInfo - element_flows: dict[tuple[str, int], DataFlowInfo] = field(default_factory=dict) - - # Track conditionally accessed elements - conditional_accesses: set[tuple[str, int]] = field(default_factory=set) - - def get_or_create_flow( - self, - array_name: str, - index: int, - is_classical: bool, - ) -> DataFlowInfo: - """Get or create data flow info for an array element.""" - key = (array_name, index) - if key not in self.element_flows: - self.element_flows[key] = DataFlowInfo( - array_name=array_name, - index=index, - is_classical=is_classical, - ) - return self.element_flows[key] - - def add_gate_use(self, array_name: str, index: int, position: int) -> None: - """Record a gate operation on an array element.""" - flow = self.get_or_create_flow(array_name, index, is_classical=False) - flow.add_use(position, "gate", is_consuming=False) - - def add_measurement( - self, - quantum_array: str, - quantum_index: int, - position: int, - classical_array: str | None = None, - classical_index: int | None = None, - ) -> None: - """Record a measurement operation.""" - # Quantum side: consumption - q_flow = self.get_or_create_flow( - quantum_array, - quantum_index, - is_classical=False, - ) - q_flow.add_use(position, "measurement", is_consuming=True) - - # Classical side: creation (if specified) - if classical_array is not None and classical_index is not None: - c_flow = self.get_or_create_flow( - classical_array, - classical_index, - is_classical=True, - ) - c_flow.add_use(position, "measurement_result", is_consuming=False) - - def add_preparation(self, array_name: str, index: int, position: int) -> None: - """Record a preparation/reset operation (replaces a qubit).""" - flow = self.get_or_create_flow(array_name, index, is_classical=False) - flow.add_use(position, "preparation", is_consuming=False) - flow.add_replacement(position) - - def add_conditional_use( - self, - array_name: str, - index: int, - position: int, - is_classical: bool, - ) -> None: - """Record a conditional use of an array element.""" - flow = self.get_or_create_flow(array_name, index, is_classical) - flow.add_use(position, "condition", is_consuming=False) - self.conditional_accesses.add((array_name, index)) - - def elements_requiring_unpacking(self) -> set[tuple[str, int]]: - """Get the set of array elements that require unpacking based on data flow.""" - requiring_unpacking = set() - - for key, flow in self.element_flows.items(): - if flow.requires_unpacking_for_flow(): - requiring_unpacking.add(key) - - return requiring_unpacking - - def array_requires_unpacking(self, array_name: str) -> bool: - """Check if an entire array requires unpacking based on data flow.""" - for key, flow in self.element_flows.items(): - if key[0] == array_name and flow.requires_unpacking_for_flow(): - return True - return False - - -class DataFlowAnalyzer: - """Analyzes data flow in SLR blocks.""" - - def __init__(self): - self.position_counter = 0 - self.in_conditional = False - - def analyze( - self, - block: SLRBlock, - variable_context: dict[str, Any], - ) -> DataFlowAnalysis: - """Analyze data flow in a block. - - Args: - block: The SLR block to analyze - variable_context: Context of variables (QReg, CReg, etc.) - - Returns: - DataFlowAnalysis containing all data flow information - """ - analysis = DataFlowAnalysis() - self.position_counter = 0 - self.in_conditional = False - - # Analyze all operations - if hasattr(block, "ops"): - for op in block.ops: - self._analyze_operation(op, analysis, variable_context) - self.position_counter += 1 - - return analysis - - def _analyze_operation( - self, - op: Any, - analysis: DataFlowAnalysis, - variable_context: dict[str, Any], - ) -> None: - """Analyze a single operation.""" - op_type = type(op).__name__ - - if op_type == "Measure": - self._analyze_measurement(op, analysis) - elif op_type == "If": - self._analyze_if_block(op, analysis, variable_context) - elif hasattr(op, "qargs"): - # Check if this is a preparation operation - if self._is_preparation(op): - self._analyze_preparation(op, analysis) - else: - self._analyze_quantum_operation(op, analysis) - elif hasattr(op, "ops"): - # Nested block - recurse - for nested_op in op.ops: - self._analyze_operation(nested_op, analysis, variable_context) - - def _is_preparation(self, op: Any) -> bool: - """Check if an operation is a preparation/reset.""" - op_name = type(op).__name__ - return op_name in ["Prep", "Init", "Reset"] - - def _analyze_measurement(self, meas: Any, analysis: DataFlowAnalysis) -> None: - """Analyze a measurement operation.""" - # Get classical targets - classical_targets = [] - if hasattr(meas, "cout") and meas.cout: - classical_targets.extend( - (cout.reg.sym, cout.index) - for cout in meas.cout - if hasattr(cout, "reg") and hasattr(cout.reg, "sym") and hasattr(cout, "index") - ) - - # Analyze quantum sources - if hasattr(meas, "qargs") and meas.qargs: - for i, qarg in enumerate(meas.qargs): - # Individual element measurement - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index"): - array_name = qarg.reg.sym - index = qarg.index - - # Get corresponding classical target if exists - classical_array = None - classical_index = None - if i < len(classical_targets): - classical_array, classical_index = classical_targets[i] - - analysis.add_measurement( - quantum_array=array_name, - quantum_index=index, - position=self.position_counter, - classical_array=classical_array, - classical_index=classical_index, - ) - - def _analyze_preparation(self, op: Any, analysis: DataFlowAnalysis) -> None: - """Analyze a preparation/reset operation.""" - if hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index"): - array_name = qarg.reg.sym - index = qarg.index - analysis.add_preparation(array_name, index, self.position_counter) - - def _analyze_quantum_operation(self, op: Any, analysis: DataFlowAnalysis) -> None: - """Analyze a quantum gate operation.""" - if hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index"): - array_name = qarg.reg.sym - index = qarg.index - - if self.in_conditional: - analysis.add_conditional_use( - array_name, - index, - self.position_counter, - is_classical=False, - ) - else: - analysis.add_gate_use(array_name, index, self.position_counter) - - def _analyze_if_block( - self, - if_block: Any, - analysis: DataFlowAnalysis, - variable_context: dict[str, Any], - ) -> None: - """Analyze an if block.""" - prev_conditional = self.in_conditional - self.in_conditional = True - - # Analyze condition - if hasattr(if_block, "cond"): - self._analyze_condition(if_block.cond, analysis) - - # Analyze then block - if hasattr(if_block, "ops"): - for op in if_block.ops: - self._analyze_operation(op, analysis, variable_context) - - # Analyze else block - if hasattr(if_block, "else_block") and if_block.else_block and hasattr(if_block.else_block, "ops"): - for op in if_block.else_block.ops: - self._analyze_operation(op, analysis, variable_context) - - self.in_conditional = prev_conditional - - def _analyze_condition(self, cond: Any, analysis: DataFlowAnalysis) -> None: - """Analyze a condition expression.""" - cond_type = type(cond).__name__ - - if cond_type == "Bit" and hasattr(cond, "reg") and hasattr(cond.reg, "sym") and hasattr(cond, "index"): - array_name = cond.reg.sym - index = cond.index - analysis.add_conditional_use( - array_name, - index, - self.position_counter, - is_classical=True, - ) - - # Handle compound conditions - if hasattr(cond, "left"): - self._analyze_condition(cond.left, analysis) - if hasattr(cond, "right"): - self._analyze_condition(cond.right, analysis) diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/dependency_analyzer.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/dependency_analyzer.py deleted file mode 100644 index bbd77521a..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/dependency_analyzer.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Dependency analyzer for SLR blocks.""" - -from __future__ import annotations - -import inspect -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from pecos.slr import Block - - -@dataclass -class BlockDependency: - """Represents dependencies for a block.""" - - block_type: str - constructor_params: dict[str, Any] # parameter name -> value/type - used_variables: set[str] # Set of variable names used in operations - nested_blocks: list[BlockDependency] # Dependencies of nested blocks - - -class DependencyAnalyzer: - """Analyzes SLR blocks to determine parameter dependencies.""" - - def __init__(self): - self.analyzed_blocks = {} # Cache of analyzed block types - - def analyze_block(self, block: Block) -> BlockDependency: - """Analyze a block to determine its dependencies.""" - block_type = type(block).__name__ - - # Get constructor parameters - constructor_params = self._get_constructor_params(block) - - # Find used variables in operations - used_variables = set() - nested_blocks = [] - - if hasattr(block, "ops"): - for op in block.ops: - # Collect variables from operations - self._collect_variables_from_op(op, used_variables) - - # If it's a nested block, analyze it too - if hasattr(op, "ops") and hasattr(op, "vars"): - nested_dep = self.analyze_block(op) - nested_blocks.append(nested_dep) - # Add nested block's used variables - used_variables.update(nested_dep.used_variables) - - return BlockDependency( - block_type=block_type, - constructor_params=constructor_params, - used_variables=used_variables, - nested_blocks=nested_blocks, - ) - - def _get_constructor_params(self, block: Block) -> dict[str, Any]: - """Extract constructor parameters from a block instance.""" - params = {} - - # Get the constructor signature - sig = inspect.signature(type(block).__init__) - - # Try to match parameters with instance attributes - for param_name in sig.parameters: - if param_name == "self": - continue - - # Common patterns for how parameters are stored - if hasattr(block, param_name): - params[param_name] = getattr(block, param_name) - elif hasattr(block, f"_{param_name}"): - params[param_name] = getattr(block, f"_{param_name}") - # Try to infer from operations - elif param_name in ["data", "qubits", "q"]: - # Look for quantum registers - params[param_name] = self._find_qreg_in_ops(block) - elif param_name in ["ancilla", "a"]: - # Look for ancilla qubits - params[param_name] = self._find_ancilla_in_ops(block) - elif param_name in ["init_bit", "init", "bit", "c"]: - # Look for classical bits - params[param_name] = self._find_bit_in_ops(block) - - return params - - def _collect_variables_from_op(self, op, used_vars: set[str]): - """Collect variable names used in an operation.""" - # Check quantum arguments - if hasattr(op, "qargs"): - for qarg in op.qargs: - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - used_vars.add(qarg.reg.sym) - elif hasattr(qarg, "sym"): - # Direct QReg object - used_vars.add(qarg.sym) - elif isinstance(qarg, tuple): - # Handle tuples of qubits - for q in qarg: - if hasattr(q, "reg") and hasattr(q.reg, "sym"): - used_vars.add(q.reg.sym) - - # Check classical arguments - if hasattr(op, "cargs"): - for carg in op.cargs: - if hasattr(carg, "reg") and hasattr(carg.reg, "sym"): - used_vars.add(carg.reg.sym) - - # Check output bits (for measurements) - if hasattr(op, "cout") and op.cout: - for cout in op.cout: - if hasattr(cout, "reg") and hasattr(cout.reg, "sym"): - used_vars.add(cout.reg.sym) - elif hasattr(cout, "sym"): - # Direct CReg reference - used_vars.add(cout.sym) - - # Check condition (for If blocks) - if hasattr(op, "cond"): - self._collect_variables_from_expr(op.cond, used_vars) - - # Check expressions (for classical operations) - if hasattr(op, "left"): - self._collect_variables_from_expr(op.left, used_vars) - if hasattr(op, "right"): - self._collect_variables_from_expr(op.right, used_vars) - - def _collect_variables_from_expr(self, expr, used_vars: set[str]): - """Collect variable names from expressions.""" - if hasattr(expr, "reg") and hasattr(expr.reg, "sym"): - used_vars.add(expr.reg.sym) - elif hasattr(expr, "left") and hasattr(expr, "right"): - # Binary operation - self._collect_variables_from_expr(expr.left, used_vars) - self._collect_variables_from_expr(expr.right, used_vars) - elif hasattr(expr, "args"): - # Function call or similar - for arg in expr.args: - self._collect_variables_from_expr(arg, used_vars) - - def _find_qreg_in_ops(self, block): - """Try to find quantum register used in operations.""" - if hasattr(block, "ops"): - for op in block.ops: - if hasattr(op, "qargs") and op.qargs: - qarg = op.qargs[0] - if hasattr(qarg, "reg"): - return qarg.reg - return None - - def _find_ancilla_in_ops(self, block): - """Try to find ancilla qubit used in operations.""" - # Look for single qubit operations that might be ancilla - if hasattr(block, "ops"): - for op in block.ops: - if type(op).__name__ == "Prep" and hasattr(op, "qargs"): - # Prep operations often reset ancillas - for qarg in op.qargs: - if hasattr(qarg, "index") and not hasattr(qarg, "size"): - # Single qubit - return qarg - return None - - def _find_bit_in_ops(self, block): - """Try to find classical bit used in operations.""" - if hasattr(block, "ops"): - for op in block.ops: - if hasattr(op, "cout") and op.cout: - # Measurement output - return op.cout[0] - return None - - def get_required_parameters( - self, - block: Block, - parent_context: dict[str, Any], - ) -> list[tuple[str, str]]: - """Get the parameters required for a block function. - - Args: - block: The block to analyze - parent_context: Dictionary mapping variable names to their types/values - - Returns: - List of (param_name, param_type) tuples - """ - dep = self.analyze_block(block) - - # Collect all used variables - all_used = dep.used_variables.copy() - - # Map to parameter types - params = [] - for var_name in sorted(all_used): - if var_name in parent_context: - var_info = parent_context[var_name] - if hasattr(var_info, "__class__"): - var_type = var_info.__class__.__name__ - if var_type == "QReg": - size = var_info.size if hasattr(var_info, "size") else 1 - params.append((var_name, f"array[quantum.qubit, {size}]")) - elif var_type == "CReg": - size = var_info.size if hasattr(var_info, "size") else 1 - params.append((var_name, f"array[bool, {size}]")) - else: - params.append((var_name, var_type)) - - return params diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/hugr_compiler.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/hugr_compiler.py deleted file mode 100644 index 8fb6f29dd..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/hugr_compiler.py +++ /dev/null @@ -1,107 +0,0 @@ -"""HUGR compiler for Guppy code generation.""" - -from __future__ import annotations - -import tempfile -from typing import Any - -from pecos.slr.gen_codes.guppy.hugr_error_handler import HugrErrorHandler - - -class HugrCompiler: - """Compiles generated Guppy code to HUGR.""" - - def __init__(self, generator): - """Initialize the HUGR compiler. - - Args: - generator: A generator instance with generated code (must have get_output() method) - """ - self.generator = generator - - def compile_to_hugr(self) -> Any: - """Compile the generated Guppy code to HUGR. - - Returns: - The compiled HUGR module - - Raises: - RuntimeError: If compilation fails - """ - # Get the generated Guppy code - guppy_code = self.generator.get_output() - - # Create a temporary file to hold the generated code - # This is necessary because guppy.compile() needs to be able to inspect the source - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - temp_file = f.name - f.write(guppy_code) - - try: - # Import the module from the temporary file - import importlib.util - import linecache - import sys - - # Add the source to linecache for better error tracking - lines = guppy_code.splitlines(keepends=True) - linecache.cache[temp_file] = ( - len(guppy_code), - None, - lines, - temp_file, - ) - - spec = importlib.util.spec_from_file_location("_guppy_generated", temp_file) - if spec is None or spec.loader is None: - msg = "Failed to create module spec" - raise RuntimeError(msg) - - module = importlib.util.module_from_spec(spec) - - # Ensure the module has proper file tracking - module.__file__ = temp_file - - # Add to sys.modules temporarily to help with source tracking - sys.modules["_guppy_generated"] = module - - spec.loader.exec_module(module) - - # Get the main function - if not hasattr(module, "main"): - msg = "No main function found in generated code" - raise RuntimeError(msg) - - main_func = module.main - - # Compile to HUGR - try: - # Debug: print the generated code - # print("DEBUG: Generated Guppy code:") - # print(guppy_code) - # print("="*50) - - # Use the new API: func.compile() instead of guppy.compile(func) - return main_func.compile() - except (AttributeError, TypeError, ValueError, RuntimeError) as e: - # Use the enhanced error handler - error_handler = HugrErrorHandler(guppy_code) - detailed_error = error_handler.analyze_error(e) - raise RuntimeError(detailed_error) - - finally: - # Clean up - try: - # Remove from sys.modules - import sys - - if "_guppy_generated" in sys.modules: - del sys.modules["_guppy_generated"] - - # Clean up the temporary file - from pathlib import Path - - Path(temp_file).unlink() - except (OSError, FileNotFoundError): - # Ignore cleanup errors - pass diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/hugr_error_handler.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/hugr_error_handler.py deleted file mode 100644 index ba8d265fb..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/hugr_error_handler.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Enhanced error handling for HUGR compilation failures.""" - -from __future__ import annotations - -import re -from dataclasses import dataclass -from typing import ClassVar - - -@dataclass -class ErrorContext: - """Context information for an error.""" - - line_number: int - line_content: str - variable_name: str | None = None - error_type: str | None = None - suggestion: str | None = None - - -class HugrErrorHandler: - """Provides detailed error messages and suggestions for HUGR compilation failures.""" - - # Common error patterns and their explanations - ERROR_PATTERNS: ClassVar[dict[str, dict[str, str]]] = { - r"PlaceNotUsedError.*Variable\(name='(\w+)'": { - "type": "PlaceNotUsedError", - "message": "Quantum register '{var}' was not consumed", - "suggestion": "Add a measurement for this quantum register or ensure it's consumed in all execution paths", - }, - r"NotOwnedError.*Variable\(name='(\w+)'": { - "type": "NotOwnedError", - "message": "Variable '{var}' is not owned in this context", - "suggestion": "Ensure the variable is passed with @owned annotation or is properly borrowed", - }, - r"AlreadyUsedError.*Variable\(name='(\w+)'": { - "type": "AlreadyUsedError", - "message": "Variable '{var}' has already been consumed", - "suggestion": "Quantum resources can only be used once. Check for duplicate measurements or operations", - }, - r"MoveOutOfSubscriptError": { - "type": "MoveOutOfSubscriptError", - "message": "Cannot move out of array subscript", - "suggestion": ( - "Use array unpacking or measure_array() instead of individual element access after consumption" - ), - }, - r"NotCallableError.*'(\w+)'": { - "type": "NotCallableError", - "message": "'{var}' is not callable", - "suggestion": "Check if a variable name conflicts with a function name (e.g., 'result')", - }, - r"NameError.*name\s+'(\w+)'\s+is\s+not\s+defined": { - "type": "NameError", - "message": "Variable '{var}' is not defined", - "suggestion": "Check variable scoping or if the variable was renamed to avoid conflicts", - }, - r"TypeError.*missing.*positional argument.*'(\w+)'": { - "type": "TypeError", - "message": "Missing required argument '{var}'", - "suggestion": "Check function signatures and ensure all required parameters are provided", - }, - r"UnknownSourceError.*obj=": { - "type": "UnknownSourceError", - "message": "Cannot find source location for dynamically generated class '{var}'", - "suggestion": "This is a known limitation with dynamically generated structs. " - "The struct definition is correct but lacks source tracking metadata.", - }, - } - - def __init__(self, guppy_code: str): - """Initialize with the generated Guppy code.""" - self.code_lines = guppy_code.split("\n") - - def analyze_error(self, error: Exception) -> str: - """Analyze an error and provide detailed diagnostics.""" - error_str = str(error) - error_type = type(error).__name__ - - # Try to match against known patterns - for pattern, info in self.ERROR_PATTERNS.items(): - match = re.search(pattern, error_str) - if match: - return self._format_known_error(match, info, error_str) - - # Handle specific error types with custom logic - if "MoveOutOfSubscriptError" in error_str: - return self._analyze_subscript_error(error_str) - - # Generic error handling - return self._format_generic_error(error_type, error_str) - - def _format_known_error(self, match: re.Match, info: dict, error_str: str) -> str: - """Format a known error pattern.""" - var_name = match.group(1) if match.groups() else None - - lines = [ - f"\n{'='*60}", - f"HUGR Compilation Error: {info['type']}", - f"{'='*60}\n", - ] - - if var_name: - lines.append(f"Problem: {info['message'].format(var=var_name)}") - else: - lines.append(f"Problem: {info['message']}") - - lines.append(f"\nSuggestion: {info['suggestion']}") - - # Find relevant code context - context = self._find_code_context(var_name) - if context: - lines.append("\nRelevant code:") - lines.extend(f" Line {ctx.line_number}: {ctx.line_content.strip()}" for ctx in context) - - # Add specific examples for common fixes - if info["type"] == "PlaceNotUsedError" and var_name: - lines.append("\nExample fix:") - lines.append(" # Add before the end of the function:") - lines.append(f" _ = quantum.measure_array({var_name})") - - elif info["type"] == "MoveOutOfSubscriptError": - lines.append("\nExample fix:") - lines.append(" # Instead of accessing elements after measurement:") - lines.append(" # BAD: c = measure_array(q); x = q[0]") - lines.append(" # GOOD: q_0, q_1 = q; c_0 = measure(q_0)") - - lines.append(f"\nOriginal error: {error_str}") - lines.append(f"{'='*60}\n") - - return "\n".join(lines) - - def _analyze_subscript_error(self, error_str: str) -> str: - """Analyze MoveOutOfSubscriptError in detail.""" - lines = [ - f"\n{'='*60}", - "HUGR Compilation Error: MoveOutOfSubscriptError", - f"{'='*60}\n", - "Problem: Attempting to access array elements after the array has been consumed", - "\nThis typically happens when:", - " 1. You measure an entire array with measure_array()", - " 2. Then try to access individual elements like array[0]", - "\nSolution approaches:", - " 1. Unpack the array before measurements:", - " q_0, q_1, q_2 = q # Unpack at the start", - " c_0 = quantum.measure(q_0) # Use unpacked names", - "\n 2. Use measure_array() for the entire array:", - " c = quantum.measure_array(q) # Measure all at once", - "\n 3. Measure individual elements without unpacking:", - " c[0] = quantum.measure(q[0]) # Before array is consumed", - ] - - # Look for array access patterns in the code - for i, line in enumerate(self.code_lines): - if ( - "measure_array" in line - and "[" - in self.code_lines[ - min(i + 1, len(self.code_lines) - 1) : min( - i + 5, - len(self.code_lines), - ) - ] - ): - lines.append(f"\nPotential issue found around line {i+1}:") - lines.append(f" {line.strip()}") - - lines.append(f"\nOriginal error: {error_str}") - lines.append(f"{'='*60}\n") - - return "\n".join(lines) - - def _format_generic_error(self, error_type: str, error_str: str) -> str: - """Format a generic error with basic diagnostics.""" - lines = [ - f"\n{'='*60}", - f"HUGR Compilation Error: {error_type}", - f"{'='*60}\n", - f"Error details: {error_str}", - "\nGeneral troubleshooting tips:", - " 1. Check that all quantum registers are consumed (measured)", - " 2. Ensure variables don't conflict with reserved names (result, array, quantum)", - " 3. Verify array operations happen before the array is consumed", - " 4. Check function parameter types and ownership annotations", - "\nFor more specific help, please check the error message above.", - f"{'='*60}\n", - ] - - return "\n".join(lines) - - def _find_code_context(self, var_name: str | None) -> list[ErrorContext]: - """Find relevant code lines for a variable.""" - if not var_name: - return [] - - contexts = [] - for i, line in enumerate(self.code_lines): - if var_name in line: - contexts.append( - ErrorContext( - line_number=i + 1, - line_content=line, - variable_name=var_name, - ), - ) - - return contexts diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir.py deleted file mode 100644 index db700a003..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir.py +++ /dev/null @@ -1,680 +0,0 @@ -"""Intermediate Representation for Guppy code generation. - -This module provides an IR that allows us to analyze and transform code -before generating the final Guppy output. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, ClassVar - - -class ResourceState(Enum): - """State of a quantum resource.""" - - AVAILABLE = "available" - CONSUMED = "consumed" - BORROWED = "borrowed" - - -@dataclass -class VariableInfo: - """Information about a variable.""" - - name: str - original_name: str # Before renaming - var_type: str # "quantum", "classical", etc. - size: int | None = None - is_array: bool = False - is_unpacked: bool = False - unpacked_names: list[str] = field(default_factory=list) - state: ResourceState = ResourceState.AVAILABLE - is_struct: bool = False - struct_info: dict | None = None - is_struct_field: bool = False - struct_name: str | None = None - field_name: str | None = None - - -@dataclass -class ScopeContext: - """Context for a scope (function, block, etc.).""" - - parent: ScopeContext | None = None - variables: dict[str, VariableInfo] = field(default_factory=dict) - unpacked_arrays: dict[str, list[str]] = field(default_factory=dict) - consumed_resources: set[str] = field(default_factory=set) - refreshed_arrays: dict[str, str] = field( - default_factory=dict, - ) # original_name -> fresh_name - - def lookup_variable(self, name: str) -> VariableInfo | None: - """Look up a variable in this scope or parent scopes.""" - if name in self.variables: - return self.variables[name] - - # Check if this variable was refreshed by a function call - if name in self.refreshed_arrays: - fresh_name = self.refreshed_arrays[name] - if fresh_name in self.variables: - return self.variables[fresh_name] - - if self.parent: - return self.parent.lookup_variable(name) - return None - - def add_variable(self, var_info: VariableInfo) -> None: - """Add a variable to this scope.""" - self.variables[var_info.name] = var_info - - def mark_consumed(self, name: str) -> None: - """Mark a resource as consumed.""" - self.consumed_resources.add(name) - var = self.lookup_variable(name) - if var: - var.state = ResourceState.CONSUMED - - -class IRNode(ABC): - """Base class for all IR nodes.""" - - @abstractmethod - def analyze(self, context: ScopeContext) -> None: - """Analyze this node for resource usage, unpacking needs, etc.""" - - @abstractmethod - def render(self, context: ScopeContext) -> list[str]: - """Render this node to Guppy code lines.""" - - -@dataclass -class ArrayAccess(IRNode): - """Represents array[index] access.""" - - array_name: str = None # Optional for backwards compatibility - array: IRNode = None # Can be a FieldAccess for struct.field[index] - index: int | str | IRNode = None - force_array_syntax: bool = False # If True, never use unpacked names - - def __post_init__(self): - """Initialize ArrayAccess, supporting both old and new API.""" - # Support both old and new API - if self.array_name and not self.array: - self.array = VariableRef(self.array_name) - - def analyze(self, context: ScopeContext) -> None: - """Mark that this array needs element access.""" - if self.array: - self.array.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - """Render array access, using unpacked name if available.""" - # Handle old API - if self.array_name and not self.force_array_syntax: - var = context.lookup_variable(self.array_name) - if var and var.is_unpacked and isinstance(self.index, int) and self.index < len(var.unpacked_names): - return [var.unpacked_names[self.index]] - - # Render array if it's an IRNode (e.g., FieldAccess) - if self.array: - array_code = self.array.render(context) - array_str = array_code[0] if len(array_code) == 1 else "???" - else: - array_str = self.array_name - - # Render index if it's an IRNode - if isinstance(self.index, IRNode): - index_code = self.index.render(context) - if len(index_code) == 1: - return [f"{array_str}[{index_code[0]}]"] - # Complex index expression - shouldn't happen usually - return [f"{array_str}[{' '.join(index_code)}]"] - - return [f"{array_str}[{self.index}]"] - - -@dataclass -class FieldAccess(IRNode): - """Access to a struct field: obj.field""" - - obj: IRNode - field: str - - def analyze(self, context: ScopeContext) -> None: - """Analyze the object being accessed.""" - self.obj.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - """Render field access as obj.field.""" - obj_code = self.obj.render(context) - obj_str = obj_code[0] if len(obj_code) == 1 else "???" - return [f"{obj_str}.{self.field}"] - - -@dataclass -class VariableRef(IRNode): - """Reference to a variable.""" - - name: str - - def analyze(self, context: ScopeContext) -> None: - """Check variable exists.""" - # Nothing to analyze for a simple variable reference - - def render(self, context: ScopeContext) -> list[str]: - """Render variable reference.""" - var = context.lookup_variable(self.name) - - if var: - return [var.name] # Use potentially renamed name - return [self.name] - - -@dataclass -class Literal(IRNode): - """Literal value.""" - - value: Any - - def analyze(self, context: ScopeContext) -> None: - """Analyze literal - no-op as literals don't need analysis.""" - # Nothing to analyze for literals - - def render(self, context: ScopeContext) -> list[str]: - _ = context # Context not needed for literal rendering - if isinstance(self.value, bool): - return ["True" if self.value else "False"] - if isinstance(self.value, str): - return [f'"{self.value}"'] - return [str(self.value)] - - -@dataclass -class Statement(IRNode): - """Base class for statements.""" - - -@dataclass -class Expression(IRNode): - """Base class for expressions.""" - - -@dataclass -class BinaryOp(Expression): - """Binary operation: left op right.""" - - left: IRNode - op: str - right: IRNode - needs_parens: bool = False # Track if this expression needs parentheses - - # Operator precedence (higher number = higher precedence) - PRECEDENCE: ClassVar[dict[str, int]] = { - "or": 1, - "|": 1, - "and": 2, - "&": 2, - "^": 3, - "==": 4, - "!=": 4, - "<": 4, - ">": 4, - "<=": 4, - ">=": 4, - "+": 5, - "-": 5, - "*": 6, - "/": 6, - "//": 6, - "%": 6, - "**": 7, - } - - def analyze(self, context: ScopeContext) -> None: - """Analyze both operands.""" - self.left.analyze(context) - self.right.analyze(context) - - def _needs_parens(self, child: IRNode, *, is_right: bool = False) -> bool: - """Check if child expression needs parentheses.""" - if not isinstance(child, BinaryOp): - return False - - child_prec = self.PRECEDENCE.get(child.op, 10) - self_prec = self.PRECEDENCE.get(self.op, 10) - - # Lower precedence needs parens - if child_prec < self_prec: - return True - # Same precedence: check associativity (left-to-right) - # For operators like -, /, we need parens on the right - return child_prec == self_prec and is_right and self.op in ["-", "/", "//", "%"] - - def render(self, context: ScopeContext) -> list[str]: - """Render binary operation with proper precedence handling.""" - # Render children - left_code = self.left.render(context) - right_code = self.right.render(context) - - # Add parentheses if needed for children - left_str = left_code[0] if len(left_code) == 1 else " ".join(left_code) - right_str = right_code[0] if len(right_code) == 1 else " ".join(right_code) - - if self._needs_parens(self.left): - left_str = f"({left_str})" - if self._needs_parens(self.right, is_right=True): - right_str = f"({right_str})" - - result = f"{left_str} {self.op} {right_str}" - - # Add parentheses if this expression was marked as needing them - if self.needs_parens: - result = f"({result})" - - return [result] - - -@dataclass -class UnaryOp(Expression): - """Unary operation: op operand.""" - - op: str - operand: IRNode - - def analyze(self, context: ScopeContext) -> None: - """Analyze the operand.""" - self.operand.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - """Render unary operation.""" - operand_code = self.operand.render(context) - if len(operand_code) == 1: - return [f"{self.op} {operand_code[0]}"] - # For complex expressions, use parentheses - return [f"{self.op} ({' '.join(operand_code)})"] - - -@dataclass -class Assignment(Statement): - """Assignment statement: target = value.""" - - target: IRNode - value: IRNode - - def analyze(self, context: ScopeContext) -> None: - """Analyze both target and value.""" - self.target.analyze(context) - self.value.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - target_code = self.target.render(context) - value_code = self.value.render(context) - if len(target_code) == 1 and len(value_code) == 1: - return [f"{target_code[0]} = {value_code[0]}"] - # Handle multi-line expressions - result = value_code[:-1] # All but last line - result.append(f"{target_code[0]} = {value_code[-1]}") - return result - - -@dataclass -class FunctionCall(Expression): - """Function call expression.""" - - func_name: str - args: list[IRNode] - - def analyze(self, context: ScopeContext) -> None: - for arg in self.args: - arg.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - arg_strs = [] - for arg in self.args: - arg_code = arg.render(context) - arg_strs.append(arg_code[0] if len(arg_code) == 1 else "???") - return [f"{self.func_name}({', '.join(arg_strs)})"] - - -@dataclass -class MethodCall(Expression): - """Method call: obj.method(args).""" - - obj: IRNode - method: str - args: list[IRNode] - - def analyze(self, context: ScopeContext) -> None: - self.obj.analyze(context) - for arg in self.args: - arg.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - obj_code = self.obj.render(context) - arg_strs = [] - for arg in self.args: - arg_code = arg.render(context) - arg_strs.append(arg_code[0] if len(arg_code) == 1 else "???") - - obj_str = obj_code[0] if len(obj_code) == 1 else "???" - return [f"{obj_str}.{self.method}({', '.join(arg_strs)})"] - - -@dataclass -class Measurement(Statement): - """Measurement operation.""" - - qubit: IRNode - target: IRNode | None = None - - def analyze(self, context: ScopeContext) -> None: - self.qubit.analyze(context) - if self.target: - self.target.analyze(context) - - # Mark qubit as consumed if it's a simple reference - if isinstance(self.qubit, VariableRef): - context.mark_consumed(self.qubit.name) - elif isinstance(self.qubit, ArrayAccess): - # Track that this array element is consumed - pass - - def render(self, context: ScopeContext) -> list[str]: - qubit_code = self.qubit.render(context) - qubit_str = qubit_code[0] if len(qubit_code) == 1 else "???" - - if self.target: - target_code = self.target.render(context) - target_str = target_code[0] if len(target_code) == 1 else "???" - return [f"{target_str} = quantum.measure({qubit_str})"] - return [f"quantum.measure({qubit_str})"] - - -@dataclass -class ArrayUnpack(Statement): - """Array unpacking: a, b, c = array.""" - - targets: list[str] - source: str - - def analyze(self, context: ScopeContext) -> None: - # Mark the array as unpacked - var = context.lookup_variable(self.source) - if var: - var.is_unpacked = True - var.unpacked_names = self.targets - - def render(self, context: ScopeContext) -> list[str]: - _ = context # Context not needed for unpacking - if len(self.targets) == 1: - # Special syntax for single element - return [f"{self.targets[0]}, = {self.source}"] - return [f"{', '.join(self.targets)} = {self.source}"] - - -@dataclass -class Comment(Statement): - """Comment line.""" - - text: str - - def analyze(self, context: ScopeContext) -> None: - """Analyze comment - no-op.""" - # Nothing to analyze for comments - - def render(self, context: ScopeContext) -> list[str]: - """Render comment line.""" - _ = context # Context not needed for comments - if self.text: - return [f"# {self.text}"] - return [] # Don't render empty comments - - -@dataclass -class ReturnStatement(Statement): - """Return statement.""" - - value: IRNode | None = None - - def analyze(self, context: ScopeContext) -> None: - """Analyze return value if present.""" - if self.value: - self.value.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - """Render return statement.""" - if self.value: - value_code = self.value.render(context) - return [f"return {value_code[0]}"] - return ["return"] - - -@dataclass -class TupleExpression(Expression): - """Tuple expression for multiple returns.""" - - elements: list[IRNode] - - def analyze(self, context: ScopeContext) -> None: - """Analyze all elements.""" - for elem in self.elements: - elem.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - """Render tuple expression.""" - elem_codes = [elem.render(context)[0] for elem in self.elements] - return [", ".join(elem_codes)] # No parentheses needed for tuple returns - - -@dataclass -class Block(IRNode): - """Block of statements.""" - - statements: list[Statement] = field(default_factory=list) - - def analyze(self, context: ScopeContext) -> None: - for stmt in self.statements: - stmt.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - lines = [] - for stmt in self.statements: - lines.extend(stmt.render(context)) - return lines - - -@dataclass -class IfStatement(Statement): - """If statement with optional else.""" - - condition: IRNode - then_block: Block - else_block: Block | None = None - - def analyze(self, context: ScopeContext) -> None: - self.condition.analyze(context) - - # Create new scope for then block - then_context = ScopeContext(parent=context) - self.then_block.analyze(then_context) - - if self.else_block: - # Create new scope for else block - else_context = ScopeContext(parent=context) - self.else_block.analyze(else_context) - - def render(self, context: ScopeContext) -> list[str]: - lines = [] - - # Render condition - cond_code = self.condition.render(context) - cond_str = cond_code[0] if len(cond_code) == 1 else "???" - lines.append(f"if {cond_str}:") - - # Render then block (indented) - then_lines = self.then_block.render(context) - if then_lines: - lines.extend(f" {line}" for line in then_lines) - else: - lines.append(" pass") - - # Render else block if present - if self.else_block: - lines.append("else:") - else_lines = self.else_block.render(context) - if else_lines: - lines.extend(f" {line}" for line in else_lines) - else: - lines.append(" pass") - - return lines - - -@dataclass -class WhileStatement(Statement): - """While loop statement.""" - - condition: IRNode - body: Block - - def analyze(self, context: ScopeContext) -> None: - self.condition.analyze(context) - # Create new scope for loop body - loop_context = ScopeContext(parent=context) - self.body.analyze(loop_context) - - def render(self, context: ScopeContext) -> list[str]: - lines = [] - - # Render condition - cond_code = self.condition.render(context) - cond_str = cond_code[0] if len(cond_code) == 1 else "???" - lines.append(f"while {cond_str}:") - - # Render body (indented) - body_lines = self.body.render(context) - if body_lines: - lines.extend(f" {line}" for line in body_lines) - else: - lines.append(" pass") - - return lines - - -@dataclass -class ForStatement(Statement): - """For loop statement.""" - - loop_var: str - iterable: IRNode - body: Block - - def analyze(self, context: ScopeContext) -> None: - # Analyze iterable - self.iterable.analyze(context) - - # Create new scope for loop body with loop variable - loop_context = ScopeContext(parent=context) - # Add loop variable to context (simplified - would need type info) - self.body.analyze(loop_context) - - def render(self, context: ScopeContext) -> list[str]: - lines = [] - - # Render iterable - iter_code = self.iterable.render(context) - iter_str = iter_code[0] if len(iter_code) == 1 else "???" - lines.append(f"for {self.loop_var} in {iter_str}:") - - # Render body (indented) - body_lines = self.body.render(context) - if body_lines: - lines.extend(f" {line}" for line in body_lines) - else: - lines.append(" pass") - - return lines - - -@dataclass -class Function(IRNode): - """Function definition.""" - - name: str - params: list[tuple[str, str]] # [(name, type), ...] - return_type: str - body: Block - decorators: list[str] = field(default_factory=list) - - def analyze(self, context: ScopeContext) -> None: - # Create new scope for function - func_context = ScopeContext(parent=context) - - # Add parameters to scope - for param_name, param_type in self.params: - var_info = VariableInfo( - name=param_name, - original_name=param_name, - var_type=param_type, - ) - func_context.add_variable(var_info) - - # Analyze body - self.body.analyze(func_context) - - def render(self, context: ScopeContext) -> list[str]: - lines = [] - - # Decorators - lines.extend(f"@{decorator}" for decorator in self.decorators) - - # Function signature - param_strs = [] - for name, ptype in self.params: - param_strs.append(f"{name}: {ptype}") - - lines.append(f"def {self.name}({', '.join(param_strs)}) -> {self.return_type}:") - - # Body - body_lines = self.body.render(context) - if body_lines: - lines.extend(f" {line}" for line in body_lines) - else: - lines.append(" pass") - - return lines - - -@dataclass -class Module(IRNode): - """Module containing imports and definitions.""" - - imports: list[str] = field(default_factory=list) - functions: list[Function] = field(default_factory=list) - refreshed_arrays: dict[str, set[str]] = field( - default_factory=dict, - ) # function_name -> set of refreshed array names - - def analyze(self, context: ScopeContext) -> None: - for func in self.functions: - func.analyze(context) - - def render(self, context: ScopeContext) -> list[str]: - lines = [] - - # Imports - lines.extend(self.imports) - - if self.imports and self.functions: - lines.append("") # Blank line between imports and code - - # Functions - for i, func in enumerate(self.functions): - if i > 0: - lines.append("") # Blank line between functions - lines.extend(func.render(context)) - - return lines diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_analyzer.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_analyzer.py deleted file mode 100644 index 954ddf13e..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_analyzer.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Analyzer for determining array unpacking and other transformations needed.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from pecos.slr import Block as SLRBlock - - -@dataclass -class ArrayAccessInfo: - """Information about how an array is accessed.""" - - array_name: str - size: int - is_classical: bool = False # Track if this is a CReg - - # Track individual element accesses - element_accesses: set[int] = field(default_factory=set) - element_access_positions: dict[int, list[int]] = field(default_factory=dict) - - # Track full array accesses - full_array_accesses: list[int] = field(default_factory=list) - - # Track if passed to blocks - passed_to_blocks: bool = False - - # Track operations between accesses - has_operations_between: bool = False - has_conditionals_between: bool = False - - # NEW: Track which specific elements are conditionally accessed - # This is more precise than the boolean flag above - conditionally_accessed_elements: set[int] = field(default_factory=set) - - # Consumption info - elements_consumed: set[int] = field(default_factory=set) - fully_consumed: bool = False - consumed_at_position: int | None = None - - @property - def has_individual_access(self) -> bool: - """Check if array has individual element access.""" - return len(self.element_accesses) > 0 - - @property - def all_elements_accessed(self) -> bool: - """Check if all elements are accessed.""" - return len(self.element_accesses) == self.size - - @property - def needs_unpacking(self) -> bool: - """Determine if this array needs unpacking. - - This uses a rule-based decision tree for clearer, more maintainable logic. - See unpacking_rules.py for the detailed decision tree implementation. - """ - from pecos.slr.gen_codes.guppy.unpacking_rules import should_unpack_array - - return should_unpack_array(self) - - -@dataclass -class UnpackingPlan: - """Plan for unpacking arrays in a scope.""" - - arrays_to_unpack: dict[str, ArrayAccessInfo] = field(default_factory=dict) - unpack_at_start: set[str] = field(default_factory=set) - renamed_variables: dict[str, str] = field(default_factory=dict) - # Store all analyzed arrays, including those that don't need unpacking - all_analyzed_arrays: dict[str, ArrayAccessInfo] = field(default_factory=dict) - - -class IRAnalyzer: - """Analyzes SLR blocks to determine IR transformations needed.""" - - def __init__(self): - self.array_info: dict[str, ArrayAccessInfo] = {} - self.position_counter = 0 - self.in_conditional = False - self.reserved_names = {"result", "array", "quantum", "guppy", "owned"} - self.has_nested_blocks = False - - def analyze_block( - self, - block: SLRBlock, - variable_context: dict[str, Any], - ) -> UnpackingPlan: - """Analyze a block and return unpacking plan.""" - plan = UnpackingPlan() - - # Reset state - self.array_info.clear() - self.position_counter = 0 - - # First, collect array information from variables - self._collect_array_info(block, variable_context) - - # Perform data flow analysis to get precise information - from pecos.slr.gen_codes.guppy.data_flow import DataFlowAnalyzer - - data_flow_analyzer = DataFlowAnalyzer() - data_flow = data_flow_analyzer.analyze(block, variable_context) - - # Analyze operations to determine access patterns - if hasattr(block, "ops"): - for op in block.ops: - self._analyze_operation(op) - self.position_counter += 1 - - # Update array info with data flow analysis results - self._integrate_data_flow(data_flow) - - # Determine which arrays need unpacking - # Special case: if we have nested blocks but @owned parameters, we must unpack - # because @owned parameters require unpacking to access elements - must_unpack_for_owned = hasattr(self, "has_nested_blocks_with_owned") and self.has_nested_blocks_with_owned - - # Store all analyzed arrays in the plan - plan.all_analyzed_arrays = self.array_info.copy() - - if not self.has_nested_blocks or must_unpack_for_owned: - for array_name, info in self.array_info.items(): - should_unpack = info.needs_unpacking - - # Force unpacking for @owned parameters even with nested blocks - if ( - must_unpack_for_owned - and hasattr(self, "expected_owned_params") - and array_name in self.expected_owned_params - ): - should_unpack = True - - if should_unpack: - plan.arrays_to_unpack[array_name] = info - plan.unpack_at_start.add(array_name) - - # Check for variable name conflicts - self._check_name_conflicts(block, plan) - - return plan - - def _collect_array_info( - self, - block: SLRBlock, - variable_context: dict[str, Any], - ) -> None: - """Collect information about arrays in the block.""" - # From block variables - if hasattr(block, "vars"): - for var in block.vars: - var_type = type(var).__name__ - if var_type in ["QReg", "CReg"] and hasattr(var, "sym") and hasattr(var, "size"): - self.array_info[var.sym] = ArrayAccessInfo( - array_name=var.sym, - size=var.size, - is_classical=(var_type == "CReg"), - ) - - # From variable context - if variable_context: - for var_name, var in variable_context.items(): - var_type = type(var).__name__ - if var_type in ["QReg", "CReg"] and hasattr(var, "size") and var_name not in self.array_info: - self.array_info[var_name] = ArrayAccessInfo( - array_name=var_name, - size=var.size, - is_classical=(var_type == "CReg"), - ) - - def _analyze_operation(self, op: Any) -> None: - """Analyze a single operation.""" - op_type = type(op).__name__ - - if op_type == "Measure": - self._analyze_measurement(op) - elif op_type == "If": - self._analyze_if_block(op) - elif hasattr(op, "qargs"): - self._analyze_quantum_operation(op) - elif hasattr(op, "ops"): - # Check if this is a nested Block - if hasattr(op, "__class__"): - from pecos.slr import Block as SlrBlock - - try: - if issubclass(op.__class__, SlrBlock): - # Mark that we have nested blocks - self.has_nested_blocks = True - except (TypeError, AttributeError): - # Not a class or doesn't have expected attributes - pass - - # Nested block - recurse into its operations - for nested_op in op.ops: - self._analyze_operation(nested_op) - - def _analyze_measurement(self, meas: Any) -> None: - """Analyze a measurement operation.""" - # Analyze classical targets if present - if hasattr(meas, "cout") and meas.cout: - for cout in meas.cout: - if hasattr(cout, "reg") and hasattr(cout.reg, "sym"): - array_name = cout.reg.sym - if array_name in self.array_info and hasattr(cout, "index"): - info = self.array_info[array_name] - # Track individual classical element access - info.element_accesses.add(cout.index) - - # Analyze quantum sources - if hasattr(meas, "qargs") and meas.qargs: - for qarg in meas.qargs: - # Handle full array measurement (QReg directly) - if hasattr(qarg, "sym") and hasattr(qarg, "size"): - array_name = qarg.sym - if array_name in self.array_info: - info = self.array_info[array_name] - # Full array measurement - info.full_array_accesses.append(self.position_counter) - info.fully_consumed = True - info.consumed_at_position = self.position_counter - - # Mark all elements as consumed - for i in range(info.size): - info.elements_consumed.add(i) - - # Handle individual element measurement (Qubit with reg) - elif hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - array_name = qarg.reg.sym - if array_name in self.array_info: - info = self.array_info[array_name] - - if hasattr(qarg, "index"): - # Individual element measurement - index = qarg.index - info.element_accesses.add(index) - info.elements_consumed.add(index) - - if index not in info.element_access_positions: - info.element_access_positions[index] = [] - info.element_access_positions[index].append( - self.position_counter, - ) - - def _analyze_quantum_operation(self, op: Any) -> None: - """Analyze a quantum operation (gate, etc.).""" - if hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - array_name = qarg.reg.sym - if array_name in self.array_info: - info = self.array_info[array_name] - - if hasattr(qarg, "index"): - # Individual element access - index = qarg.index - info.element_accesses.add(index) - - if index not in info.element_access_positions: - info.element_access_positions[index] = [] - info.element_access_positions[index].append( - self.position_counter, - ) - - # Check if there are measurements before this - if info.elements_consumed: - info.has_operations_between = True - - def _analyze_if_block(self, if_block: Any) -> None: - """Analyze an if block.""" - prev_conditional = self.in_conditional - self.in_conditional = True - - # Check condition for array accesses - if hasattr(if_block, "cond"): - self._analyze_condition(if_block.cond) - - # Analyze then block - if hasattr(if_block, "ops"): - for op in if_block.ops: - self._analyze_operation(op) - - # Analyze else block - if hasattr(if_block, "else_block") and if_block.else_block and hasattr(if_block.else_block, "ops"): - for op in if_block.else_block.ops: - self._analyze_operation(op) - - self.in_conditional = prev_conditional - - # Mark arrays used in conditionals - if self.in_conditional: - for info in self.array_info.values(): - if info.element_accesses: - info.has_conditionals_between = True - - def _analyze_condition(self, cond: Any) -> None: - """Analyze a condition expression.""" - # Look for array accesses in conditions - cond_type = type(cond).__name__ - - if cond_type == "Bit": - if hasattr(cond, "reg") and hasattr(cond.reg, "sym"): - array_name = cond.reg.sym - if array_name in self.array_info and hasattr(cond, "index"): - info = self.array_info[array_name] - info.element_accesses.add(cond.index) - info.has_conditionals_between = True - - # Handle compound conditions - elif hasattr(cond, "left"): - self._analyze_condition(cond.left) - if hasattr(cond, "right"): - self._analyze_condition(cond.right) - - def _check_name_conflicts(self, block: SLRBlock, plan: UnpackingPlan) -> None: - """Check for variable names that conflict with reserved words.""" - if hasattr(block, "vars"): - for var in block.vars: - if hasattr(var, "sym") and var.sym in self.reserved_names: - # Need to rename this variable - new_name = f"{var.sym}_reg" - plan.renamed_variables[var.sym] = new_name - - def _integrate_data_flow(self, data_flow) -> None: - """Integrate data flow analysis results into array access info. - - This provides more precise information about operations between accesses, - reducing false positives from the heuristic analysis. - - Args: - data_flow: DataFlowAnalysis from the data flow analyzer - """ - from pecos.slr.gen_codes.guppy.data_flow import DataFlowAnalysis - - if not isinstance(data_flow, DataFlowAnalysis): - return - - # For each array element in the data flow analysis - for (array_name, index), flow_info in data_flow.element_flows.items(): - if array_name in self.array_info: - info = self.array_info[array_name] - - # Update has_operations_between with precise data flow information - # Only set to True if THIS SPECIFIC element is used after its own measurement - if flow_info.has_use_after_consumption(): - # Mark that THIS array has operations between for THIS element - # This is more precise than the heuristic which marks the whole array - info.has_operations_between = True - - # Also check conditional accesses from data flow - for array_name, index in data_flow.conditional_accesses: - if array_name in self.array_info: - info = self.array_info[array_name] - # NEW: Track the specific element that is conditionally accessed - info.conditionally_accessed_elements.add(index) - - # Keep the old flag for backward compatibility - # But now we have more precise information in conditionally_accessed_elements - if index in info.element_accesses: - info.has_conditionals_between = True diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_builder.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_builder.py deleted file mode 100644 index 5163a275e..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_builder.py +++ /dev/null @@ -1,9360 +0,0 @@ -"""Builder for converting SLR operations to IR. - -IMPORTANT LIMITATION - Partial Consumption in Loops: -==================================================== - -The current implementation returns ONLY unconsumed array elements from functions. -This works correctly for most patterns, but has a known limitation with certain -verification loop patterns (e.g., Steane code). - -WORKING PATTERN (Partial Consumption): --------------------------------------- -def process_qubits(q: array[quantum.qubit, 4] @owned) -> array[quantum.qubit, 2]: - # Measures q[0] and q[2], returns q[1] and q[3] - # Return type correctly reflects only unconsumed elements - -PROBLEMATIC PATTERN (Verification Ancillas in Loops): ------------------------------------------------------ -def verify(ancilla: array[qubit, 3] @owned) -> tuple[array[qubit, 2], ...]: - # Measures ancilla[0], creates fresh qubit at ancilla[0] - # Returns ONLY ancilla[1] and ancilla[2] (unconsumed elements) - # Fresh qubit is NOT returned (it's an automatic replacement for linearity) - -# In calling function: -for _ in range(2): - ancilla_returned = verify(ancilla) # ERROR: Returns size 2, needs size 3 - -WHY THIS HAPPENS: -- Automatic qubit replacements (lines 2966-2977) are created for Guppy's linear - type system, not for meaningful quantum operations -- The replacement qubit is not semantically part of the verification result -- Only unconsumed elements (ancilla[1], ancilla[2]) are returned -- This creates a size mismatch in subsequent loop iterations - -ARCHITECTURAL SOLUTIONS: -- Don't use partial consumption for verification ancillas that need reuse -- Use separate ancilla qubits instead of array elements for verification -- Or restructure the verification pattern to avoid the loop issue - -See tests/slr_tests/guppy/test_partial_array_returns.py for correct usage patterns. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, ClassVar - -if TYPE_CHECKING: - from pecos.slr import Block as SLRBlock - from pecos.slr.gen_codes.guppy.ir import IRNode - from pecos.slr.gen_codes.guppy.ir_analyzer import UnpackingPlan - from pecos.slr.gen_codes.guppy.unified_resource_planner import ( - UnifiedResourceAnalysis, - ) - -# AllocationOptimizer removed - now using UnifiedResourceAnalysis directly -from pecos.slr.gen_codes.guppy.ir import ( - ArrayAccess, - ArrayUnpack, - Assignment, - BinaryOp, - Block, - Comment, - Expression, - FieldAccess, - ForStatement, - Function, - FunctionCall, - IfStatement, - Literal, - Measurement, - Module, - ResourceState, - ReturnStatement, - ScopeContext, - Statement, - TupleExpression, - UnaryOp, - VariableInfo, - VariableRef, - WhileStatement, -) -from pecos.slr.gen_codes.guppy.qubit_usage_analyzer import QubitRole, QubitUsageAnalyzer -from pecos.slr.gen_codes.guppy.scope_manager import ( - ResourceUsage, - ScopeManager, - ScopeType, -) - - -class IRBuilder: - """Builds IR from SLR operations.""" - - # Core blocks that should remain as control flow (not converted to functions) - CORE_BLOCKS: ClassVar[set[str]] = { - "If", - "Repeat", - "While", - "For", - "Main", - "Block", - "Comment", - "Barrier", - } - - def __init__( - self, - unpacking_plan: UnpackingPlan, - *, - unified_analysis: UnifiedResourceAnalysis | None = None, - include_optimization_report: bool = False, - ): - self.plan = unpacking_plan - self.unified_analysis = unified_analysis - self.context = ScopeContext() - self.scope_manager = ScopeManager() - self.current_block: Block | None = None - # AllocationOptimizer removed - using UnifiedResourceAnalysis directly - self.include_optimization_report = include_optimization_report - - # Track arrays that have been refreshed by function calls - # Maps original array name -> fresh returned name - self.refreshed_arrays = {} - # Track which function refreshed each array - # Maps original array name -> function name that refreshed it - self.refreshed_by_function = {} - - # Track conditionally consumed variables (e.g., in if blocks) - # Maps original variable -> conditionally consumed version - self.conditional_fresh_vars = {} - - # Track blocks for function generation - self.block_registry = {} # Maps block signature to function name - self.pending_functions = [] # Functions to be generated - self.generated_functions = set() # Functions already generated (actually built) - self.discovered_functions = set() # Functions discovered but maybe not built yet - self.function_counter = 0 # For generating unique function names - self.function_info = {} # Track metadata about functions - self.function_return_types = {} # Maps function name to return type - - # Struct generation tracking - self.struct_info = {} # Maps prefix -> {fields: [(suffix, type, size)], struct_name: str} - - # Track all used variable names to avoid conflicts - self.used_var_names = set() - - # Track explicit Prep (reset) operations for return type calculation - # Maps array_name -> set of indices that were explicitly reset - self.explicitly_reset_qubits = {} - - # Variable remapping for handling measurement+Prep pattern - # Maps old_name -> new_name for variables that need fresh names - self.variable_remapping: dict[str, str] = {} - # Track version numbers for generating unique variable names - self.variable_version_counter: dict[str, int] = {} - - # Unified variable-state tracking: replaces ad-hoc dicts like - # `unpacked_vars`, `refreshed_arrays`, etc. (See variable_state.py - # for rationale.) Migration is incremental -- legacy dicts still - # populated, this object is consulted at sites that need a coherent - # view of "what Guppy form is this SLR variable in right now?". - from pecos.slr.gen_codes.guppy.variable_state import VariableState - - self.var_state = VariableState() - - def _get_unique_var_name(self, base_name: str, index: int | None = None) -> str: - """Generate a unique variable name that doesn't conflict with existing names. - - Args: - base_name: The base name for the variable - index: Optional index to append to the base name - - Returns: - A unique variable name - """ - candidate = f"{base_name}_{index}" if index is not None else base_name - - # If the name doesn't conflict, use it - if candidate not in self.used_var_names: - self.used_var_names.add(candidate) - return candidate - - # Add underscores until we find a unique name - while candidate in self.used_var_names: - candidate = f"_{candidate}" - - self.used_var_names.add(candidate) - return candidate - - def _collect_var_names(self, block) -> None: - """Collect all variable names from a block to avoid conflicts.""" - if hasattr(block, "vars"): - for var in block.vars: - if hasattr(var, "sym"): - self.used_var_names.add(var.sym) - # Also check ops recursively - if hasattr(block, "ops"): - for op in block.ops: - if hasattr(op, "__class__") and op.__class__.__name__ in [ - "Main", - "Block", - ]: - self._collect_var_names(op) - - def build_module(self, main_block: SLRBlock, pending_functions: list) -> Module: - """Build a complete module from SLR.""" - module = Module() - - # Collect all existing variable names to avoid conflicts - self._collect_var_names(main_block) - - # Allocation analysis now comes from UnifiedResourceAnalysis - # (passed via unified_analysis parameter) - - # Analyze qubit usage to identify ancillas - qubit_analyzer = QubitUsageAnalyzer() - self.qubit_usage_stats = qubit_analyzer.analyze_block(main_block) - - # Detect and analyze struct patterns (will use qubit usage stats) - self._detect_struct_patterns(main_block) - - # Add imports (including functional quantum operations for Array Unpacking Pattern) - module.imports = [ - "from __future__ import annotations", - "", - "from typing import no_type_check", - "", - "from guppylang.decorator import guppy", - "from guppylang.std import quantum", - "from guppylang.std.quantum import qubit", - "from guppylang.std.quantum.functional import (reset, h, x, y, z, s, t, sdg, tdg, cx, cy, cz)", - "from guppylang.std.builtins import array, owned, result, py", - ] - - # Generate struct definitions after imports - if self.struct_info: - module.imports.append("") - struct_defs = self._generate_struct_definitions() - module.imports.extend(struct_defs) - - # Add optimization report as comments (only if requested) - if self.include_optimization_report and self.unified_analysis: - # Use unified resource planning report (comprehensive) - report = self.unified_analysis.get_report() - module.imports.extend( - [ - "", - *["# " + line for line in report.split("\n") if line.strip()], - ], - ) - - # Build main function - main_func = self.build_main_function(main_block) - module.functions.append(main_func) - # Store refreshed arrays for main function - module.refreshed_arrays["main"] = self.refreshed_arrays.copy() - # Also store which functions refreshed each array in main - if not hasattr(module, "refreshed_by_function_map"): - module.refreshed_by_function_map = {} - module.refreshed_by_function_map["main"] = self.refreshed_by_function.copy() - - # Generate helper functions for structs - for prefix, info in self.struct_info.items(): - # Generate decompose function (always needed for cleanup) - decompose_func = self._generate_struct_decompose_function(prefix, info) - if decompose_func: - module.functions.append(decompose_func) - - # Also generate discard function (useful for other contexts) - discard_func = self._generate_struct_discard_function(prefix, info) - if discard_func: - module.functions.append(discard_func) - - # Build any pending functions (from both parameter and internal tracking) - all_pending = list(pending_functions) + self.pending_functions - while all_pending: - func_info = all_pending.pop(0) - func = self.build_function(func_info) - if func: - module.functions.append(func) - # Mark this function as generated - if len(func_info) >= 2: - self.generated_functions.add(func_info[1]) - # Store refreshed arrays for this function - module.refreshed_arrays[func_info[1]] = self.refreshed_arrays.copy() - # Also store which functions refreshed each array - if not hasattr(module, "refreshed_by_function_map"): - module.refreshed_by_function_map = {} - module.refreshed_by_function_map[func_info[1]] = self.refreshed_by_function.copy() - # Check if building this function added more pending functions - # Add any new pending functions, avoiding duplicates - for new_func in self.pending_functions: - _new_block, new_name, _new_sig = new_func - # Check if this function is already built or pending - already_pending = any(f[1] == new_name for f in all_pending if len(f) >= 2) - if new_name not in self.generated_functions and not already_pending: - all_pending.append(new_func) - self.pending_functions = [] - - # SECOND PASS: Correct return types for functions that return values from other functions - # This is needed because nested functions are built after their parents - self._correct_return_types_from_called_functions(module) - - return module - - def _correct_return_types_from_called_functions(self, module): - """Correct return types for functions that return values from other functions. - - This is a second pass needed because nested functions are built after their parents, - so when calculating the parent's return type, the nested function's return type - isn't available yet. - """ - - # For each function, check if it needs return type correction - for func in module.functions: - if func.name == "main": - continue # Skip main function - - # Check if this function has refreshed_by_function mappings - if func.name not in module.refreshed_arrays: - continue - - func_refreshed_arrays = module.refreshed_arrays[func.name] - if not func_refreshed_arrays: - continue - - # We need to check if this function's return type should be corrected - # by looking at which functions refreshed its arrays - # For now, we'll use a simpler approach: check if the return type - # involves arrays that were refreshed by other functions - - # Parse the current return type - current_return_type = func.return_type - if current_return_type == "None": - continue # Procedural function, no correction needed - - # Get the refreshed_by_function mapping for this function - if not hasattr(module, "refreshed_by_function_map"): - continue - if func.name not in module.refreshed_by_function_map: - continue - - func_refreshed_by_function = module.refreshed_by_function_map[func.name] - if not func_refreshed_by_function: - continue - - # For functions returning tuples, we need to check each element - if current_return_type.startswith("tuple["): - import re - - tuple_match = re.match(r"tuple\[(.*)\]", current_return_type) - if tuple_match: - # Get parameter names from function params (quantum arrays only) - param_names = [p[0] for p in func.params if "array[quantum.qubit," in p[1]] - - # For each quantum parameter, check if it was refreshed by a function - corrected_types = [] - for param_name in param_names: - if param_name in func_refreshed_by_function: - func_info = func_refreshed_by_function[param_name] - # Extract function name from the dict (or handle legacy string format) - called_func_name = ( - func_info["function"] - if isinstance(func_info, dict) - else func_info # Legacy string format - ) - - # Look up the called function's return type - if called_func_name in self.function_return_types: - called_return_type = self.function_return_types[called_func_name] - - # If the called function returns a tuple, extract the type for this param - if called_return_type.startswith("tuple["): - tuple_match2 = re.match( - r"tuple\[(.*)\]", - called_return_type, - ) - if tuple_match2: - called_types_str = tuple_match2.group(1) - # Parse the types (handling nested brackets) - types_list = [] - bracket_depth = 0 - current_type = "" - for char in called_types_str: - if char == "[": - bracket_depth += 1 - current_type += char - elif char == "]": - bracket_depth -= 1 - current_type += char - elif char == "," and bracket_depth == 0: - types_list.append(current_type.strip()) - current_type = "" - else: - current_type += char - if current_type: - types_list.append(current_type.strip()) - - # Find which position this param is in - param_idx = param_names.index(param_name) - if param_idx < len(types_list): - corrected_types.append( - types_list[param_idx], - ) - else: - # Fallback: use current type - corrected_types.append(None) - else: - corrected_types.append(None) - else: - # Single return - use it directly if this is the only param - if len(param_names) == 1: - corrected_types.append(called_return_type) - else: - corrected_types.append(None) - else: - corrected_types.append(None) - else: - corrected_types.append(None) - - # If we have corrections, update the function's return type - if any(ct is not None for ct in corrected_types): - # Parse current types - current_types_str = tuple_match.group(1) - current_types_list = [] - bracket_depth = 0 - current_type = "" - for char in current_types_str: - if char == "[": - bracket_depth += 1 - current_type += char - elif char == "]": - bracket_depth -= 1 - current_type += char - elif char == "," and bracket_depth == 0: - current_types_list.append(current_type.strip()) - current_type = "" - else: - current_type += char - if current_type: - current_types_list.append(current_type.strip()) - - # Apply corrections - new_types = [] - for i, corrected in enumerate(corrected_types): - if corrected is not None: - new_types.append(corrected) - elif i < len(current_types_list): - new_types.append(current_types_list[i]) - else: - # Something went wrong, skip correction - new_types = None - break - - if new_types: - new_return_type = f"tuple[{', '.join(new_types)}]" - func.return_type = new_return_type - # Also update the registry - self.function_return_types[func.name] = new_return_type - - def build_main_function(self, block: SLRBlock) -> Function: - """Build the main function.""" - # Set current function name - self.current_function_name = "main" - - # Reset function-local state - self.refreshed_arrays = {} - self.refreshed_by_function = {} - self.conditional_fresh_vars = {} - self.array_remapping = {} # Reset array remapping for main function - - # Analyze qubit usage patterns - usage_analyzer = QubitUsageAnalyzer() - usage_analyzer.analyze_block(block, self.struct_info) - self.allocation_recommendations = usage_analyzer.get_allocation_recommendations() - - # Pre-analyze explicit reset operations (Prep) to distinguish them from automatic replacements - consumed_in_main = {} - self._track_consumed_qubits(block, consumed_in_main) - - # Override allocation recommendations for struct fields to ensure they're pre-allocated - # (struct constructors need all fields to be available) - if self.struct_info: - for prefix, info in self.struct_info.items(): - for suffix, _, _ in info["fields"]: - var_name = info["var_names"][suffix] - # Override the allocation recommendations system - if var_name in self.allocation_recommendations: - recommendation = self.allocation_recommendations[var_name] - if recommendation.get("allocation") == "dynamic": - # Override dynamic allocation for struct fields - self.allocation_recommendations[var_name] = { - "allocation": "pre_allocate", - "reason": "Struct field requires pre-allocation", - "keep_packed": recommendation.get("keep_packed", True), - "pre_allocate": True, - } - - body = Block() - self.current_block = body - - # Track arrays consumed by @owned function calls - self.consumed_arrays = set() - - # Add variable declarations - if hasattr(block, "vars"): - # First, add non-struct variables - struct_vars = set() - for prefix, info in self.struct_info.items(): - struct_vars.update(info["var_names"].values()) - - # Get ancilla variables that were excluded from structs - ancilla_vars = getattr(self, "ancilla_qubits", set()) - - for var in block.vars: - if hasattr(var, "sym"): - # Add if not in struct OR if it's an ancilla (which was excluded from struct) - if var.sym not in struct_vars or var.sym in ancilla_vars: - self._add_variable_declaration(var, block) - - # Add to scope context for resource tracking - var_type = type(var).__name__ - if var_type in ["QReg", "CReg"]: - is_quantum = var_type == "QReg" - size = getattr(var, "size", None) - - var_info = VariableInfo( - name=var.sym, - original_name=var.sym, - var_type="quantum" if is_quantum else "classical", - size=size, - is_array=True, - ) - self.context.add_variable(var_info) - - # Then, create struct instances - for prefix, info in self.struct_info.items(): - self._add_struct_initialization(prefix, info, block) - - # Main function maintains natural SLR array semantics - # Arrays are only unpacked internally when needed for selective measurements - - # Track unpacked vars for main - self.unpacked_vars = {} - - # First pass: determine which quantum arrays will be unpacked - will_unpack_quantum = set() - for array_name in self.plan.unpack_at_start: - if array_name in self.plan.arrays_to_unpack: - info = self.plan.arrays_to_unpack[array_name] - - # Skip struct fields - is_struct_field = False - if self.struct_info: - for prefix, struct_info in self.struct_info.items(): - if array_name in struct_info.get("var_names", {}).values(): - is_struct_field = True - break - - if is_struct_field: - continue - - # Skip dynamically allocated arrays - if hasattr(self, "dynamic_allocations") and array_name in self.dynamic_allocations: - continue - - # Mark quantum arrays that will be unpacked - if not info.is_classical: - will_unpack_quantum.add(array_name) - - # Second pass: actually unpack arrays - for array_name in self.plan.unpack_at_start: - if array_name in self.plan.arrays_to_unpack: - info = self.plan.arrays_to_unpack[array_name] - - # Skip unpacking for arrays that are struct fields - # (already consumed by struct constructor) - is_struct_field = False - if self.struct_info: - for prefix, struct_info in self.struct_info.items(): - if array_name in struct_info.get("var_names", {}).values(): - is_struct_field = True - break - - if is_struct_field: - # Skip unpacking - array is consumed by struct constructor - # Individual elements can be accessed via struct decomposition - self.current_block.statements.append( - Comment( - f"Skip unpacking {array_name} - consumed by struct constructor", - ), - ) - continue - - # For dynamically allocated arrays, skip unpacking - qubits are allocated on first use - if hasattr(self, "dynamic_allocations") and array_name in self.dynamic_allocations: - # Don't unpack - the array doesn't exist, qubits are allocated individually - continue - if not info.is_classical: - # Regular unpacking for quantum arrays - self.current_block.statements.append( - Comment(f"Unpack {array_name} for individual access"), - ) - self._add_array_unpacking(array_name, info.size) - else: - # For classical arrays, unpack if any quantum array is unpacked - # This ensures consistent variable naming patterns - should_unpack_classical = len(will_unpack_quantum) > 0 or ( - hasattr(self, "dynamic_allocations") and len(self.dynamic_allocations) > 0 - ) - if should_unpack_classical: - # Unpack classical array to support quantum unpacking pattern - self.current_block.statements.append( - Comment( - f"Unpack {array_name} for individual measurement results", - ), - ) - self._add_array_unpacking(array_name, info.size) - else: - # Skip unpacking classical arrays in main to avoid linearity violations - # Classical arrays can be accessed directly and passed to functions - self.current_block.statements.append( - Comment( - f"Skip unpacking classical array {array_name} - not needed for linearity", - ), - ) - - # Add operations - if hasattr(block, "ops"): - # Store block reference for look-ahead in operation conversion - self.current_block_ops = block.ops - for op_index, op in enumerate(block.ops): - # Store current operation index for look-ahead - self.current_op_index = op_index - stmt = self._convert_operation(op) - if stmt: - body.statements.append(stmt) - # Clear after processing - self.current_block_ops = None - self.current_op_index = None - - # Handle struct decomposition, results, and cleanup - self._add_final_handling(block) - - return Function( - name="main", - params=[], - return_type="None", - body=body, - decorators=["guppy", "no_type_check"], - ) - - def build_function(self, func_info) -> Function | None: - """Build a function from pending function info.""" - - # Reset function-local state - self.refreshed_arrays = {} - self.refreshed_by_function = {} - self.conditional_fresh_vars = {} - self.array_remapping = {} # Reset array remapping for each function - # Reset parameter_unpacked_arrays for each function - self.parameter_unpacked_arrays = set() - # Reset explicitly_reset_qubits for each function to prevent cross-contamination - self.explicitly_reset_qubits = {} - - # Handle different formats of func_info - if len(func_info) == 3: - # New format from IR builder: (block, func_name, signature) - sample_block, func_name, _block_signature = func_info - elif len(func_info) == 4: - # Old format: (block_key, func_name, sample_block, block_name) - _block_key, func_name, sample_block, _block_name = func_info - else: - return None - - # Analyze dependencies to determine parameters - deps = self._analyze_block_dependencies(sample_block) - - # Build parameter list - params = [] - param_mapping = {} # Maps parameter names to original variable names - - # Check if we should use structs instead of individual arrays - struct_params = set() # Structs we've already added - vars_in_structs = set() # Variables that are part of structs - - # First pass: identify which variables are part of structs - for prefix, info in self.struct_info.items(): - vars_in_this_struct = [] - for var in info["var_names"].values(): - if var in deps["quantum"] or var in deps["classical"]: - vars_in_structs.add(var) - vars_in_this_struct.append(var) - - # If any variable from this struct is used, add the struct as a parameter - if vars_in_this_struct and prefix not in struct_params: - # Add struct parameter - struct_name = info["struct_name"] - param_type = struct_name - - # Check if this struct contains quantum resources - has_quantum = any(v in deps["quantum"] for v in vars_in_this_struct) - if has_quantum and self._block_consumes_quantum(sample_block): - param_type = f"{param_type} @owned" - - params.append((prefix, param_type)) - param_mapping[prefix] = prefix - struct_params.add(prefix) - - # Black Box Pattern: All functions that handle quantum arrays should use - # functional pattern. This maintains SLR's global array semantics at - # boundaries while using functional internals - # BUT: Only unpack if the IR analyzer determined it's necessary - # First, run the IR analyzer on this block to get unpacking plan - from pecos.slr.gen_codes.guppy.ir_analyzer import IRAnalyzer - - # Pre-analyze consumption to inform the IR analyzer about @owned parameters - consumed_params = set() - if hasattr(sample_block, "ops"): - # Check if this function has nested blocks - has_nested_blocks = False - for op in sample_block.ops: - if hasattr(op, "__class__"): - from pecos.slr import Block as SlrBlock - - try: - if issubclass(op.__class__, SlrBlock): - has_nested_blocks = True - break - except (TypeError, AttributeError): - # Not a class or doesn't have required attributes - pass - - # Analyze consumption - this will help determine @owned parameters - consumed_params = self._analyze_consumed_parameters(sample_block) - # Also analyze which arrays have subscript access - they also need @owned - subscripted_params = self._analyze_subscript_access(sample_block) - # Store for later use in @owned determination - self.subscripted_params = subscripted_params - else: - # No ops - initialize empty set - self.subscripted_params = set() - - analyzer = IRAnalyzer() - - # Pass information about expected @owned parameters to the analyzer - analyzer.expected_owned_params = consumed_params - analyzer.has_nested_blocks_with_owned = has_nested_blocks and bool( - consumed_params, - ) - - block_plan = analyzer.analyze_block(sample_block, self.context.variables) - - # Only unpack if there are arrays that need unpacking according to the analyzer - needs_unpacking = len(block_plan.arrays_to_unpack) > 0 - - # Check if this function consumes its quantum arrays - # For the functional pattern in Guppy, all functions that take quantum arrays - # and will return them need @owned annotation - self._block_consumes_quantum(sample_block) - - # If the function has quantum parameters, it should use @owned - # This is required for Guppy's linearity system when arrays are returned - bool(deps["quantum"] & deps["reads"]) - - # Add quantum parameters (skip those in structs UNLESS they're ancillas) - for var in sorted(deps["quantum"] & deps["reads"]): - # Check if this is an ancilla that was excluded from structs - is_excluded_ancilla = hasattr(self, "ancilla_qubits") and var in self.ancilla_qubits - - if var in vars_in_structs and not is_excluded_ancilla: - continue - param_name = var # Use the same name, no need for _param suffix - param_mapping[param_name] = var - # Determine type from context or default to qubit array - var_info = self.context.lookup_variable(var) - if var_info: - if var_info.is_unpacked: - # This is an unpacked array - need the original array type - param_type = f"array[quantum.qubit, {var_info.size}]" - else: - # Always use array type to maintain consistency with SLR semantics - param_type = f"array[quantum.qubit, {var_info.size}]" - else: - # Default assumption for quantum variables - param_type = "array[quantum.qubit, 7]" - - params.append((param_name, param_type)) - - # Add classical parameters (no ownership, but include written vars - # since arrays are mutable) - for var in sorted(deps["classical"] & (deps["reads"] | deps["writes"])): - if var in vars_in_structs: - continue - param_name = var # Use the same name, no need for _param suffix - param_mapping[param_name] = var - # Determine type from context - var_info = self.context.lookup_variable(var) - # Always use array type for consistency - param_type = f"array[bool, {var_info.size}]" if var_info else "array[bool, 32]" - params.append((param_name, param_type)) - - # Create function body - body = Block() - prev_block = self.current_block - prev_mapping = self.param_mapping if hasattr(self, "param_mapping") else {} - self.current_block = body - self.param_mapping = param_mapping - - # Create a variable remapping context for this function - # This maps original variable names to their parameter names - var_remapping = {} - for param_name, original_name in param_mapping.items(): - var_remapping[original_name] = param_name - - # Also handle unpacked variables - var_info = self.context.lookup_variable(original_name) - if var_info and var_info.is_unpacked: - # Map each unpacked element - for i, unpacked_name in enumerate(var_info.unpacked_names): - var_remapping[unpacked_name] = f"{param_name}[{i}]" - - # Store current function context - self.current_function_name = func_name - self.current_function_params = params - self.current_function_return_type = None # Will be set after we determine it - - # Clear fresh_return_vars tracking for this new function - # (to avoid bleeding from previous function builds) - self.fresh_return_vars = {} - - # Track if this function has @owned struct parameters - has_owned_struct_params = any( - "@owned" in param_type and param_name in self.struct_info for param_name, param_type in params - ) - self.function_info[func_name] = { - "has_owned_struct_params": has_owned_struct_params, - "params": params, - } - - # Store the remapping for use during conversion - prev_var_remapping = getattr(self, "var_remapping", {}) - self.var_remapping = var_remapping - - # Track unpacked variables (only if needed) - self.unpacked_vars = {} # Maps array_name -> [element_names] - self.replaced_qubits = {} # Maps array_name -> set of replaced indices - - # Initially add array unpacking for arrays that the analyzer determined need it - if needs_unpacking: - for param_name, param_type in params: - if "array[quantum.qubit," in param_type and param_name in block_plan.arrays_to_unpack: - # Extract array size - import re - - match = re.search(r"array\[quantum\.qubit, (\d+)\]", param_type) - if match: - size = int(match.group(1)) - # Generate unpacked variable names - element_names = [self._get_unique_var_name(param_name, i) for i in range(size)] - self.unpacked_vars[param_name] = element_names - - # Add unpacking statement to function body - unpacking_stmt = self._create_array_unpack_statement( - param_name, - element_names, - ) - body.statements.append(unpacking_stmt) - - # Additionally, check for ALL @owned arrays that need unpacking - # With the functional pattern, @owned arrays must be unpacked to avoid MoveOutOfSubscriptError - # UNLESS they're passed to nested blocks - for param_name, param_type in params: - if "@owned" in param_type and "array[quantum.qubit," in param_type and param_name not in self.unpacked_vars: - # Check if this function has any nested block calls - # If so, we can't unpack @owned arrays as we may need to pass them - # But this will cause MoveOutOfSubscriptError, so we need a different approach - has_nested_blocks = False - if hasattr(sample_block, "ops"): - for op in sample_block.ops: - # Check if this is a Block subclass - if hasattr(op, "__class__"): - from pecos.slr import Block as SlrBlock - - try: - if issubclass(op.__class__, SlrBlock): - has_nested_blocks = True - break - except (TypeError, AttributeError): - # Not a class or doesn't have required attributes - pass - - # @owned parameters MUST be unpacked regardless of analyzer decision - # This is required by Guppy's type system to avoid MoveOutOfSubscriptError - force_unpack = "@owned" in param_type - - # Check if the analyzer decided this array should be unpacked - # Even with nested blocks, @owned arrays need unpacking to access elements - if not force_unpack and param_name not in block_plan.arrays_to_unpack: - if has_nested_blocks: - body.statements.append( - Comment( - f"Skip unpacking {param_name} - function has nested blocks", - ), - ) - continue - - # This @owned array needs unpacking to avoid MoveOutOfSubscriptError - import re - - match = re.search(r"array\[quantum\.qubit, (\d+)\]", param_type) - if match: - size = int(match.group(1)) - # Generate unpacked variable names - element_names = [self._get_unique_var_name(param_name, i) for i in range(size)] - self.unpacked_vars[param_name] = element_names - - # Track that this was unpacked from a parameter (not a return value) - # Parameter-unpacked arrays should NOT be reconstructed for function calls - if not hasattr(self, "parameter_unpacked_arrays"): - self.parameter_unpacked_arrays = set() - self.parameter_unpacked_arrays.add(param_name) - - # Add comment explaining why we're unpacking - body.statements.append( - Comment( - f"Unpack @owned array {param_name} to avoid MoveOutOfSubscriptError", - ), - ) - - # Add unpacking statement to function body - unpacking_stmt = ArrayUnpack( - source=param_name, - targets=element_names, - ) - body.statements.append(unpacking_stmt) - - # Add struct unpacking for struct parameters - struct_field_vars = {} # Maps original var name to struct field path for @owned structs - struct_reconstruction = {} # Maps struct param name to list of field vars for reconstruction - - for param_name, param_type in params: - if "@owned" in param_type and param_name in self.struct_info: - # This is an @owned struct parameter - # For @owned structs, we must decompose them immediately to avoid AlreadyUsedError - # when accessing multiple fields - struct_info = self.struct_info[param_name] - - # Track that we have an owned struct - if not hasattr(self, "owned_structs"): - self.owned_structs = set() - self.owned_structs.add(param_name) - - # Decompose the @owned struct using the decompose function - # Use the struct name, not the parameter name (e.g., steane_decompose not c_decompose) - struct_name = struct_info["struct_name"].replace("_struct", "") - decompose_func_name = f"{struct_name}_decompose" - - # Create decomposition call - field_vars = [] - for suffix, field_type, field_size in sorted(struct_info["fields"]): - field_var = f"{param_name}_{suffix}" - field_vars.append(field_var) - - # Add comment explaining decomposition - body.statements.append( - Comment( - f"Decompose @owned struct {param_name} to avoid AlreadyUsedError", - ), - ) - - # Add decomposition statement: c_c, c_d, ... = steane_decompose(c) - class TupleAssignment(Statement): - def __init__(self, targets, value): - self.targets = targets - self.value = value - - def analyze(self, context): - self.value.analyze(context) - - def render(self, context): - target_str = ", ".join(self.targets) - value_str = self.value.render(context)[0] - return [f"{target_str} = {value_str}"] - - decompose_call = FunctionCall( - func_name=decompose_func_name, - args=[VariableRef(param_name)], - ) - - decomposition_stmt = TupleAssignment( - targets=field_vars, - value=decompose_call, - ) - body.statements.append(decomposition_stmt) - - # Map original variables to the decomposed field variables - for suffix, field_type, field_size in sorted(struct_info["fields"]): - original_var = struct_info["var_names"].get(suffix) - if original_var: - field_var = f"{param_name}_{suffix}" - # Map the original variable name to the decomposed variable - if not hasattr(self, "var_remapping"): - self.var_remapping = {} - self.var_remapping[original_var] = field_var - - # Track the field variables for reconstruction in return statements - struct_reconstruction[param_name] = field_vars - - # Track decomposed variables for field access - if not hasattr(self, "decomposed_vars"): - self.decomposed_vars = {} - field_mapping = {} - for suffix, field_type, field_size in sorted(struct_info["fields"]): - field_var = f"{param_name}_{suffix}" - field_mapping[suffix] = field_var - self.decomposed_vars[param_name] = field_mapping - - # Skip normal unpacking for @owned structs - continue - if param_name in self.struct_info: - # Non-owned struct parameter - can unpack normally - struct_info = self.struct_info[param_name] - field_vars = [] - - # Generate unpacking statement - use same order as struct - # definition (sorted by suffix) - unpack_targets = [] - for suffix, field_type, field_size in sorted(struct_info["fields"]): - field_var = f"{param_name}_{suffix}" - unpack_targets.append(field_var) - field_vars.append(field_var) - - # Map the original variable name to this unpacked field variable - original_var = struct_info["var_names"].get(suffix) - if original_var: - struct_field_vars[original_var] = field_var - # Also update var_remapping to use field access directly - self.var_remapping[original_var] = field_var - - # Create the unpacking statement: - # field1, field2, ... = struct.field1, struct.field2, ... - # In Guppy, we need to unpack the entire struct at once - - # use same order as struct definition - unpack_stmt = Assignment( - target=TupleExpression( - [VariableRef(var) for var in unpack_targets], - ), - value=TupleExpression( - [FieldAccess(VariableRef(param_name), field) for field, _, _ in sorted(struct_info["fields"])], - ), - ) - body.statements.append(unpack_stmt) - - # Store for reconstruction - struct_reconstruction[param_name] = field_vars - - # Store struct field mappings for use in variable references - self.struct_field_mapping = struct_field_vars - - # Pre-analyze what qubits will be consumed to determine return type - consumed_in_function = {} - self._track_consumed_qubits(sample_block, consumed_in_function) - - # Pre-determine if this function will return quantum arrays - # (needed for measurement replacement logic) - will_return_quantum = False - has_quantum_arrays = any("array[quantum.qubit," in ptype for name, ptype in params) - has_structs = any(name in self.struct_info for name, ptype in params) - - if has_quantum_arrays or has_structs: - # Check if any quantum arrays will be returned - for name, ptype in params: - if "array[quantum.qubit," in ptype: - # Check if this array is part of a struct - in_struct = False - for prefix, info in self.struct_info.items(): - if name in info["var_names"].values(): - in_struct = True - break - - # Check if this is an ancilla that was excluded from structs - is_excluded_ancilla = hasattr(self, "ancilla_qubits") and name in self.ancilla_qubits - - # Check if this array has any live qubits - if name in consumed_in_function: - # Some elements were consumed - check if any are still live - consumed_indices = consumed_in_function[name] - import re - - size_match = re.search( - r"array\[quantum\.qubit,\s*(\d+)\]", - ptype, - ) - array_size = int(size_match.group(1)) if size_match else 2 - total_indices = set(range(array_size)) - live_indices = total_indices - consumed_indices - include_array = bool( - live_indices, - ) # Only include if has live qubits - else: - # No consumption tracked for this array - assume it's live - include_array = not in_struct or is_excluded_ancilla - - if include_array: - will_return_quantum = True - break - - # Check if this is a procedural block based on resource flow - # If the block has live qubits that should be returned, it's not procedural - _consumed_qubits, live_qubits = self._analyze_quantum_resource_flow( - sample_block, - ) - has_live_qubits = bool(live_qubits) - is_procedural_block = not has_live_qubits - - # SMART DETECTION: Determine if this function should be procedural based on usage patterns - # Functions should be procedural if: - # 1. They don't need their quantum returns to be used afterward in the calling scope - # 2. They primarily do terminal operations (measurements, cleanup) - # 3. Making them procedural would avoid PlaceNotUsedError issues - - # HYBRID APPROACH: Use smart detection to determine optimal strategy - should_be_procedural = self._should_function_be_procedural( - func_name, - sample_block, - params, - has_live_qubits, - ) - - if should_be_procedural: - is_procedural_block = True - # Function determined to be procedural - - # If it appears to be procedural based on live qubits, double-check with signature - if is_procedural_block and hasattr(sample_block, "__init__"): - import inspect - - try: - sig = inspect.signature(sample_block.__class__.__init__) - return_annotation = sig.return_annotation - if return_annotation is None or return_annotation is type(None) or str(return_annotation) == "None": - is_procedural_block = True - else: - is_procedural_block = False # Has return annotation, not procedural - except (ValueError, TypeError, AttributeError): - # Default to procedural if can't inspect signature - # ValueError: signature cannot be determined - # TypeError: object is not callable - # AttributeError: missing expected attributes - is_procedural_block = True - - # Store whether this is a procedural block for measurement logic - self.current_function_is_procedural = is_procedural_block - - # Process params and add @owned annotations (now that we know if it's procedural) - # HYBRID OWNERSHIP: Smart @owned annotation based on function type and consumption - processed_params = [] - for param_name, param_type in params: - if "array[quantum.qubit," in param_type: - # Determine if this parameter should be @owned based on consumption analysis - should_be_owned = False - - if is_procedural_block: - # For procedural blocks, be selective with @owned - # Only use @owned if the parameter is truly consumed (measured) and not reused - # BUT also check if this parameter is passed to other functions that might expect @owned - # This is necessary for functions like prep_rus that pass parameters to prep_encoding_ft_zero - # For simplicity, if the block has nested blocks, make quantum params @owned - # If a procedural block calls other blocks, those blocks might need @owned params - should_be_owned = True if has_nested_blocks else param_name in consumed_params - else: - # For functional blocks that return quantum arrays, check if parameter is actually consumed - # In Guppy's linear type system: - # - @owned: parameter is consumed by the function - # - non-@owned: parameter is borrowed and must be returned - # IMPORTANT: In Guppy, subscripting an array (c_a[0]) marks it as used - # So ANY element access requires @owned annotation to avoid MoveOutOfSubscriptError - if param_name in consumed_in_function: - # ANY consumption requires @owned (not just full consumption) - # This is because subscripting marks the array as used - consumed_indices = consumed_in_function[param_name] - should_be_owned = len(consumed_indices) > 0 - elif hasattr(self, "subscripted_params") and param_name in self.subscripted_params: - # Array has subscript access (c_d[0]) which requires @owned - should_be_owned = True - else: - # Check if there's element access even without consumption - # (e.g., gates applied to elements) - # Arrays in arrays_to_unpack need @owned - should_be_owned = param_name in block_plan.arrays_to_unpack - if should_be_owned: - pass - else: - # Last resort: if parameter is used in the function at all, it likely needs @owned - # In Guppy, any use of an array parameter in a functional block requires @owned - # because the generated IR will likely subscript it - # Check if the parameter appears in deps (it's used in the function) - if param_name in deps["quantum"]: - should_be_owned = True - - if should_be_owned: - param_type = f"{param_type} @owned" - - processed_params.append((param_name, param_type)) - params = processed_params - - # HYBRID UNPACKING: After parameter processing, check for @owned arrays that need unpacking - # @owned arrays must be unpacked to avoid MoveOutOfSubscriptError when accessing elements - for param_name, param_type in params: - # Don't double-unpack - is_owned_qubit_array = "array[quantum.qubit," in param_type and "@owned" in param_type - if is_owned_qubit_array and param_name not in self.unpacked_vars: - # Adding @owned array unpacking - # Extract array size - import re - - match = re.search(r"array\[quantum\.qubit, (\d+)\]", param_type) - if match: - size = int(match.group(1)) - # Generate unpacked variable names - element_names = [self._get_unique_var_name(param_name, i) for i in range(size)] - self.unpacked_vars[param_name] = element_names - - # Track that this was unpacked from a parameter (not a return value) - # Parameter-unpacked arrays should NOT be reconstructed for function calls - self.parameter_unpacked_arrays.add(param_name) - - # Add unpacking statement to function body - unpacking_stmt = self._create_array_unpack_statement( - param_name, - element_names, - ) - body.statements.append(unpacking_stmt) - - # Store whether this function returns quantum arrays - self.current_function_returns_quantum = will_return_quantum - - # Pre-extract conditions that might be needed in loops with @owned structs - # This must happen BEFORE any operations that might consume the structs - if hasattr(sample_block, "ops") and self._function_has_owned_struct_params( - params, - ): - extracted_conditions = self._pre_extract_loop_conditions(sample_block, body) - - # Track extracted conditions for later use - if extracted_conditions: - if not hasattr(self, "pre_extracted_conditions"): - self.pre_extracted_conditions = {} - self.pre_extracted_conditions.update(extracted_conditions) - - # Now convert operations (can use will_return_quantum flag) - if hasattr(sample_block, "ops"): - # Store block reference for look-ahead in operation conversion - # This enables measurement+Prep pattern detection in _convert_operation - self.current_block_ops = sample_block.ops - for op_index, op in enumerate(sample_block.ops): - # Store current operation index for look-ahead - self.current_op_index = op_index - stmt = self._convert_operation(op) - if stmt: - body.statements.append(stmt) - # Clear after processing - self.current_block_ops = None - self.current_op_index = None - - # Fix linearity issues: add fresh qubit allocations after consuming operations - self._fix_post_consuming_linearity_issues(body) - - # Fix unused fresh variables in conditional execution paths - self._fix_unused_fresh_variables(body) - - # Save the current variable remapping (includes changes from Prep operations) - # BEFORE restoring previous mapping, as we need it for return statement generation - self.function_var_remapping = self.variable_remapping.copy() if hasattr(self, "variable_remapping") else {} - - # Restore previous remapping - self.var_remapping = prev_var_remapping - self.current_block = prev_block - self.param_mapping = prev_mapping - - # Now calculate the actual detailed return type and generate return statements - return_type = "None" - - # Black Box Pattern: functions that handle quantum arrays return modified arrays - # BUT: if function consumes arrays (@owned), don't return them - # Check if we have quantum arrays or structs to return (regardless of unpacking) - has_quantum_arrays = any("array[quantum.qubit," in ptype for name, ptype in params) - has_structs = any(name in self.struct_info for name, ptype in params) - - # For procedural blocks, don't generate return statements - if not is_procedural_block and (has_quantum_arrays or has_structs): - # Array/struct return pattern: functions return reconstructed arrays or structs - quantum_returns = [] - - # Add structs first - even @owned structs can be returned if they're reconstructed - for name, ptype in params: - if name in self.struct_info: - # Remove @owned annotation from type for return type - return_type = ptype.replace(" @owned", "") - quantum_returns.append((name, return_type)) - - # Then add individual arrays not in structs (including ancillas) - for name, ptype in params: - if "array[quantum.qubit," in ptype: - # Check if this array is part of a struct - in_struct = False - is_excluded_ancilla = False - - for prefix, info in self.struct_info.items(): - if name in info["var_names"].values(): - in_struct = True - break - - # Check if this is an ancilla that was excluded from structs - if hasattr(self, "ancilla_qubits") and name in self.ancilla_qubits: - is_excluded_ancilla = True - - # Only include arrays that have live (unconsumed) qubits - # Check if this array has any unconsumed elements - if name in consumed_in_function: - # Some elements were consumed - check if any are still live - consumed_indices = consumed_in_function[name] - # Extract size from parameter type - import re - - size_match = re.search( - r"array\[quantum\.qubit,\s*(\d+)\]", - ptype, - ) - array_size = int(size_match.group(1)) if size_match else 2 - total_indices = set(range(array_size)) - - # Live indices = unconsumed OR explicitly reset - # Explicitly reset qubits are consumed by measurement but recreated by Prep - explicitly_reset_indices = set() - if hasattr(self, "explicitly_reset_qubits") and name in self.explicitly_reset_qubits: - explicitly_reset_indices = self.explicitly_reset_qubits[name] - - live_indices = (total_indices - consumed_indices) | explicitly_reset_indices - include_array = bool( - live_indices, - ) # Only include if has live qubits (unconsumed OR explicitly reset) - else: - # No consumption tracked for this array - assume it's live - include_array = not in_struct or is_excluded_ancilla - - if include_array: - # PRIORITY 1: Check if this array was refreshed by a function call - # If so, use the called function's return type instead of consumption analysis - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - self.refreshed_arrays[name] - # Find which function call produced this fresh variable - # by looking at the refreshed_by_function mapping - if hasattr(self, "refreshed_by_function") and name in self.refreshed_by_function: - func_info = self.refreshed_by_function[name] - # Extract function name from the dict (or handle legacy string format) - called_func_name = ( - func_info["function"] - if isinstance(func_info, dict) - else func_info # Legacy string format - ) - # Look up that function's return type - if called_func_name in self.function_return_types: - called_func_return = self.function_return_types[called_func_name] - # If it returns a tuple, extract the type for this array - if called_func_return.startswith("tuple["): - # Parse tuple to find the type for this array - import re - - tuple_match = re.match( - r"tuple\[(.*)\]", - called_func_return, - ) - if tuple_match: - return_types_str = tuple_match.group(1) - # Split by comma but handle nested brackets - types_list = [] - bracket_depth = 0 - current_type = "" - for char in return_types_str: - if char == "[": - bracket_depth += 1 - current_type += char - elif char == "]": - bracket_depth -= 1 - current_type += char - elif char == "," and bracket_depth == 0: - types_list.append( - current_type.strip(), - ) - current_type = "" - else: - current_type += char - if current_type: - types_list.append(current_type.strip()) - - # Find which position this array is in the function's parameters - quantum_param_names = [ - n for n, pt in params if "array[quantum.qubit," in pt - ] - if name in quantum_param_names: - param_idx = quantum_param_names.index( - name, - ) - if param_idx < len(types_list): - # Use the return type from the called function - new_type = types_list[param_idx] - quantum_returns.append( - (name, new_type), - ) - continue # Skip consumption analysis - else: - # Single return - use it directly - quantum_returns.append( - (name, called_func_return), - ) - continue # Skip consumption analysis - - # PRIORITY 2: Use consumption analysis if array wasn't refreshed by a function - # Check if any elements remain unconsumed for ALL arrays - if name in consumed_in_function: - # Extract array size from type - import re - - match = re.search(r"array\[quantum\.qubit, (\d+)\]", ptype) - if match: - original_size = int(match.group(1)) - consumed_indices = consumed_in_function[name] - - # Check if any consumed qubits were replaced - if hasattr(self, "replaced_qubits") and name in self.replaced_qubits: - self.replaced_qubits[name] - - # Check if this parameter was fully consumed (all elements measured) - # BUT: if consumed qubits were explicitly reset, they should be returned - fully_consumed = len(consumed_indices) == original_size - - # Check if any consumed qubits were explicitly reset - explicitly_reset_indices = set() - if hasattr(self, "explicitly_reset_qubits") and name in self.explicitly_reset_qubits: - explicitly_reset_indices = self.explicitly_reset_qubits[name] - - # If fully consumed BUT some were explicitly reset, we should return those - if fully_consumed and not explicitly_reset_indices: - # All qubits were measured and none were explicitly reset - don't return - pass # Don't add to quantum_returns - else: - # Not fully consumed - return the array - # Determine how many qubits will actually be returned - # This depends on: - # 1. Whether this will be a single or multiple return - # 2. Whether consumed qubits were replaced - - # Count how many quantum arrays will likely be returned - # (This is a heuristic - we're building quantum_returns as we go) - num_quantum_params = 0 - for n, pt in params: - if "array[quantum.qubit," in pt: - # Check if this array is part of a struct - in_struct = False - if isinstance(self.struct_info, dict) and n in self.struct_info.values(): - in_struct = True - if not in_struct: - num_quantum_params += 1 - - # For both single and multiple returns with partial consumption: - # Return unconsumed + explicitly reset elements - # Automatic replacements (for linearity) are not returned - # Matches return statement generation at lines 1424-1465 - - # Calculate how many elements to return - explicitly_reset_indices = set() - if ( - hasattr(self, "explicitly_reset_qubits") - and name in self.explicitly_reset_qubits - ): - explicitly_reset_indices = self.explicitly_reset_qubits[name] - - # Count elements that are either unconsumed OR explicitly reset - elements_to_return_count = 0 - for i in range(original_size): - if i not in consumed_indices or i in explicitly_reset_indices: - elements_to_return_count += 1 - - remaining_count = elements_to_return_count - - if remaining_count > 0: - # Some qubits remain - return array with correct size - if remaining_count < original_size: - # Partial consumption - return array with reduced size - new_type = f"array[quantum.qubit, {remaining_count}]" - else: - # No consumption - return original type - new_type = ptype.replace(" @owned", "") - quantum_returns.append((name, new_type)) - else: - # No consumption tracked - return full array - # Remove @owned annotation from return type - return_type = ptype.replace(" @owned", "") - quantum_returns.append((name, return_type)) - - if quantum_returns: - # Add return statements - if len(quantum_returns) == 1: - name, ptype = quantum_returns[0] - - # Check if this is a partial return - if name in consumed_in_function and "array[quantum.qubit," in ptype: - # Need to return only unconsumed elements - import re - - match = re.search(r"array\[quantum\.qubit, (\d+)\]", ptype) - if match: - int(match.group(1)) - original_match = re.search( - r"array\[quantum\.qubit, (\d+)\]", - next(pt for n, pt in params if n == name), - ) - if original_match: - original_size = int(original_match.group(1)) - consumed_indices = consumed_in_function[name] - - # Build array with unconsumed + explicitly reset elements - # - # DESIGN DECISION: Return unconsumed and explicitly reset elements - # - Unconsumed: Elements never measured/consumed - # - Explicitly reset: Elements reset via Prep operation (quantum.reset) - # - Automatic replacements: Created for linearity, NOT returned - # - # This distinguishes: - # 1. Explicit Prep(qubit) - semantic reset operation → RETURN - # 2. Automatic post-measurement replacement → DON'T RETURN - # - # Examples: - # - Steane verification: Prep(ancilla) → explicit reset → included - # - Partial consumption: Measure(q[0]) → automatic replacement → excluded - - # Determine which consumed indices should be returned - # (i.e., those that were explicitly reset) - explicitly_reset_indices = set() - if hasattr(self, "explicitly_reset_qubits") and name in self.explicitly_reset_qubits: - explicitly_reset_indices = self.explicitly_reset_qubits[name] - - elements_to_return = [] - for i in range(original_size): - # Include if: (1) not consumed, OR (2) explicitly reset - if i not in consumed_indices or i in explicitly_reset_indices: - if name in self.unpacked_vars: - # Use unpacked element name directly using original index - # NOTE: index_mapping maps original index → - # compact position for function CALLS - # But for RETURNS, we still have all original - # unpacked elements available - # So we use the original index 'i' directly! - element_name = self.unpacked_vars[name][i] - # Apply variable remapping if element was - # reassigned (e.g., Prep after Measure) - if hasattr(self, "function_var_remapping"): - element_name = self.function_var_remapping.get( - element_name, - element_name, - ) - elements_to_return.append( - VariableRef(element_name), - ) - else: - # Use array indexing - elements_to_return.append( - ArrayAccess(array_name=name, index=i), - ) - - # Create array construction - array_expr = FunctionCall( - func_name="array", - args=elements_to_return, - ) - body.statements.append( - ReturnStatement(value=array_expr), - ) - elif name in self.unpacked_vars: - # Array was unpacked - check for partial consumption - # CRITICAL: Also check consumed_in_function here! - # The earlier check (line 1548) might have failed due to return type detection issues - if name in consumed_in_function: - # Partial consumption - return unconsumed + explicitly reset elements - consumed_indices = consumed_in_function[name] - element_names = self.unpacked_vars[name] - - # Get explicitly reset indices - explicitly_reset_indices = set() - if hasattr(self, "explicitly_reset_qubits") and name in self.explicitly_reset_qubits: - explicitly_reset_indices = self.explicitly_reset_qubits[name] - - # Filter: include unconsumed OR explicitly reset - elements_to_return = [] - for i, elem_name in enumerate(element_names): - if i not in consumed_indices or i in explicitly_reset_indices: - # Apply variable remapping if element was reassigned (e.g., Prep after Measure) - if hasattr(self, "function_var_remapping"): - elem_name = self.function_var_remapping.get( - elem_name, - elem_name, - ) - elements_to_return.append(VariableRef(elem_name)) - array_construction = FunctionCall( - func_name="array", - args=elements_to_return, - ) - body.statements.append( - ReturnStatement(value=array_construction), - ) - elif hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - # Array was unpacked AND refreshed - return the fresh version - fresh_name = self.refreshed_arrays[name] - body.statements.append( - ReturnStatement(value=VariableRef(fresh_name)), - ) - else: - # Array was unpacked - must reconstruct from elements for linearity - # Even if no elements were consumed, the original array is "moved" by unpacking - element_names = self.unpacked_vars[name] - array_construction = self._create_array_reconstruction( - element_names, - ) - body.statements.append( - ReturnStatement(value=array_construction), - ) - elif name in struct_reconstruction: - # Struct was decomposed - but check if it was also refreshed by function calls - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - # Struct was refreshed - return the fresh version directly - fresh_name = self.refreshed_arrays[name] - body.statements.append( - ReturnStatement(value=VariableRef(fresh_name)), - ) - else: - # Struct was decomposed - reconstruct it from field variables - struct_info = self.struct_info[name] - - # Check if this is an @owned struct that was decomposed - is_owned_struct = hasattr(self, "owned_structs") and name in self.owned_structs - - # For @owned structs, always reconstruct from decomposed variables - # For regular structs, check if the unpacked variables are still valid - if is_owned_struct: - should_reconstruct = True - else: - # Check if the unpacked variables are still valid - # They're only valid if we haven't passed the struct - # to any @owned functions - should_reconstruct = all( - struct_info["var_names"].get(suffix) in self.var_remapping - for suffix, _, _ in struct_info["fields"] - ) - - if should_reconstruct: - # Create struct constructor call - use same order - # as struct definition (sorted by suffix) - constructor_args = [] - all_vars_available = True - - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{name}_{suffix}" - - # Check if we have a fresh version of this field variable - if hasattr(self, "refreshed_arrays") and field_var in self.refreshed_arrays: - field_var = self.refreshed_arrays[field_var] - elif hasattr(self, "var_remapping") and field_var in self.var_remapping: - field_var = self.var_remapping[field_var] - else: - # Check if the variable was consumed in operations - if hasattr(self, "consumed_vars") and field_var in self.consumed_vars: - all_vars_available = False - break - - constructor_args.append(VariableRef(field_var)) - - if all_vars_available and constructor_args: - struct_constructor = FunctionCall( - func_name=struct_info["struct_name"], - args=constructor_args, - ) - body.statements.append( - ReturnStatement(value=struct_constructor), - ) - else: - # Variables were consumed - cannot reconstruct - # Return void or handle appropriately for @owned structs - pass - else: - # Unpacked variables are no longer valid - return the struct directly - body.statements.append( - ReturnStatement(value=VariableRef(name)), - ) - else: - # Check if this variable was refreshed due to being borrowed - # (e.g., c_d -> c_d_returned) - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - # Use the refreshed name for the return - return_name = self.refreshed_arrays[name] - body.statements.append( - ReturnStatement(value=VariableRef(return_name)), - ) - elif hasattr(self, "owned_structs") and name in self.owned_structs and name in self.struct_info: - # @owned struct needs reconstruction from decomposed variables - struct_info = self.struct_info[name] - - # Create struct constructor call - constructor_args = [] - all_vars_available = True - - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{name}_{suffix}" - - # Check if we have a fresh version of this field variable - if hasattr(self, "refreshed_arrays") and field_var in self.refreshed_arrays: - field_var = self.refreshed_arrays[field_var] - elif hasattr(self, "var_remapping") and field_var in self.var_remapping: - field_var = self.var_remapping[field_var] - else: - # Check if the variable was consumed in operations - if hasattr(self, "consumed_vars") and field_var in self.consumed_vars: - all_vars_available = False - break - - constructor_args.append(VariableRef(field_var)) - - if all_vars_available and constructor_args: - struct_constructor = FunctionCall( - func_name=struct_info["struct_name"], - args=constructor_args, - ) - body.statements.append( - ReturnStatement(value=struct_constructor), - ) - else: - # Check if this variable has been refreshed by function calls - var_to_return = name - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - var_to_return = self.refreshed_arrays[name] - body.statements.append( - ReturnStatement(value=VariableRef(var_to_return)), - ) - - # Set return type - return_type = ptype # Use the potentially modified type - else: - # Multiple arrays/structs - return tuple - return_exprs = [] - return_types = [] - for name, ptype in quantum_returns: - if name in self.unpacked_vars: - # Array was unpacked - check if it was also refreshed by function calls - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - # Array was refreshed after unpacking - return the fresh version - fresh_name = self.refreshed_arrays[name] - return_exprs.append(VariableRef(fresh_name)) - else: - # Array was unpacked - check if elements are still available for reconstruction - element_names = self.unpacked_vars[name] - - # For arrays with size 0 in return type, create empty arrays instead of reconstructing - if "array[quantum.qubit, 0]" in ptype: - # All elements consumed - create empty quantum array using generator expression - # Create custom expression for: array(quantum.qubit() for _ in range(0)) - - class EmptyArrayExpression(Expression): - def analyze(self, _context): - pass # No analysis needed for empty array - - def render(self, _context): - return [ - "array(quantum.qubit() for _ in range(0))", - ] - - empty_array = EmptyArrayExpression() - return_exprs.append(empty_array) - else: - # Check if this array has partial consumption - if name in consumed_in_function: - consumed_indices = consumed_in_function[name] - - # Build array with unconsumed + explicitly reset elements - # See single return path (lines 1424-1465) for detailed rationale - - # Get explicitly reset indices - explicitly_reset_indices = set() - if ( - hasattr(self, "explicitly_reset_qubits") - and name in self.explicitly_reset_qubits - ): - explicitly_reset_indices = self.explicitly_reset_qubits[name] - - elements_to_return = [] - for i in range(len(element_names)): - # Include if: (1) not consumed, OR (2) explicitly reset - if i not in consumed_indices or i in explicitly_reset_indices: - element_name = element_names[i] - # Apply variable remapping if element was reassigned - # Use function_var_remapping which includes Prep changes - if hasattr( - self, - "function_var_remapping", - ): - element_name = self.function_var_remapping.get( - element_name, - element_name, - ) - elements_to_return.append( - VariableRef(element_name), - ) - - if elements_to_return: - # Create array from unconsumed elements - array_construction = FunctionCall( - func_name="array", - args=elements_to_return, - ) - return_exprs.append(array_construction) - else: - # All elements consumed - use empty array - class EmptyArrayExpression(Expression): - def analyze(self, _context): - pass - - def render(self, _context): - return [ - "array(quantum.qubit() for _ in range(0))", - ] - - return_exprs.append(EmptyArrayExpression()) - else: - # No consumption or not tracked - standard reconstruction from all elements - array_construction = self._create_array_reconstruction( - element_names, - ) - return_exprs.append(array_construction) - elif name in struct_reconstruction: - # Struct was decomposed - but check if it was also refreshed by function calls - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - # Struct was refreshed - return the fresh version directly - fresh_name = self.refreshed_arrays[name] - return_exprs.append(VariableRef(fresh_name)) - else: - # Struct was decomposed - check if we can still use - # the decomposed variables - struct_info = self.struct_info[name] - - # Check if this is an @owned struct that was decomposed - is_owned_struct = hasattr(self, "owned_structs") and name in self.owned_structs - - # For @owned structs, always reconstruct from decomposed variables - # For regular structs, check if the unpacked variables are still valid - if is_owned_struct: - unpacked_vars_valid = True - else: - # Check if the unpacked variables are still valid - unpacked_vars_valid = all( - struct_info["var_names"].get(suffix) in self.var_remapping - for suffix, _, _ in struct_info["fields"] - ) - - if unpacked_vars_valid: - # Create struct constructor call - use same order - # as struct definition (sorted by suffix) - constructor_args = [] - all_vars_available = True - - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{name}_{suffix}" - - # Check if we have a fresh version of this field variable - if hasattr(self, "refreshed_arrays") and field_var in self.refreshed_arrays: - field_var = self.refreshed_arrays[field_var] - elif hasattr(self, "var_remapping") and field_var in self.var_remapping: - field_var = self.var_remapping[field_var] - else: - # Check if the variable was consumed in operations - if hasattr(self, "consumed_vars") and field_var in self.consumed_vars: - all_vars_available = False - break - - constructor_args.append(VariableRef(field_var)) - - if all_vars_available and constructor_args: - struct_constructor = FunctionCall( - func_name=struct_info["struct_name"], - args=constructor_args, - ) - return_exprs.append(struct_constructor) - else: - # Variables were consumed - handle appropriately - var_to_return = name - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - var_to_return = self.refreshed_arrays[name] - return_exprs.append(VariableRef(var_to_return)) - else: - # Unpacked variables are no longer valid - - # return the struct directly - # Check if this variable has been refreshed by function calls - var_to_return = name - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - var_to_return = self.refreshed_arrays[name] - return_exprs.append(VariableRef(var_to_return)) - else: - # Array/struct was not unpacked - return it directly - # Check if this is an @owned struct that needs reconstruction - if ( - hasattr(self, "owned_structs") - and name in self.owned_structs - and name in self.struct_info - ): - # @owned struct needs reconstruction from decomposed variables - struct_info = self.struct_info[name] - - # Create struct constructor call - constructor_args = [] - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{name}_{suffix}" - - # Check if we have a fresh version of this field variable - if hasattr(self, "refreshed_arrays") and field_var in self.refreshed_arrays: - field_var = self.refreshed_arrays[field_var] - elif hasattr(self, "var_remapping") and field_var in self.var_remapping: - field_var = self.var_remapping[field_var] - - constructor_args.append(VariableRef(field_var)) - - struct_constructor = FunctionCall( - func_name=struct_info["struct_name"], - args=constructor_args, - ) - return_exprs.append(struct_constructor) - else: - # Check if this variable has been refreshed by function calls - var_to_return = name - if hasattr(self, "refreshed_arrays") and name in self.refreshed_arrays: - var_to_return = self.refreshed_arrays[name] - return_exprs.append(VariableRef(var_to_return)) - - # Add type to return types - return_types.append(ptype) - - if return_exprs: - body.statements.append( - ReturnStatement( - value=TupleExpression(elements=return_exprs), - ), - ) - return_type = f"tuple[{', '.join(return_types)}]" - - # For procedural blocks, override return type to None even if they return arrays internally - if is_procedural_block: - return_type = "None" - # Also remove any return statements from the body since this is procedural - body.statements = [stmt for stmt in body.statements if not isinstance(stmt, ReturnStatement)] - - # Add cleanup for unused quantum arrays that might have been created - # by function calls but not consumed (e.g., fresh variables) - # GENERAL APPROACH: Check for any fresh_return_vars that were created - if hasattr(self, "fresh_return_vars") and self.fresh_return_vars: - # Add discard for each fresh variable that wasn't consumed - # (consumed variables are tracked in consumed_arrays or consumed_resources) - for fresh_name, info in self.fresh_return_vars.items(): - # Check if this fresh variable was consumed - was_consumed = False - if hasattr(self, "consumed_arrays"): - was_consumed = fresh_name in self.consumed_arrays - if not was_consumed and hasattr(self, "consumed_resources"): - was_consumed = fresh_name in self.consumed_resources - - # If the fresh array was unpacked into element vars, the - # array itself was moved by the unpack -- discard_array - # would error. Element-level cleanup is handled separately - # (or the elements were consumed by gates/measurements). - # The unpacked-state tracker keys by the *original* SLR - # symbol, so we look up via the original; the fresh name - # itself doesn't appear in unpacked_vars. - original_name = info.get("original") - if ( - original_name - and self.var_state.is_unpacked(original_name) - and hasattr(self, "refreshed_arrays") - and self.refreshed_arrays.get(original_name) == fresh_name - ): - was_consumed = True - - if not was_consumed and info.get("is_quantum_array"): - # Add discard statement - discard_stmt = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(fresh_name)], - ) - - # Wrap in expression statement - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - body.statements.append(Comment(f"Discard unused {fresh_name}")) - body.statements.append(ExpressionStatement(discard_stmt)) - - # Clear tracking for next function - self.fresh_return_vars = {} - - # Store the return type for use in other parts of the code - self.current_function_return_type = return_type - # Store in function return types registry for later lookup - self.function_return_types[func_name] = return_type - - return Function( - name=func_name, - params=params, - return_type=return_type, - body=body, - decorators=["guppy", "no_type_check"], - ) - - def _add_variable_declaration(self, var, block=None) -> None: - """Add variable declaration to current block.""" - var_type = type(var).__name__ - var_name = var.sym - - # Check for renaming - if var_name in self.plan.renamed_variables: - var_name = self.plan.renamed_variables[var_name] - - if var_type == "QReg": - # Get size for all cases - size = var.size - - # Check allocation recommendation for this array - recommendation = self.allocation_recommendations.get(var.sym, {}) - - # Get resource plan from unified analysis if available - resource_plan = None - if self.unified_analysis: - resource_plan = self.unified_analysis.get_plan(var.sym) - - # Check if this array needs unpacking (selective measurements) - needs_unpacking = var.sym in self.plan.arrays_to_unpack - - # Check if this array is used in full array operations - needs_full_array = self._array_needs_full_allocation(var.sym, block) - - # Check if this should be dynamically allocated based on usage patterns - # But only if it doesn't need unpacking for selective measurements - # AND not used in full array ops - # AND not a function parameter in current function - # AND the unified resource plan agrees with dynamic allocation - is_function_parameter = hasattr(self, "current_function_params") and any( - param_name == var.sym for param_name, _ in self.current_function_params - ) - - # Use the unified resource plan if available, otherwise fall back to recommendation - should_use_dynamic = False - if resource_plan: - # Resource plan from unified analysis (authoritative) - should_use_dynamic = resource_plan.uses_dynamic_allocation - else: - # Fall back to recommendation - should_use_dynamic = recommendation.get("allocation") == "dynamic" - - if should_use_dynamic and not needs_unpacking and not needs_full_array and not is_function_parameter: - # Check if this ancilla array is used as a function parameter - # If so, we need to pre-allocate it despite being an ancilla - is_function_param = False - if hasattr(self, "ancilla_qubits") and var_name in self.ancilla_qubits: - # This is an ancilla that was excluded from structs - # It will be passed as a parameter to functions, so pre-allocate it - is_function_param = True - - if is_function_param: - # For ancilla qubits, create individual qubits instead of arrays - # This avoids @owned array passing issues that cause linearity violations - self.current_block.statements.append( - Comment( - f"Create individual ancilla qubits for {var_name} (avoids @owned array issues)", - ), - ) - - # Create individual qubits: c_a_0, c_a_1, c_a_2 instead of array c_a - for i in range(size): - qubit_name = f"{var_name}_{i}" - init_expr = FunctionCall(func_name="quantum.qubit", args=[]) - assignment = Assignment( - target=VariableRef(qubit_name), - value=init_expr, - ) - self.current_block.statements.append(assignment) - - # Mark this variable as having been decomposed into individual qubits - if not hasattr(self, "decomposed_ancilla_arrays"): - self.decomposed_ancilla_arrays = {} - self.decomposed_ancilla_arrays[var_name] = [f"{var_name}_{i}" for i in range(size)] - - # Add a function to reconstruct the array when needed for function calls - # This creates: c_a = array(c_a_0, c_a_1, c_a_2) - self.current_block.statements.append( - Comment(f"# Reconstruct {var_name} array for function calls"), - ) - array_construction_args = [VariableRef(f"{var_name}_{i}") for i in range(size)] - reconstruct_expr = FunctionCall( - func_name="array", - args=array_construction_args, - ) - reconstruct_assignment = Assignment( - target=VariableRef(var_name), - value=reconstruct_expr, - ) - self.current_block.statements.append(reconstruct_assignment) - - # Track that this array has been reconstructed - use the variable directly, not individual qubits - if not hasattr(self, "reconstructed_arrays"): - self.reconstructed_arrays = set() - self.reconstructed_arrays.add(var_name) - else: - # For other ancillas, don't pre-allocate array - reason = recommendation.get("reason", "ancilla pattern") - # Before marking for dynamic allocation, check if this variable - # is used as a function argument in the current block - is_function_arg = self._is_variable_used_as_function_arg( - var.sym, - block, - ) - - if is_function_arg: - # For ancilla qubits used as function arguments, create individual qubits - # This avoids @owned array passing issues - comment_text = ( - f"Create individual ancilla qubits for {var_name} " - f"(function argument, avoids @owned array issues)" - ) - self.current_block.statements.append( - Comment(comment_text), - ) - - # Create individual qubits: c_a_0, c_a_1, c_a_2 instead of array c_a - for i in range(size): - qubit_name = f"{var_name}_{i}" - init_expr = FunctionCall(func_name="quantum.qubit", args=[]) - assignment = Assignment( - target=VariableRef(qubit_name), - value=init_expr, - ) - self.current_block.statements.append(assignment) - - # Mark this variable as having been decomposed into individual qubits - if not hasattr(self, "decomposed_ancilla_arrays"): - self.decomposed_ancilla_arrays = {} - self.decomposed_ancilla_arrays[var_name] = [f"{var_name}_{i}" for i in range(size)] - else: - # Normal dynamic allocation - self.current_block.statements.append( - Comment( - f"# {var_name} will be allocated dynamically ({reason})", - ), - ) - # Track that this is dynamically allocated - if not hasattr(self, "dynamic_allocations"): - self.dynamic_allocations = set() - self.dynamic_allocations.add(var.sym) - elif resource_plan and resource_plan.uses_dynamic_allocation: - # Check if all elements are local (full dynamic allocation) - if len(resource_plan.elements_to_allocate_locally) == size: - # Don't pre-allocate - all will be allocated when first used - self.current_block.statements.append( - Comment(f"Qubits from {var_name} will be allocated locally"), - ) - # Track that this is dynamically allocated - if not hasattr(self, "dynamic_allocations"): - self.dynamic_allocations = set() - self.dynamic_allocations.add(var.sym) - else: - # Mixed strategy - pre-allocate some, allocate others locally - # But only if the array doesn't need unpacking - if needs_unpacking: - # Can't use mixed allocation with unpacking - fall back to full pre-allocation - init_expr = FunctionCall( - func_name="array", - args=[ - FunctionCall( - func_name="quantum.qubit() for _ in range", - args=[Literal(size)], - ), - ], - ) - assignment = Assignment( - target=VariableRef(var_name), - value=init_expr, - ) - self.current_block.statements.append(assignment) - self.current_block.statements.append( - Comment( - f"Note: Full pre-allocation used because {var_name} needs unpacking", - ), - ) - elif size - len(resource_plan.elements_to_allocate_locally) > 0: - pre_alloc_size = size - len( - resource_plan.elements_to_allocate_locally, - ) - init_expr = FunctionCall( - func_name="array", - args=[ - FunctionCall( - func_name="quantum.qubit() for _ in range", - args=[Literal(pre_alloc_size)], - ), - ], - ) - assignment = Assignment( - target=VariableRef(var_name), - value=init_expr, - ) - self.current_block.statements.append(assignment) - - self.current_block.statements.append( - Comment( - f"Elements {sorted(resource_plan.elements_to_allocate_locally)} of " - f"{var_name} will be allocated locally", - ), - ) - else: - # Check if this is an ancilla array that should be decomposed - if hasattr(self, "ancilla_qubits") and var_name in self.ancilla_qubits: - # Decompose ancilla arrays into individual qubits to avoid @owned linearity issues - self.current_block.statements.append( - Comment( - f"Create individual ancilla qubits for {var_name} (avoids @owned array linearity issues)", - ), - ) - - # Create individual qubits: c_a_0, c_a_1, c_a_2 instead of array c_a - for i in range(size): - qubit_name = f"{var_name}_{i}" - init_expr = FunctionCall(func_name="quantum.qubit", args=[]) - assignment = Assignment( - target=VariableRef(qubit_name), - value=init_expr, - ) - self.current_block.statements.append(assignment) - - # Mark this variable as having been decomposed into individual qubits - if not hasattr(self, "decomposed_ancilla_arrays"): - self.decomposed_ancilla_arrays = {} - self.decomposed_ancilla_arrays[var_name] = [f"{var_name}_{i}" for i in range(size)] - - # Add a function to reconstruct the array when needed for function calls - # This creates: c_a = array(c_a_0, c_a_1, c_a_2) - self.current_block.statements.append( - Comment(f"# Reconstruct {var_name} array for function calls"), - ) - array_construction_args = [VariableRef(f"{var_name}_{i}") for i in range(size)] - reconstruct_expr = FunctionCall( - func_name="array", - args=array_construction_args, - ) - reconstruct_assignment = Assignment( - target=VariableRef(var_name), - value=reconstruct_expr, - ) - self.current_block.statements.append(reconstruct_assignment) - - # Track that this array has been reconstructed - use the variable directly, not individual qubits - if not hasattr(self, "reconstructed_arrays"): - self.reconstructed_arrays = set() - self.reconstructed_arrays.add(var_name) - else: - # Check if this ancilla array was already decomposed into individual qubits - if hasattr(self, "decomposed_ancilla_arrays") and var_name in self.decomposed_ancilla_arrays: - # Skip array creation - individual qubits were already created - qubit_list = ", ".join(self.decomposed_ancilla_arrays[var_name]) - comment_text = f"# {var_name} already decomposed into individual qubits: {qubit_list}" - self.current_block.statements.append( - Comment(comment_text), - ) - else: - # Default: pre-allocate all qubits - init_expr = FunctionCall( - func_name="array", - args=[ - FunctionCall( - func_name="quantum.qubit() for _ in range", - args=[Literal(size)], - ), - ], - ) - assignment = Assignment( - target=VariableRef(var_name), - value=init_expr, - ) - self.current_block.statements.append(assignment) - - # Track in context - var_info = VariableInfo( - name=var_name, - original_name=var.sym, - var_type="quantum", - size=size, - is_array=True, - ) - self.context.add_variable(var_info) - self.scope_manager.current_context.add_variable(var_info) - - elif var_type == "CReg": - # Create classical array - size = var.size - init_expr = FunctionCall( - func_name="array", - args=[ - FunctionCall( - func_name="False for _ in range", - args=[Literal(size)], - ), - ], - ) - assignment = Assignment( - target=VariableRef(var_name), - value=init_expr, - ) - self.current_block.statements.append(assignment) - - # Track in context - var_info = VariableInfo( - name=var_name, - original_name=var.sym, - var_type="classical", - size=size, - is_array=True, - ) - self.context.add_variable(var_info) - self.scope_manager.current_context.add_variable(var_info) - - def _block_consumes_quantum(self, block) -> bool: - """Check if a block consumes ALL quantum resources. - - Only return True if the block consumes ALL its quantum inputs. - Most SLR functions modify arrays in-place without consuming them. - - However, functions that access quantum fields within structs need @owned - annotation to satisfy Guppy's linearity requirements. - """ - # For now, be very conservative - assume functions don't consume - # their parameters unless they're explicitly measurement blocks - # that measure ALL qubits - - # Check the block name - only certain blocks truly consume all resources - block_name = type(block).__name__ - if block_name in ["MeasureAll", "DiscardAll"]: - return True - - # IMPORTANT: Functions that will access quantum fields within structs - # need @owned annotation for Guppy's linearity system - # Otherwise assume the function modifies in-place without consuming - return self._block_accesses_struct_quantum_fields(block) - - def _analyze_consumed_parameters(self, block) -> set[str]: - """Analyze which quantum parameters are consumed by a block. - - A parameter is consumed if: - 1. It appears in a Measure operation that measures the full register - 2. All its elements are measured individually - 3. It's passed to a nested Block that consumes it - """ - consumed_params = set() - element_measurements = {} # Track which array elements are measured - - if not hasattr(block, "ops"): - return consumed_params - - # Recursively analyze all operations including nested blocks - def analyze_ops(ops_list): - for op in ops_list: - op_type = type(op).__name__ - - # Measurement consumes qubits - if op_type == "Measure": - if hasattr(op, "qargs"): - for qarg in op.qargs: - # Check if it's a full register measurement (not indexed) - if hasattr(qarg, "sym"): - # This is a full register being measured - consumed_params.add(qarg.sym) - # Check for indexed measurements (e.g., q[0], q[1]) - elif hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - array_name = qarg.reg.sym - if array_name not in element_measurements: - element_measurements[array_name] = set() - if hasattr(qarg, "index"): - element_measurements[array_name].add(qarg.index) - - # Check if this is a nested Block call - elif hasattr(op, "__class__") and hasattr(op.__class__, "__bases__"): - from pecos.slr import Block as SlrBlock - - # Check if op is a Block subclass - # Need to check the class itself, not just the base name - try: - if issubclass(op.__class__, SlrBlock) and hasattr(op, "ops"): - # Recursively analyze nested block - analyze_ops(op.ops) - except (TypeError, AttributeError): - # Not a class or missing expected attributes - pass - - # Analyze all operations - analyze_ops(block.ops) - - # Check if arrays are consumed - # In Guppy, any measurement of array elements requires @owned annotation - # because it consumes those elements - for array_name, measured_indices in element_measurements.items(): - # If any element is measured, the array is consumed and needs @owned - if len(measured_indices) > 0: - consumed_params.add(array_name) - - return consumed_params - - def _analyze_subscript_access(self, block) -> set[str]: - """Analyze which quantum arrays have subscript access in a block. - - In Guppy, any subscript access (c_d[0]) marks the array as used, - requiring @owned annotation to avoid MoveOutOfSubscriptError. - - Returns: - set of array names that have subscript access - """ - subscripted_arrays = set() - - if not hasattr(block, "ops"): - return subscripted_arrays - - # Recursively analyze all operations - def analyze_ops(ops_list): - for op in ops_list: - # Check for any quantum operation with indexed arguments - if hasattr(op, "qargs"): - for qarg in op.qargs: - # Check for indexed access (e.g., q[0]) - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - array_name = qarg.reg.sym - subscripted_arrays.add(array_name) - # Also check for register-wide operations that will be converted to loops - # (e.g., qubit.H(q) becomes for i in range(7): quantum.h(q[i])) - elif hasattr(qarg, "sym") and hasattr(qarg, "elems") and len(qarg.elems) > 1: - # This is a register-wide operation - will use subscripts - array_name = qarg.sym - subscripted_arrays.add(array_name) - # else: qarg doesn't match expected patterns - - # Check for classical array subscripts too - if hasattr(op, "cargs"): - for carg in op.cargs: - if hasattr(carg, "reg") and hasattr(carg.reg, "sym"): - # This is classical, skip for now - pass - - # Check nested blocks - if hasattr(op, "__class__") and hasattr(op.__class__, "__bases__"): - from pecos.slr import Block as SlrBlock - - try: - if issubclass(op.__class__, SlrBlock) and hasattr(op, "ops"): - analyze_ops(op.ops) - except (TypeError, AttributeError): - # Not a class or missing expected attributes - pass - - analyze_ops(block.ops) - return subscripted_arrays - - def _analyze_block_element_usage(self, block) -> dict: - """Analyze which specific array elements are consumed vs returned by a block. - - Returns: - dict: { - 'consumed_elements': {'array_name': {consumed_indices}}, - 'array_sizes': {'array_name': size}, - 'returned_elements': {'array_name': {returned_indices}} - } - """ - consumed_elements = {} - array_sizes = {} - - if not hasattr(block, "ops"): - return { - "consumed_elements": consumed_elements, - "array_sizes": array_sizes, - "returned_elements": {}, - } - - # Analyze block to find measurements - def analyze_ops(ops_list): - for op in ops_list: - op_type = type(op).__name__ - - # Measurement consumes qubits - if op_type == "Measure": - if hasattr(op, "qargs"): - for qarg in op.qargs: - # Check for indexed measurements (e.g., q[0]) - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - array_name = qarg.reg.sym - if array_name not in consumed_elements: - consumed_elements[array_name] = set() - if hasattr(qarg, "index"): - consumed_elements[array_name].add(qarg.index) - - # Check if this is a nested Block call - elif hasattr(op, "__class__") and hasattr(op.__class__, "__bases__"): - from pecos.slr import Block as SlrBlock - - try: - if issubclass(op.__class__, SlrBlock) and hasattr(op, "ops"): - # Recursively analyze nested block - analyze_ops(op.ops) - except (TypeError, AttributeError): - # Not a class or missing expected attributes - pass - - # Get array sizes from block parameters - if hasattr(block, "q") and hasattr(block.q, "size"): - array_sizes["q"] = block.q.size - - analyze_ops(block.ops) - - # Pre-track explicit resets to know which consumed qubits are reset and should be returned - consumed_for_tracking = {} - self._track_consumed_qubits(block, consumed_for_tracking) - - # Calculate returned elements - # = (all elements - consumed) + explicitly_reset - # Explicitly reset qubits are consumed by measurement but then recreated by Prep - returned_elements = {} - for array_name, size in array_sizes.items(): - consumed = consumed_elements.get(array_name, set()) - all_indices = set(range(size)) - unconsumed = all_indices - consumed - - # Add explicitly reset qubits (they're consumed but then reset, so should be returned) - explicitly_reset = set() - if hasattr(self, "explicitly_reset_qubits") and array_name in self.explicitly_reset_qubits: - explicitly_reset = self.explicitly_reset_qubits[array_name] - - returned_elements[array_name] = unconsumed | explicitly_reset - - return { - "consumed_elements": consumed_elements, - "array_sizes": array_sizes, - "returned_elements": returned_elements, - } - - def _block_accesses_struct_quantum_fields(self, block) -> bool: - """Check if a block accesses quantum fields within structs. - - This is important because Guppy's linearity system requires @owned - annotation for functions that access quantum fields within structs. - """ - if not hasattr(block, "ops"): - return False - - # If we have struct info, assume that functions accessing quantum operations - # will need to access quantum fields within structs - if self.struct_info: - # Check if this block has quantum operations - for op in block.ops: - # Check for quantum operations (gates, measurements, etc.) - op_name = type(op).__name__ - if op_name in [ - "H", - "X", - "Y", - "Z", - "CX", - "CY", - "CZ", - "Reset", - "Measure", - "S", - "T", - "Sdg", - "Tdg", - ]: - return True - - # Also check for nested quantum operations - if hasattr(op, "ops") and self._block_accesses_struct_quantum_fields( - op, - ): - return True - - return False - - def _needs_unpacking_workaround(self, block) -> bool: - """Detect if a block needs the unpacking workaround for Guppy constraints.""" - if not hasattr(block, "ops"): - return False - - # Check for patterns that cause MoveOutOfSubscriptError - for op in block.ops: - op_type = type(op).__name__ - - # Reset operations on arrays are the main culprit - if op_type == "Prep" and hasattr(op, "qargs"): - for qarg in op.qargs: - # If it's an array operation, it might cause issues - if hasattr(qarg, "sym") and hasattr(qarg, "size") and qarg.size > 1: - return True - - # Multiple operations on the same array elements might cause issues - # This is a more complex heuristic we could add later - - # Recursively check nested blocks - if hasattr(op, "ops") and self._needs_unpacking_workaround(op): - return True - - return False - - def _function_needs_unpacking(self, func_name: str) -> bool: - """Check if a function uses the unpacking pattern by analyzing function behavior. - - This method analyzes the actual function operations rather than using hardcoded names, - making it general for all QEC codes. - """ - _ = func_name # Currently not used, reserved for future use - # Since this function is not currently used, return False for now - # In the future, this could analyze the function's block to determine - # if it performs operations that would benefit from unpacking - return False - - def _function_consumes_parameters(self, func_name: str, block) -> bool: - """Check if a function consumes its quantum parameters (has @owned).""" - _ = func_name # Currently not used, reserved for future use - # Check if we already know about this function - if hasattr(block, "ops"): - return self._block_consumes_quantum(block) - - # Default: assume functions don't consume unless we know otherwise - return False - - def _is_variable_used_as_function_arg(self, var_name: str, block) -> bool: - """Check if a variable is used as an argument to block operations (functions).""" - if not hasattr(block, "ops"): - return False - - for op in block.ops: - # Check if this is a Block-type operation - if hasattr(op, "ops") and hasattr(op, "vars"): - # This is a block - check variables used by operations inside it - # Since constructor arguments aren't preserved, we need to analyze the inner operations - for inner_op in op.ops: - # Check quantum arguments - if hasattr(inner_op, "qargs"): - for qarg in inner_op.qargs: - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - if qarg.reg.sym == var_name: - return True - elif hasattr(qarg, "sym") and qarg.sym == var_name: - return True - - # Check measurement targets - if hasattr(inner_op, "cout") and inner_op.cout: - for cout in inner_op.cout: - if hasattr(cout, "reg") and hasattr(cout.reg, "sym"): - if cout.reg.sym == var_name: - return True - elif hasattr(cout, "sym") and cout.sym == var_name: - return True - - return False - - def _create_array_unpack_statement( - self, - array_name: str, - element_names: list[str], - ) -> Statement: - """Create an array unpacking statement: q_0, q_1, q_2 = q""" - - class ArrayUnpackStatement(Statement): - def __init__(self, targets, source): - self.targets = targets - self.source = source - - def analyze(self, context): - _ = context # Not used - - def render(self, context): - _ = context # Not used - # For single element unpacking, we need a trailing comma - target_str = self.targets[0] + "," if len(self.targets) == 1 else ", ".join(self.targets) - return [f"{target_str} = {self.source}"] - - return ArrayUnpackStatement(element_names, array_name) - - def _create_array_construction(self, element_names: list[str]) -> Expression: - """Create an array construction expression: array([q_0, q_1, q_2])""" - - class ArrayConstructionExpression(Expression): - def __init__(self, elements): - self.elements = elements - - def analyze(self, context): - _ = context # Not used - - def render(self, context): - _ = context # Not used - element_str = ", ".join(self.elements) - return [f"array({element_str})"] - - return ArrayConstructionExpression(element_names) - - def _create_array_reconstruction(self, element_names: list[str]) -> Expression: - """Create an array reconstruction expression for returns: array([q_0, q_1])""" - - # Apply variable remapping to get the latest names - # Use function_var_remapping if available (includes Prep changes) - remapping = ( - self.function_var_remapping - if hasattr(self, "function_var_remapping") - else self.variable_remapping if hasattr(self, "variable_remapping") else {} - ) - remapped_element_names = [remapping.get(elem, elem) for elem in element_names] - - class ArrayReconstructionExpression(Expression): - def __init__(self, elements): - self.elements = elements - - def analyze(self, context): - _ = context # Not used - - def render(self, context): - _ = context # Not used - element_str = ", ".join(self.elements) - return [f"array({element_str})"] - - return ArrayReconstructionExpression(remapped_element_names) - - def _create_struct_construction( - self, - struct_name: str, - field_names: list[str], - field_values: list[Expression], - ) -> Expression: - """Create a struct construction expression.""" - - class StructConstructionExpression(Expression): - def __init__(self, struct_name, field_names, field_values): - self.struct_name = struct_name - self.field_names = field_names - self.field_values = field_values - - def analyze(self, context): - for value in self.field_values: - value.analyze(context) - - def render(self, context): - # Render as struct_name(value1, value2, ...) - positional args only - # Guppy doesn't support keyword arguments in struct construction - field_values_str = [] - for value in self.field_values: - value_str = value.render(context)[0] - field_values_str.append(value_str) - return [f"{self.struct_name}({', '.join(field_values_str)})"] - - return StructConstructionExpression(struct_name, field_names, field_values) - - def _add_array_unpacking(self, array_name: str, size: int) -> None: - """Add array unpacking statement.""" - # Check if this array is already unpacked in the current context - if hasattr(self, "unpacked_vars") and array_name in self.unpacked_vars: - # Array is already unpacked, don't unpack again - return - - # Get the actual variable name (might be renamed) - actual_name = array_name - if array_name in self.plan.renamed_variables: - actual_name = self.plan.renamed_variables[array_name] - - # Generate unpacked names - unpacked_names = [self._get_unique_var_name(array_name, i) for i in range(size)] - - # Track unpacked vars in the builder - self.unpacked_vars[array_name] = unpacked_names - - # Comment already added by caller, don't add another one - - # Add unpacking statement - unpack = ArrayUnpack( - targets=unpacked_names, - source=actual_name, - ) - self.current_block.statements.append(unpack) - - # Update variable info - var = self.context.lookup_variable(actual_name) - if var: - var.is_unpacked = True - var.unpacked_names = unpacked_names - - def _is_prep_rus_block(self, op) -> bool: - """Check if this is a PrepRUS block that needs special handling.""" - return hasattr(op, "block_name") and op.block_name == "PrepRUS" - - def _convert_prep_rus_special(self, op) -> Statement | None: - """Special conversion for PrepRUS to avoid linearity issues.""" - # PrepRUS has a specific pattern that causes issues: - # 1. PrepEncodingFTZero creates fresh variables - # 2. Repeat with conditional PrepEncodingFTZero - # 3. LogZeroRot uses the variables - - # We'll generate a simplified version that avoids the conditional consumption - self.current_block.statements.append( - Comment("Special handling for PrepRUS to avoid linearity issues"), - ) - - # Process the operations in PrepRUS - if hasattr(op, "ops"): - for sub_op in op.ops: - # Skip the Repeat block with conditional consumption - if type(sub_op).__name__ == "Repeat": - # Instead of the loop with conditional, just do it once unconditionally - self.current_block.statements.append( - Comment("Simplified repeat to avoid conditional consumption"), - ) - # Don't process the Repeat block - continue - - # Process other operations normally - stmt = self._convert_operation(sub_op) - if stmt: - self.current_block.statements.append(stmt) - - return None - - def _convert_operation(self, op) -> Statement | None: - """Convert an SLR operation to IR statement.""" - op_type = type(op).__name__ - - if op_type == "Measure": - return self._convert_measurement(op) - if op_type == "If": - return self._convert_if(op) - if op_type == "While": - return self._convert_while(op) - if op_type == "For": - return self._convert_for(op) - if op_type == "Repeat": - return self._convert_repeat(op) - if op_type == "Comment": - return self._convert_comment(op) - if op_type == "Permute": - return self._convert_permute(op) - if hasattr(op, "qargs"): - stmt = self._convert_quantum_gate(op) - # Handle case where quantum gate returns a Block - if stmt and type(stmt).__name__ == "Block": - # Add all statements from the block - for s in stmt.statements: - self.current_block.statements.append(s) - return None # Already added - return stmt - if hasattr(op, "ops") and hasattr(op, "vars"): - # This is a block - convert to function call - return self._convert_block_call(op) - if op_type == "SET": - # Classical bit assignment - return self._convert_set_operation(op) - if op_type == "Barrier": - # Barriers are just synchronization points, ignore in Guppy - return None - if op_type == "Return": - # Return is metadata for type checking and block analysis - # The actual return handling is done by the function generation code - return None - - # Unknown operation - return Comment(f"TODO: Handle {op_type}") - - def _convert_measurement(self, meas) -> Statement | None: - """Convert measurement operation.""" - if not hasattr(meas, "qargs") or not meas.qargs: - return None - - # Check if we're measuring a struct field qubit with @owned struct - if hasattr(meas, "qargs") and len(meas.qargs) > 0: - qarg = meas.qargs[0] - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - array_name = qarg.reg.sym - # Check if this is a struct field - for info in self.struct_info.values(): - if ( - array_name in info["var_names"].values() - and hasattr(self, "function_info") - and hasattr(self, "current_function_name") - ): - func_info = self.function_info.get( - self.current_function_name, - {}, - ) - if func_info.get("has_owned_struct_params", False): - # This is a known limitation - add a warning comment - self.current_block.statements.append( - Comment( - "WARNING: Measuring qubits from @owned struct arrays " - "is not supported by guppylang", - ), - ) - self.current_block.statements.append( - Comment( - "This will cause a MoveOutOfSubscriptError during compilation", - ), - ) - - # Check if we're in a function that takes and returns a struct - # If so, we need to be careful about struct field access - if hasattr(self, "current_function_params"): - for param_name, param_type in self.current_function_params: - if "_struct" in str(param_type) and "@owned" not in str(param_type): - break - - # Check if this is a full array measurement - if ( - len(meas.qargs) == 1 - and hasattr(meas.qargs[0], "sym") - and hasattr(meas.qargs[0], "size") - and meas.qargs[0].size >= 1 - ): - # Full array measurement - qreg = meas.qargs[0] - - # Track full array consumption globally - if not hasattr(self, "consumed_resources"): - self.consumed_resources = {} - if qreg.sym not in self.consumed_resources: - self.consumed_resources[qreg.sym] = set() - self.consumed_resources[qreg.sym].update(range(qreg.size)) - - # Track in scope manager too - self.scope_manager.track_resource_usage( - qreg.sym, - set(range(qreg.size)), - consumed=True, - ) - - # Check if this array was dynamically allocated - if hasattr(self, "dynamic_allocations") and qreg.sym in self.dynamic_allocations: - # For dynamically allocated arrays, we need to handle this differently - # Generate individual measurements - stmts = [] - - # Check for target - if hasattr(meas, "cout") and meas.cout and len(meas.cout) == 1: - cout = meas.cout[0] - if hasattr(cout, "sym"): - creg_name = cout.sym - # Measure each individual qubit - for i in range(qreg.size): - ancilla_var = self._get_unique_var_name(qreg.sym, i) - # Allocate if not already allocated - if not hasattr(self, "allocated_ancillas"): - self.allocated_ancillas = set() - if ancilla_var not in self.allocated_ancillas: - alloc_stmt = Assignment( - target=VariableRef(ancilla_var), - value=FunctionCall( - func_name="quantum.qubit", - args=[], - ), - ) - stmts.append(alloc_stmt) - self.allocated_ancillas.add(ancilla_var) - - # Measure individual qubit - meas_call = FunctionCall( - func_name="quantum.measure", - args=[VariableRef(ancilla_var)], - ) - creg_access = ArrayAccess(array_name=creg_name, index=i) - assign = Assignment(target=creg_access, value=meas_call) - stmts.append(assign) - - # Return block with all statements - if len(stmts) == 1: - return stmts[0] - return Block(statements=stmts) - else: - # No target - measure individual qubits without storing - for i in range(qreg.size): - # Use consistent mapping from (array_name, index) to variable name - if not hasattr(self, "allocated_qubit_vars"): - self.allocated_qubit_vars = {} - - array_index_key = (qreg.sym, i) - - # Check if we already have a variable for this array element - if array_index_key in self.allocated_qubit_vars: - ancilla_var = self.allocated_qubit_vars[array_index_key] - else: - # Create a new variable name for this specific array element - ancilla_var = self._get_unique_var_name(qreg.sym, i) - self.allocated_qubit_vars[array_index_key] = ancilla_var - - alloc_stmt = Assignment( - target=VariableRef(ancilla_var), - value=FunctionCall(func_name="quantum.qubit", args=[]), - ) - stmts.append(alloc_stmt) - - # Measure and discard result - meas_call = FunctionCall( - func_name="quantum.measure", - args=[VariableRef(ancilla_var)], - ) - - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return f"_ = {self.expr.render(context)}" - - stmts.append(ExpressionStatement(meas_call)) - - if len(stmts) == 1: - return stmts[0] - return Block(statements=stmts) - else: - # Regular pre-allocated array - use measure_array - qreg_ref = self._convert_qubit_ref(qreg) - - # If the array was previously unpacked (e.g., to access an - # individual element after a function call returned it), - # Guppy considers the original variable name consumed by - # the unpack. Repack from the element vars so measure_array - # can take the whole array as input. We emit the repack - # statement *prepended* to whatever statement(s) the rest - # of this branch produces (see `_prepend_to_result`). - # - # var_state and the legacy `unpacked_vars` dict are both - # updated so other code paths agree the array is whole again. - repack_stmt = None - if hasattr(qreg, "sym") and self.var_state.is_unpacked(qreg.sym): - binding = self.var_state.get(qreg.sym) - repack_stmt = Assignment( - target=VariableRef(qreg.sym), - value=self._create_array_reconstruction(list(binding.element_names)), - ) - self.var_state.bind_whole(qreg.sym, qreg.sym) - if hasattr(self, "unpacked_vars") and qreg.sym in self.unpacked_vars: - del self.unpacked_vars[qreg.sym] - if hasattr(self, "context"): - var = self.context.lookup_variable(qreg.sym) - if var: - var.is_unpacked = False - var.unpacked_names = [] - # qreg_ref was computed *before* the repack -- recompute - # so it points at the now-whole array, not stale unpacked - # element variables. - qreg_ref = self._convert_qubit_ref(qreg) - - # Mark fresh variable as used if this is measuring a fresh variable - if hasattr(self, "fresh_variables_to_track") and hasattr( - self, - "refreshed_arrays", - ): - # Check if qreg is using a fresh variable - for orig_name, fresh_name in self.refreshed_arrays.items(): - if fresh_name in self.fresh_variables_to_track and orig_name == qreg.sym: - # Mark this fresh variable as used - self.fresh_variables_to_track[fresh_name]["used"] = True - break - - # Check for target - if hasattr(meas, "cout") and meas.cout and len(meas.cout) == 1: - cout = meas.cout[0] - if hasattr(cout, "sym"): - # Check for renamed variable - creg_name = cout.sym - if creg_name in self.plan.renamed_variables: - creg_name = self.plan.renamed_variables[creg_name] - - # Check if this variable is remapped (e.g., function parameter) - is_function_param = False - if hasattr(self, "var_remapping") and creg_name in self.var_remapping: - creg_name = self.var_remapping[creg_name] - # Check if this is a function parameter (not in main) - is_function_param = ( - hasattr(self, "current_function_name") and self.current_function_name != "main" - ) - - # For function parameters (classical arrays), we need to update in-place - # to avoid BorrowShadowedError - if is_function_param: - # Generate element-wise measurements - stmts = [] - - # IMPORTANT: Do NOT automatically replace qubits after measurement - # The old logic tried to maintain array size, but this breaks partial consumption. - # Only replace if allocation optimizer detected reuse. - should_replace = False # Disabled automatic replacement - - for i in range(qreg.size): - # Check if the quantum array was unpacked - if hasattr(self, "unpacked_vars") and qreg.sym in self.unpacked_vars: - # Use unpacked variable - element_names = self.unpacked_vars[qreg.sym] - qubit_ref = VariableRef(element_names[i]) - qubit_var_name = element_names[i] - else: - # Use array access - qubit_ref = ArrayAccess( - array_name=( - self._convert_qubit_ref(qreg).name - if hasattr( - self._convert_qubit_ref(qreg), - "name", - ) - else qreg.sym - ), - index=i, - ) - qubit_var_name = None - - meas_call = FunctionCall( - func_name="quantum.measure", - args=[qubit_ref], - ) - # Assign to array element - creg_access = ArrayAccess(array_name=creg_name, index=i) - assign = Assignment(target=creg_access, value=meas_call) - stmts.append(assign) - - # Replace measured qubit with fresh one if needed - if should_replace and qubit_var_name: - replacement_stmt = Assignment( - target=VariableRef(qubit_var_name), - value=FunctionCall( - func_name="quantum.qubit", - args=[], - ), - ) - stmts.append(replacement_stmt) - - # Track that this qubit was replaced - if not hasattr(self, "replaced_qubits"): - self.replaced_qubits = {} - if qreg.sym not in self.replaced_qubits: - self.replaced_qubits[qreg.sym] = set() - self.replaced_qubits[qreg.sym].add(i) - - # Return block with all statements - if len(stmts) == 1: - return stmts[0] - return Block(statements=stmts) - # Not a function parameter - can reassign whole array - creg_ref = VariableRef(creg_name) - # Generate measure_array - call = FunctionCall( - func_name="quantum.measure_array", - args=[qreg_ref], - ) - result = Assignment(target=creg_ref, value=call) - if repack_stmt is not None: - return Block(statements=[repack_stmt, result]) - return result - - # No target - just measure - call = FunctionCall( - func_name="quantum.measure_array", - args=[qreg_ref], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - result = ExpressionStatement(call) - if repack_stmt is not None: - return Block(statements=[repack_stmt, result]) - return result - - # Handle single qubit measurement - if len(meas.qargs) == 1: - qarg = meas.qargs[0] - qubit_ref = self._convert_qubit_ref(qarg) - - # Get target if specified - target_ref = None - if hasattr(meas, "cout") and meas.cout and len(meas.cout) == 1: - cout = meas.cout[0] - # For measurements, the target should use unpacked names if available - # So we pass is_assignment_target=False to use unpacked names - target_ref = self._convert_bit_ref(cout, is_assignment_target=False) - - # Track resource consumption for linearity checking - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index"): - array_name = qarg.reg.sym - qubit_index = qarg.index - self.scope_manager.track_resource_usage( - array_name, - {qubit_index}, - consumed=True, - ) - - # Also track globally for conditional resource balancing - if not hasattr(self, "consumed_resources"): - self.consumed_resources = {} - if array_name not in self.consumed_resources: - self.consumed_resources[array_name] = set() - self.consumed_resources[array_name].add(qubit_index) - - # Generate measurement statement - meas_stmt = Measurement(qubit=qubit_ref, target=target_ref) - - # IMPORTANT: Do NOT automatically replace measured qubits! - # The old "black box pattern" logic assumed functions should maintain array size, - # but this breaks partial consumption patterns where a function consumes some qubits - # and returns others. Only explicit Prep operations should create fresh qubits. - # - # The correct behavior: - # - Measure consumes the qubit → it's gone - # - If user wants to reset, they use explicit Prep(q[i]) → creates fresh qubit - # - Function returns only the qubits that weren't consumed - # - # Check if this qubit is marked as needing replacement due to reuse - # (e.g., unified analysis detected it's used again after consumption) - needs_replacement_for_reuse = False - if self.unified_analysis and hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index"): - array_name = qarg.reg.sym - qubit_index = qarg.index - resource_plan = self.unified_analysis.get_plan(array_name) - if resource_plan and qubit_index in resource_plan.elements_requiring_replacement: - # CRITICAL: Check if the next operation is a Prep on this same qubit - # If so, skip measurement replacement - let Prep handle it - next_op_is_prep_on_same_qubit = False - if ( - hasattr(self, "current_block_ops") - and hasattr(self, "current_op_index") - and self.current_block_ops is not None - and self.current_op_index is not None - ): - next_index = self.current_op_index + 1 - if next_index < len(self.current_block_ops): - next_op = self.current_block_ops[next_index] - # Check if next operation is Prep on the same qubit - if type(next_op).__name__ == "Prep" and hasattr( - next_op, - "qargs", - ): - for prep_qarg in next_op.qargs: - if ( - hasattr(prep_qarg, "reg") - and hasattr(prep_qarg.reg, "sym") - and prep_qarg.reg.sym == array_name - and hasattr(prep_qarg, "index") - and prep_qarg.index == qubit_index - ): - next_op_is_prep_on_same_qubit = True - break - - if not next_op_is_prep_on_same_qubit: - # No Prep follows - we need to replace the qubit - needs_replacement_for_reuse = True - - # Only replace if allocation optimizer determined it's reused - if ( - needs_replacement_for_reuse - and hasattr(self, "unpacked_vars") - and hasattr(qarg, "reg") - and hasattr(qarg.reg, "sym") - and hasattr(qarg, "index") - ): - array_name = qarg.reg.sym - qubit_index = qarg.index - - # Check if this array is unpacked in current function - if array_name in self.unpacked_vars: - element_names = self.unpacked_vars[array_name] - if qubit_index < len(element_names): - # Replace the measured qubit with a fresh one - replacement_stmt = Assignment( - target=VariableRef(element_names[qubit_index]), - value=FunctionCall(func_name="quantum.qubit", args=[]), - ) - - # Track that this qubit was replaced (not consumed) - if not hasattr(self, "replaced_qubits"): - self.replaced_qubits = {} - if array_name not in self.replaced_qubits: - self.replaced_qubits[array_name] = set() - self.replaced_qubits[array_name].add(qubit_index) - - # Return a block with measurement followed by replacement - statements = [meas_stmt, replacement_stmt] - return Block(statements=statements) - - return meas_stmt - - # Handle multi-qubit measurements by generating multiple single-qubit measurements - if len(meas.qargs) > 1: - # Verify we have corresponding classical outputs - if not hasattr(meas, "cout") or not meas.cout: - # No classical outputs specified - generate measurements without targets - measurements = [] - for qarg in meas.qargs: - qubit_ref = self._convert_qubit_ref(qarg) - - # Track resource consumption for each qubit - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index"): - array_name = qarg.reg.sym - qubit_index = qarg.index - self.scope_manager.track_resource_usage( - array_name, - {qubit_index}, - consumed=True, - ) - - # Also track globally for conditional resource balancing - if not hasattr(self, "consumed_resources"): - self.consumed_resources = {} - if array_name not in self.consumed_resources: - self.consumed_resources[array_name] = set() - self.consumed_resources[array_name].add(qubit_index) - - meas_stmt = Measurement(qubit=qubit_ref, target=None) - measurements.append(meas_stmt) - - return Block(statements=measurements) - - # Multi-qubit measurement with classical outputs - if len(meas.cout) != len(meas.qargs): - # Mismatch between number of qubits and classical outputs - return Comment( - f"ERROR: Multi-qubit measurement has {len(meas.qargs)} qubits " - f"but {len(meas.cout)} classical outputs", - ) - - # Generate one measurement statement for each qubit-bit pair - measurements = [] - for qarg, cout in zip(meas.qargs, meas.cout): - qubit_ref = self._convert_qubit_ref(qarg) - target_ref = self._convert_bit_ref(cout, is_assignment_target=False) - - # Track resource consumption for each qubit - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index"): - array_name = qarg.reg.sym - qubit_index = qarg.index - self.scope_manager.track_resource_usage( - array_name, - {qubit_index}, - consumed=True, - ) - - # Also track globally for conditional resource balancing - if not hasattr(self, "consumed_resources"): - self.consumed_resources = {} - if array_name not in self.consumed_resources: - self.consumed_resources[array_name] = set() - self.consumed_resources[array_name].add(qubit_index) - - # Generate measurement statement - meas_stmt = Measurement(qubit=qubit_ref, target=target_ref) - measurements.append(meas_stmt) - - # Return a block containing all the measurements - return Block(statements=measurements) - - # Shouldn't reach here, but just in case - return Comment(f"Unhandled measurement with {len(meas.qargs)} qubits") - - def _convert_qubit_ref(self, qarg) -> IRNode: - """Convert a qubit reference to IR.""" - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - array_name = qarg.reg.sym - original_array = array_name - - # Check if this array has been remapped to a reconstructed name - if hasattr(self, "array_remapping") and array_name in self.array_remapping: - # Use the reconstructed array name instead - remapped_name = self.array_remapping[array_name] - - # Check if the original array was unpacked after remapping - # If it was, use the unpacked variables instead of array indexing - if hasattr(self, "unpacked_vars") and array_name in self.unpacked_vars and hasattr(qarg, "index"): - element_names = self.unpacked_vars[array_name] - - # CRITICAL: Check if we have index mapping for partial consumption - # If so, map original index to unpacked variable index - if hasattr(self, "index_mapping") and array_name in self.index_mapping: - mapped_index = self.index_mapping[array_name].get(qarg.index) - if mapped_index is not None and mapped_index < len( - element_names, - ): - var_name = element_names[mapped_index] - # Apply variable remapping if exists - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - elif qarg.index < len(element_names) and element_names[qarg.index] is not None: - # No index mapping - use direct indexing (full array return) - var_name = element_names[qarg.index] - # Apply variable remapping if exists - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - - # Not unpacked, use array indexing with remapped name - if hasattr(qarg, "index"): - return ArrayAccess( - array=VariableRef(remapped_name), - index=qarg.index, - force_array_syntax=True, # Force array syntax for remapped arrays - ) - - # Check if this array has been refreshed by function call - # If it was refreshed AND then unpacked, use the unpacked variables - if hasattr(self, "refreshed_arrays") and array_name in self.refreshed_arrays and hasattr(qarg, "index"): - # Array was refreshed by function call - fresh_array_name = self.refreshed_arrays[array_name] - - # Check if the original array name was unpacked after refresh - # (the unpacked_vars gets updated to point to the new unpacked elements) - if hasattr(self, "unpacked_vars") and array_name in self.unpacked_vars: - # It was unpacked after being refreshed - use unpacked variables - element_names = self.unpacked_vars[array_name] - - # CRITICAL: Check if we have index mapping for partial consumption - # If so, map original index to unpacked variable index - if hasattr(self, "index_mapping") and array_name in self.index_mapping: - # Map original index to position in returned array - mapped_index = self.index_mapping[array_name].get(qarg.index) - if mapped_index is not None and mapped_index < len( - element_names, - ): - var_name = element_names[mapped_index] - # Apply variable remapping if exists - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - elif qarg.index < len(element_names) and element_names[qarg.index] is not None: - # No index mapping - use direct indexing (full array return) - var_name = element_names[qarg.index] - # Apply variable remapping if exists - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - - # Also check if the fresh array itself was unpacked - if hasattr(self, "unpacked_vars") and fresh_array_name in self.unpacked_vars: - element_names = self.unpacked_vars[fresh_array_name] - if qarg.index < len(element_names) and element_names[qarg.index] is not None: - var_name = element_names[qarg.index] - # Apply variable remapping if exists - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - - # Not unpacked - use array indexing on fresh name - return ArrayAccess( - array=VariableRef(fresh_array_name), - index=qarg.index, - force_array_syntax=True, # Force array syntax for refreshed arrays - ) - - # Check if this array has been unpacked (for ancilla arrays with @owned) - if hasattr(self, "unpacked_vars") and array_name in self.unpacked_vars and hasattr(qarg, "index"): - # This array was unpacked - use the unpacked variable directly - element_names = self.unpacked_vars[array_name] - if qarg.index < len(element_names) and element_names[qarg.index] is not None: - var_name = element_names[qarg.index] - # Apply variable remapping if exists - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - if qarg.index < len(element_names) and element_names[qarg.index] is None: - # This element was consumed - this is an error case but let's fallback - pass - - # Check if this variable is mapped to a struct field (for @owned structs) - if hasattr(self, "struct_field_mapping") and original_array in self.struct_field_mapping: - struct_field_path = self.struct_field_mapping[original_array] - if "." in struct_field_path: - struct_name, field_name = struct_field_path.split(".", 1) - if hasattr(qarg, "index"): - # Return struct.field[index] - field_access = FieldAccess( - obj=VariableRef(struct_name), - field=field_name, - ) - return ArrayAccess(array=field_access, index=qarg.index) - # Return struct.field - return FieldAccess(obj=VariableRef(struct_name), field=field_name) - - # Check if this is a dynamically allocated array (ancilla) - if ( - hasattr(self, "dynamic_allocations") - and original_array in self.dynamic_allocations - and hasattr(qarg, "index") - ): - # Use a consistent mapping from (array_name, index) to variable name - if not hasattr(self, "allocated_qubit_vars"): - self.allocated_qubit_vars = {} - - array_index_key = (original_array, qarg.index) - - # Check if we already have a variable for this array element - if array_index_key in self.allocated_qubit_vars: - var_name = self.allocated_qubit_vars[array_index_key] - # Apply variable remapping if exists (for Prep operations) - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - - # Create a new variable name for this specific array element - ancilla_var = self._get_unique_var_name(original_array, qarg.index) - - # Record the mapping and allocate the qubit - self.allocated_qubit_vars[array_index_key] = ancilla_var - - # Also track in allocated_ancillas for cleanup - if not hasattr(self, "allocated_ancillas"): - self.allocated_ancillas = set() - self.allocated_ancillas.add(ancilla_var) - - alloc_stmt = Assignment( - target=VariableRef(ancilla_var), - value=FunctionCall(func_name="quantum.qubit", args=[]), - ) - self.current_block.statements.append(alloc_stmt) - - # Apply variable remapping if exists (for Prep operations) - var_name = self.variable_remapping.get(ancilla_var, ancilla_var) - return VariableRef(var_name) - - # Check if this variable is part of a struct and has been unpacked - if hasattr(self, "var_remapping") and original_array in self.var_remapping: - # Use the unpacked field variable - unpacked_var_name = self.var_remapping[original_array] - if hasattr(qarg, "index"): - # Array element access with unpacked variable: c_d[0] - return ArrayAccess( - array=VariableRef(unpacked_var_name), - index=qarg.index, - ) - # Full array access with unpacked variable: c_d - return VariableRef(unpacked_var_name) - - # Check if this array is part of a struct (fallback) - for prefix, info in self.struct_info.items(): - if array_name in info["var_names"].values(): - # This is a struct field - suffix = next(k for k, v in info["var_names"].items() if v == array_name) - - # Check if we're in a function that takes this struct as parameter - struct_param_name = prefix # Default to the struct name - if hasattr(self, "param_mapping") and prefix in self.param_mapping: - struct_param_name = self.param_mapping[prefix] - - # Check if the struct has a fresh version (after function calls) - if hasattr(self, "refreshed_arrays") and prefix in self.refreshed_arrays: - struct_param_name = self.refreshed_arrays[prefix] - - if hasattr(qarg, "index"): - # Struct field element access: c.d[0] - field_access = FieldAccess( - obj=VariableRef(struct_param_name), - field=suffix, - ) - return ArrayAccess(array=field_access, index=qarg.index) - # Full struct field access: c.d - return FieldAccess(obj=VariableRef(struct_param_name), field=suffix) - - # Check if we're inside a function and need to use remapped names - if hasattr(self, "var_remapping") and original_array in self.var_remapping: - array_name = self.var_remapping[original_array] - - # Check for renaming - if array_name in self.plan.renamed_variables: - array_name = self.plan.renamed_variables[array_name] - - if hasattr(qarg, "index"): - # Array Unpacking Pattern: use unpacked variable names instead of array indexing - # Check both the original name and any remapped name - check_names = [original_array] - if hasattr(self, "var_remapping") and original_array in self.var_remapping: - check_names.append(self.var_remapping[original_array]) - if array_name != original_array: - check_names.append(array_name) - - # Try each possible name for unpacked variables - for check_name in check_names: - if ( - hasattr(self, "unpacked_vars") - and check_name in self.unpacked_vars - # Don't use unpacked variables if the array was refreshed - and check_name not in self.refreshed_arrays - ): - element_names = self.unpacked_vars[check_name] - if qarg.index < len(element_names): - var_name = element_names[qarg.index] - # Apply variable remapping if exists - var_name = self.variable_remapping.get(var_name, var_name) - return VariableRef(var_name) - - # Check if this element should be allocated locally - resource_plan = None - if self.unified_analysis: - resource_plan = self.unified_analysis.get_plan(original_array) - if resource_plan and qarg.index in resource_plan.elements_to_allocate_locally: - # This element should be allocated locally - local_var_name = f"{original_array}_{qarg.index}_local" - - # Add local allocation if not already done - if not hasattr(self, "_local_allocations"): - self._local_allocations = set() - - if local_var_name not in self._local_allocations: - self._local_allocations.add(local_var_name) - # Add allocation statement - alloc_stmt = Assignment( - target=VariableRef(local_var_name), - value=FunctionCall(func_name="quantum.qubit", args=[]), - ) - self.current_block.statements.append(alloc_stmt) - - # Apply variable remapping if exists (for Prep operations) - local_var_name = self.variable_remapping.get( - local_var_name, - local_var_name, - ) - return VariableRef(local_var_name) - - # Array element access - # Skip this shortcut - we need to check for unpacked vars first - # The unpacking check above should handle function cases too - - # In main function, check if this array is unpacked - if original_array in self.plan.arrays_to_unpack: - # This array should be unpacked, use unpacked name - info = self.plan.arrays_to_unpack[original_array] - if qarg.index < info.size: - # Check if the array is actually unpacked yet - var_info = self.context.lookup_variable(array_name) - if var_info and var_info.is_unpacked: - # Use the actual unpacked name from our tracking - if array_name in self.unpacked_vars and qarg.index < len( - self.unpacked_vars[array_name], - ): - unpacked_name = self.unpacked_vars[array_name][qarg.index] - else: - # Fallback to generating the name (should not normally happen) - unpacked_name = self._get_unique_var_name( - original_array, - qarg.index, - ) - # Apply variable remapping if exists (for Prep operations) - unpacked_name = self.variable_remapping.get( - unpacked_name, - unpacked_name, - ) - return VariableRef(unpacked_name) - - # Not unpacked or inside function, use array access - return ArrayAccess(array_name=array_name, index=qarg.index) - - # Full array reference - check if array was refreshed by function call - if hasattr(self, "refreshed_arrays") and original_array in self.refreshed_arrays: - # Use the fresh returned array name instead of the original - fresh_array_name = self.refreshed_arrays[original_array] - return VariableRef(fresh_array_name) - - return VariableRef(array_name) - if hasattr(qarg, "sym"): - # Direct variable reference - var_name = qarg.sym - original_var = var_name - - # Check if this variable was refreshed by function call - if hasattr(self, "refreshed_arrays") and original_var in self.refreshed_arrays: - # Use the fresh returned variable name instead of the original - fresh_var_name = self.refreshed_arrays[original_var] - return VariableRef(fresh_var_name) - - # Check if we're inside a function and need to use remapped names - if hasattr(self, "var_remapping") and original_var in self.var_remapping: - var_name = self.var_remapping[original_var] - - # Check for renaming - if var_name in self.plan.renamed_variables: - var_name = self.plan.renamed_variables[var_name] - return VariableRef(var_name) - - # Fallback - return VariableRef(str(qarg)) - - def _convert_bit_ref(self, carg, *, is_assignment_target: bool = False) -> IRNode: - """Convert a classical bit reference to IR. - - Args: - carg: The classical argument to convert - is_assignment_target: If True, always use array indexing (for assignments) - """ - if hasattr(carg, "reg") and hasattr(carg.reg, "sym"): - array_name = carg.reg.sym - original_array = array_name - - # Check if this array has been refreshed by function call - # If so, prefer array indexing over stale unpacked variables - if hasattr(self, "refreshed_arrays") and array_name in self.refreshed_arrays and hasattr(carg, "index"): - # Array was refreshed by function call - use the fresh returned name - fresh_array_name = self.refreshed_arrays[array_name] - return ArrayAccess( - array=VariableRef(fresh_array_name), - index=carg.index, - force_array_syntax=True, # Force array syntax for refreshed arrays - ) - - # Check if this variable is mapped to a struct field (for @owned structs) - if hasattr(self, "struct_field_mapping") and original_array in self.struct_field_mapping: - struct_field_path = self.struct_field_mapping[original_array] - if "." in struct_field_path: - struct_name, field_name = struct_field_path.split(".", 1) - if hasattr(carg, "index"): - # Return struct.field[index] - field_access = FieldAccess( - obj=VariableRef(struct_name), - field=field_name, - ) - return ArrayAccess(array=field_access, index=carg.index) - # Return struct.field - return FieldAccess(obj=VariableRef(struct_name), field=field_name) - - # Check if this variable is part of a struct and has been unpacked - if hasattr(self, "var_remapping") and original_array in self.var_remapping: - # Use the unpacked field variable - unpacked_var_name = self.var_remapping[original_array] - if hasattr(carg, "index"): - # Array element access with unpacked variable: c_verify_prep[0] - return ArrayAccess( - array=VariableRef(unpacked_var_name), - index=carg.index, - ) - # Full array access with unpacked variable: c_verify_prep - return VariableRef(unpacked_var_name) - - # Check if this variable is part of a struct in main context (fallback) - for prefix, info in self.struct_info.items(): - if original_array in info["var_names"].values(): - # Find the field name - for suffix, var_name in info["var_names"].items(): - if var_name == original_array: - # Check if the struct has been decomposed and we should use decomposed variables - if hasattr(self, "var_remapping") and original_array in self.var_remapping: - # Struct was decomposed - use the decomposed variable directly - decomposed_var = self.var_remapping[original_array] - if hasattr(carg, "index"): - return ArrayAccess( - array=VariableRef(decomposed_var), - index=carg.index, - ) - return VariableRef(decomposed_var) - - # Check if we're in a function that receives the struct - struct_param_name = prefix - if hasattr(self, "param_mapping") and prefix in self.param_mapping: - struct_param_name = self.param_mapping[prefix] - - # Check if we have decomposed variables for fresh structs - if hasattr(self, "refreshed_arrays") and prefix in self.refreshed_arrays: - fresh_struct_name = self.refreshed_arrays[prefix] - # Check if this fresh struct was decomposed - if hasattr(self, "decomposed_vars") and fresh_struct_name in self.decomposed_vars: - # Use the decomposed variable - field_vars = self.decomposed_vars[fresh_struct_name] - if suffix in field_vars: - decomposed_var = field_vars[suffix] - if hasattr(carg, "index"): - return ArrayAccess( - array=VariableRef(decomposed_var), - index=carg.index, - ) - return VariableRef(decomposed_var) - struct_param_name = fresh_struct_name - - if hasattr(carg, "index"): - # Struct field element access: c.verify_prep[0] - field_access = FieldAccess( - obj=VariableRef(struct_param_name), - field=suffix, - ) - return ArrayAccess(array=field_access, index=carg.index) - # Full struct field access: c.verify_prep - return FieldAccess( - obj=VariableRef(struct_param_name), - field=suffix, - ) - - # Check if we're inside a function and need to use remapped names - if hasattr(self, "var_remapping") and original_array in self.var_remapping: - array_name = self.var_remapping[original_array] - - # Check for renaming - if array_name in self.plan.renamed_variables: - array_name = self.plan.renamed_variables[array_name] - - if hasattr(carg, "index"): - # Check if this array is unpacked and we're not assigning - var_info = self.context.lookup_variable(array_name) - if ( - not is_assignment_target - and var_info - and var_info.is_unpacked - and hasattr(var_info, "unpacked_names") - ): - # Use unpacked variable name for reading - index = carg.index - if index < len(var_info.unpacked_names): - return VariableRef(var_info.unpacked_names[index]) - - # Use array access for assignments or non-unpacked arrays - return ArrayAccess(array_name=array_name, index=carg.index) - # Full array reference - return VariableRef(array_name) - if hasattr(carg, "sym"): - # Direct variable reference - var_name = carg.sym - # Check for renaming - if var_name in self.plan.renamed_variables: - var_name = self.plan.renamed_variables[var_name] - return VariableRef(var_name) - - # Fallback - return VariableRef(str(carg)) - - def _convert_quantum_gate(self, gate) -> Statement | None: - """Convert quantum gate operation.""" - gate_name = type(gate).__name__ - - # Regular gate mapping for in-place operations - gate_map = { - "H": "quantum.h", - "X": "quantum.x", - "Y": "quantum.y", - "Z": "quantum.z", - "S": "quantum.s", - "SZ": "quantum.s", - "SZdg": "quantum.sdg", - "T": "quantum.t", - "Tdg": "quantum.tdg", - "CX": "quantum.cx", - "CY": "quantum.cy", - "CZ": "quantum.cz", - "Prep": "quantum.qubit", # Prep allocates a fresh qubit - } - - if gate_name not in gate_map: - return Comment(f"Unknown gate: {gate_name}") - - func_name = gate_map[gate_name] - - # Convert qubit arguments - args = [] - if hasattr(gate, "qargs") and gate.qargs: - # Check if this is a single-qubit gate with multiple arguments - if gate_name in ["H", "X", "Y", "Z", "S", "SZ", "SZdg", "T", "Tdg", "Prep"] and len(gate.qargs) > 1: - # Single-qubit gate applied to multiple qubits - # Check if all qargs are consecutive array elements from the same array - if ( - all(hasattr(qarg, "reg") and hasattr(qarg, "index") for qarg in gate.qargs) - and len({qarg.reg.sym for qarg in gate.qargs}) == 1 - ): - # All from same array - check if consecutive - indices = [qarg.index for qarg in gate.qargs] - array_name = gate.qargs[0].reg.sym - - if indices == list(range(min(indices), max(indices) + 1)): - # Consecutive indices - generate a loop - loop_var = "i" - start = min(indices) - stop = max(indices) + 1 - - # Create loop body - body_block = Block() - - # Check if the array name needs remapping (for unpacked struct fields) - actual_array_name = array_name - if hasattr(self, "var_remapping") and array_name in self.var_remapping: - actual_array_name = self.var_remapping[array_name] - - array_ref = VariableRef(actual_array_name) - index_ref = VariableRef(loop_var) - elem_access = ArrayAccess(array=array_ref, index=index_ref) - call = FunctionCall(func_name=func_name, args=[elem_access]) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - body_block.statements.append(ExpressionStatement(call)) - - # Create for loop - range_call = FunctionCall( - func_name="range", - args=[Literal(start), Literal(stop)], - ) - return ForStatement( - loop_var=loop_var, - iterable=range_call, - body=body_block, - ) - - # Not consecutive or not from same array - expand to individual calls - stmts = [] - for qarg in gate.qargs: - qref = self._convert_qubit_ref(qarg) - call = FunctionCall(func_name=func_name, args=[qref]) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - stmts.append(ExpressionStatement(call)) - # Return a block with all statements - return Block(statements=stmts) - # Handle multi-qubit gates with tuple arguments - if gate_name in ["CX", "CY", "CZ"] and all(isinstance(arg, tuple) and len(arg) == 2 for arg in gate.qargs): - # Multiple (control, target) pairs - generate multiple statements - stmts = [] - for ctrl, tgt in gate.qargs: - ctrl_ref = self._convert_qubit_ref(ctrl) - tgt_ref = self._convert_qubit_ref(tgt) - call = FunctionCall(func_name=func_name, args=[ctrl_ref, tgt_ref]) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - stmts.append(ExpressionStatement(call)) - # Return a block with all statements - return Block(statements=stmts) - # Standard argument handling - for qarg in gate.qargs: - # Check if this is a full array (no index) - if hasattr(qarg, "sym") and hasattr(qarg, "size") and qarg.size > 1: - # This is a full array - need to expand to individual gates - stmts = [] - array_name = qarg.sym - - # Check for renaming - if array_name in self.plan.renamed_variables: - array_name = self.plan.renamed_variables[array_name] - - # Check if this array name needs remapping (for unpacked struct fields) - if hasattr(self, "var_remapping") and array_name in self.var_remapping: - array_name = self.var_remapping[array_name] - - # Apply gate to each element - # For operations on arrays, we need to expand to individual operations - # However, reset operations in functions with owned arrays - # need special handling - - if ( - gate_name == "Prep" - and hasattr(self, "var_remapping") - and self.var_remapping - and array_name in self.var_remapping - ): - # Array Unpacking Pattern: use unpacked variables with - # functional operations - stmts.append(Comment(f"Reset all qubits in {array_name}")) - - if hasattr(self, "unpacked_vars") and array_name in self.unpacked_vars: - # Use unpacked variables with functional assignments - # Note: Explicit reset tracking is done during consumption analysis - # in _track_consumed_qubits(), not here - element_names = self.unpacked_vars[array_name] - - for i in range(min(qarg.size, len(element_names))): - # CRITICAL: Check if this qubit was just replaced by a measurement - # If so, skip the entire Prep (qubit already fresh) - if hasattr(self, "replaced_qubits") and ( - array_name in self.replaced_qubits and i in self.replaced_qubits[array_name] - ): - # This qubit was just replaced by measurement - skip Prep - self.replaced_qubits[array_name].discard(i) - # Add comment but no actual operation - stmts.append( - Comment( - f"Prep skipped for {element_names[i]} - already fresh from measurement", - ), - ) - continue - - elem_var = VariableRef(element_names[i]) - - # CRITICAL: Prep (reset) requires discard-then-allocate pattern - # Can't pass old qubit as argument to quantum.qubit() - # Pattern: quantum.discard(q); q = quantum.qubit() - - # 1. Discard the old qubit - discard_call = FunctionCall( - func_name="quantum.discard", - args=[elem_var], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - stmts.append(ExpressionStatement(discard_call)) - - # 2. Allocate fresh qubit - fresh_qubit_call = FunctionCall( - func_name="quantum.qubit", - args=[], # No arguments - fresh allocation - ) - assignment = Assignment( - target=elem_var, - value=fresh_qubit_call, - ) - stmts.append(assignment) - else: - # Fallback to array indexing if no unpacking - for i in range(qarg.size): - elem_ref = ArrayAccess(array_name=array_name, index=i) - call = FunctionCall( - func_name=func_name, - args=[elem_ref], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - stmts.append(ExpressionStatement(call)) - else: - # Regular case - generate a loop instead of expanding - # Check if this array is part of a struct - is_struct_field = False - - # First check if we have a remapped variable (unpacked struct field) - # The key insight is that if we're in a function with - # @owned struct parameters - # and this array is a struct field that has been unpacked, we should use - # the unpacked variable name directly, not struct.field notation - use_unpacked = False - if hasattr(self, "var_remapping") and array_name in self.var_remapping: - # Check if this is a struct field that has been unpacked - for prefix, info in self.struct_info.items(): - if array_name in info["var_names"].values() and hasattr( - self, - "current_function_params", - ): - # Check if the struct is an @owned parameter - for ( - param_name, - param_type, - ) in self.current_function_params: - if param_name == prefix and "@owned" in str( - param_type, - ): - use_unpacked = True - break - if use_unpacked: - break - - if use_unpacked: - # Generate a loop using the unpacked variable - loop_var = "i" - body_block = Block() - - # Use the remapped name from var_remapping - remapped_name = self.var_remapping.get( - array_name, - array_name, - ) - elem_ref = ArrayAccess( - array=VariableRef(remapped_name), - index=VariableRef(loop_var), - ) - call = FunctionCall(func_name=func_name, args=[elem_ref]) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - body_block.statements.append(ExpressionStatement(call)) - - # Create for loop - range_call = FunctionCall( - func_name="range", - args=[Literal(0), Literal(qarg.size)], - ) - for_stmt = ForStatement( - loop_var=loop_var, - iterable=range_call, - body=body_block, - ) - stmts.append(for_stmt) - is_struct_field = True # Skip the struct field check below - - if not is_struct_field: - for prefix, info in self.struct_info.items(): - if qarg.sym in info["var_names"].values(): - # Find the field name - for suffix, var_name in info["var_names"].items(): - if var_name == qarg.sym: - # Check if we're in a function that receives the struct - struct_param_name = prefix - if hasattr(self, "param_mapping") and prefix in self.param_mapping: - struct_param_name = self.param_mapping[prefix] - - # Check if the struct has a fresh version (after function calls) - if hasattr(self, "refreshed_arrays") and prefix in self.refreshed_arrays: - struct_param_name = self.refreshed_arrays[prefix] - - # Generate a loop for struct field access - loop_var = "i" - body_block = Block() - - field_access = FieldAccess( - obj=VariableRef(struct_param_name), - field=suffix, - ) - elem_ref = ArrayAccess( - array=field_access, - index=VariableRef(loop_var), - ) - call = FunctionCall( - func_name=func_name, - args=[elem_ref], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - body_block.statements.append( - ExpressionStatement(call), - ) - - # Create for loop - range_call = FunctionCall( - func_name="range", - args=[Literal(0), Literal(qarg.size)], - ) - for_stmt = ForStatement( - loop_var=loop_var, - iterable=range_call, - body=body_block, - ) - stmts.append(for_stmt) - is_struct_field = True - break - break - - if not is_struct_field: - # Not in a struct - check if array was unpacked - if hasattr(self, "unpacked_vars") and array_name in self.unpacked_vars: - # Array was unpacked - UNROLL the loop to use unpacked elements directly - # This avoids: unpack → reconstruct → loop → unpack (AlreadyUsedError) - # Instead: unpack → apply to each element (no reconstruction needed) - element_names = self.unpacked_vars[array_name] - - # Unroll: apply the operation to each unpacked element - for i in range(qarg.size): - if i < len(element_names): - elem_ref = VariableRef(element_names[i]) - call = FunctionCall( - func_name=func_name, - args=[elem_ref], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - stmts.append(ExpressionStatement(call)) - - # No need to update unpacked_vars - elements are modified in-place - else: - # Array not unpacked - generate a loop - loop_var = "i" - body_block = Block() - - # Check if the array name needs remapping (for unpacked struct fields) - actual_array_name = array_name - if hasattr(self, "var_remapping") and array_name in self.var_remapping: - actual_array_name = self.var_remapping[array_name] - - elem_ref = ArrayAccess( - array=VariableRef(actual_array_name), - index=VariableRef(loop_var), - ) - call = FunctionCall( - func_name=func_name, - args=[elem_ref], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - body_block.statements.append(ExpressionStatement(call)) - - # Create for loop - range_call = FunctionCall( - func_name="range", - args=[Literal(0), Literal(qarg.size)], - ) - for_stmt = ForStatement( - loop_var=loop_var, - iterable=range_call, - body=body_block, - ) - stmts.append(for_stmt) - - # Return a block with all statements - return Block(statements=stmts) - args.append(self._convert_qubit_ref(qarg)) - - # If we get here, we have regular args (not arrays) - if args: - # Create function call expression - call = FunctionCall(func_name=func_name, args=args) - - # Special handling for Prep - it allocates a fresh qubit - # so we need to use assignment, not an expression statement - # Note: Explicit reset tracking is done during consumption analysis - # in _track_consumed_qubits(), not here - # Prep generates: discard + fresh allocation (reset pattern) - if gate_name == "Prep" and len(args) == 1: - # Get the target variable (where to store the fresh qubit) - target = args[0] - - # CRITICAL: Check if the previous operation was a measurement on this same qubit - # If so, skip the discard step (qubit already consumed by measurement) - skip_discard = False - if ( - hasattr(self, "current_block_ops") - and hasattr(self, "current_op_index") - and self.current_block_ops is not None - and self.current_op_index is not None - and self.current_op_index > 0 - and hasattr(target, "name") - ): - prev_index = self.current_op_index - 1 - prev_op = self.current_block_ops[prev_index] - # Check if previous operation was a measurement - if type(prev_op).__name__ == "Measure" and hasattr( - prev_op, - "qargs", - ): - for meas_qarg in prev_op.qargs: - # Get the variable name that would have been generated for this qubit - if hasattr(meas_qarg, "reg") and hasattr( - meas_qarg.reg, - "sym", - ): - array_name = meas_qarg.reg.sym - # Check both unpacked vars and locally allocated vars - if ( - hasattr(self, "unpacked_vars") - and array_name in self.unpacked_vars - and hasattr(meas_qarg, "index") - ): - element_names = self.unpacked_vars[array_name] - qubit_index = meas_qarg.index - if qubit_index < len(element_names): - meas_var_name = element_names[qubit_index] - if meas_var_name == target.name: - # Same qubit - skip discard - skip_discard = True - break - # Also check if this is a locally allocated qubit (two patterns) - elif hasattr(meas_qarg, "index"): - qubit_index = meas_qarg.index - # Pattern 1: {array}_{index}_local (from line 3712) - local_var_name = f"{array_name}_{qubit_index}_local" - # Pattern 2: {array}_{index} (from UNPACKED_MIXED with local allocation) - unpacked_var_name = f"{array_name}_{qubit_index}" - - if target.name in ( - local_var_name, - unpacked_var_name, - ): - # This is the same qubit that was measured - skip discard - skip_discard = True - break - - # CRITICAL: Use discard-then-allocate pattern for reset - # Pattern: quantum.discard(q); q = quantum.qubit() - # BUT: If qubit was just consumed by measurement, use fresh variable name - # to satisfy Guppy's linear type constraints - stmts = [] - - # Determine target variable for the fresh qubit - if skip_discard: - # Previous operation consumed the qubit - # We need a fresh variable name to avoid PlaceNotUsedError - old_name = target.name - - # Generate a new version for this variable - version = self.variable_version_counter.get(old_name, 0) + 1 - self.variable_version_counter[old_name] = version - new_name = f"{old_name}_{version}" - - # Add remapping so subsequent operations use the new name - self.variable_remapping[old_name] = new_name - - # Track the new variable for cleanup - if not hasattr(self, "allocated_ancillas"): - self.allocated_ancillas = set() - self.allocated_ancillas.add(new_name) - - # Allocate to the new variable - fresh_target = VariableRef(new_name) - else: - # Discard the old qubit first - discard_call = FunctionCall( - func_name="quantum.discard", - args=[target], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - stmts.append(ExpressionStatement(discard_call)) - - # Reuse the same variable - fresh_target = target - - # Allocate fresh qubit - fresh_qubit_call = FunctionCall(func_name="quantum.qubit", args=[]) - stmts.append(Assignment(target=fresh_target, value=fresh_qubit_call)) - - return Block(statements=stmts) - - # No longer use functional operations - all gates are in-place - - # Create expression statement wrapper for non-functional operations - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - return ExpressionStatement(call) - - return None - - def _should_restructure_conditional_consumption(self, if_block) -> bool: - """Check if this If block needs restructuring to avoid conditional consumption.""" - # Check if we're in a conditional consumption loop - if not (hasattr(self, "_in_conditional_consumption_loop") and self._in_conditional_consumption_loop): - return False - - # Check if the If block contains function calls that consume variables - if hasattr(if_block, "ops"): - for op in if_block.ops: - if hasattr(op, "block_name") and op.block_name in [ - "PrepEncodingFTZero", - "PrepEncodingNonFTZero", - ]: - return True - - return False - - def _convert_if(self, if_block) -> Statement | None: - """Convert If block.""" - # Check if this conditional needs restructuring to avoid consumption issues - if self._should_restructure_conditional_consumption(if_block): - # Restructure to avoid conditional consumption - # Instead of: if cond: consume(vars) - # We do: vars = consume(vars); if not cond: pass - # This ensures vars are always consumed, maintaining linearity - - self.current_block.statements.append( - Comment("Restructured conditional to avoid consumption in conditional"), - ) - - # Execute the operations unconditionally - if hasattr(if_block, "ops"): - for op in if_block.ops: - stmt = self._convert_operation(op) - if stmt: - self.current_block.statements.append(stmt) - - # The condition check becomes a no-op since we already executed - return None - - # Check if we have a pre-extracted condition for this If block - if hasattr(self, "pre_extracted_conditions") and id(if_block) in self.pre_extracted_conditions: - # Use the pre-extracted condition variable - condition_var_name = self.pre_extracted_conditions[id(if_block)] - condition = VariableRef(condition_var_name) - - # Convert then block - then_block = Block() - if hasattr(if_block, "ops"): - prev_block = self.current_block - self.current_block = then_block - - for op in if_block.ops: - stmt = self._convert_operation(op) - if stmt: - then_block.statements.append(stmt) - - self.current_block = prev_block - - # Handle else block if present - else_block = None - if hasattr(if_block, "else_ops") and if_block.else_ops: - else_block = Block() - prev_block = self.current_block - self.current_block = else_block - - for op in if_block.else_ops: - stmt = self._convert_operation(op) - if stmt: - else_block.statements.append(stmt) - - self.current_block = prev_block - - return IfStatement( - condition=condition, - then_block=then_block, - else_block=else_block, - ) - - # Check if this If block has struct field access in loop with @owned parameters - if hasattr(if_block, "cond") and self._is_struct_field_in_loop_with_owned( - if_block.cond, - ): - # Implement a proper fix by extracting the condition value before the conditional - # This allows us to check the struct field without violating @owned constraints - - # Extract the struct field that's being tested - condition_var = self._extract_condition_variable(if_block.cond) - if condition_var: - self.current_block.statements.append( - Comment( - "Extract condition variable to avoid @owned struct field access in loop", - ), - ) - - # Create a local variable to hold the condition value - condition_stmt = Assignment( - target=VariableRef(condition_var["var_name"]), - value=self._convert_condition_value(if_block.cond), - ) - self.current_block.statements.append(condition_stmt) - - # Convert then block first - then_block = Block() - if hasattr(if_block, "ops"): - # Enter a new scope for the If block - prev_block = self.current_block - self.current_block = then_block - - for op in if_block.ops: - stmt = self._convert_operation(op) - if stmt: - then_block.statements.append(stmt) - - self.current_block = prev_block - - # Now create the If statement using the extracted variable - if condition_var["comparison"] == "EQUIV": - # For bool comparison with 1, convert to just the boolean variable - # Since verify_prep[0] is bool and we're checking == 1, - # this means "if verification failed" which is just the boolean value - if condition_var["compare_value"] == 1: - condition = VariableRef(condition_var["var_name"]) - else: - # For other comparisons, use == operator with appropriate type - condition = BinaryOp( - left=VariableRef(condition_var["var_name"]), - op="==", - right=Literal(condition_var["compare_value"]), - ) - else: - condition = VariableRef(condition_var["var_name"]) - - # Create and return the If statement - return IfStatement( - condition=condition, - then_block=then_block, - ) - # Fallback to the conservative approach if we can't extract the condition - self.current_block.statements.append( - Comment( - "Fallback: If condition with struct field access simplified for @owned compatibility", - ), - ) - - # Convert the If body operations unconditionally - if hasattr(if_block, "ops"): - for op in if_block.ops: - stmt = self._convert_operation(op) - if stmt: - self.current_block.statements.append(stmt) - - return None - - # Convert condition - condition = self._convert_condition(if_block.cond) - - # Track what resources were consumed before this conditional - # We need to ensure we don't try to re-consume them in else blocks - consumed_before_if = {} - if not hasattr(self, "consumed_resources"): - self.consumed_resources = {} - for res_name, indices in self.consumed_resources.items(): - consumed_before_if[res_name] = indices.copy() if isinstance(indices, set) else set(indices) - - # Convert then block with scope tracking - then_block = Block() - prev_block = self.current_block - - with self.scope_manager.enter_scope(ScopeType.IF_THEN) as then_scope: - self.current_block = then_block - - if hasattr(if_block, "ops"): - for op in if_block.ops: - stmt = self._convert_operation(op) - if stmt: - then_block.statements.append(stmt) - - # Convert else block if present - else_block = None - else_scope_info = None - - if hasattr(if_block, "else_block") and if_block.else_block: - else_block = Block() - - with self.scope_manager.enter_scope(ScopeType.IF_ELSE) as else_scope: - else_scope_info = else_scope - self.current_block = else_block - - if hasattr(if_block.else_block, "ops"): - for op in if_block.else_block.ops: - stmt = self._convert_operation(op) - if stmt: - else_block.statements.append(stmt) - - # Check for resource balancing needs - # Analyze resource consumption across branches - unbalanced = self.scope_manager.analyze_conditional_branches( - then_scope, - else_scope_info, - self.context, - ) - - # If there are unbalanced resources, we need to balance them - if unbalanced: - # Helper function to add resource consumption - def add_resource_consumption(block, res_name, indices): - # Filter out indices that were already consumed before the if statement - if res_name in consumed_before_if: - already_consumed = consumed_before_if[res_name] - indices = indices - already_consumed - - if indices: - block.statements.append( - Comment("Consume qubits to maintain linearity"), - ) - for idx in sorted(indices): - # Check if resource is unpacked - if res_name in self.unpacked_vars: - element_names = self.unpacked_vars[res_name] - if idx < len(element_names): - # Measure the unpacked qubit - meas_expr = FunctionCall( - func_name="quantum.measure", - args=[VariableRef(element_names[idx])], - ) - block.statements.append( - Assignment( - target=VariableRef("_"), - value=meas_expr, - ), - ) - elif hasattr(self, "dynamic_allocations") and res_name in self.dynamic_allocations: - # For dynamic allocations, allocate a fresh qubit and measure it - # Always allocate a fresh qubit for consumption (for linearity balancing) - var_name = self._get_unique_var_name(res_name, idx) - block.statements.append( - Assignment( - target=VariableRef(var_name), - value=FunctionCall( - func_name="quantum.qubit", - args=[], - ), - ), - ) - # Measure the qubit - meas_expr = FunctionCall( - func_name="quantum.measure", - args=[VariableRef(var_name)], - ) - block.statements.append( - Assignment(target=VariableRef("_"), value=meas_expr), - ) - else: - # Use array indexing - meas_expr = FunctionCall( - func_name="quantum.measure", - args=[ArrayAccess(array_name=res_name, index=idx)], - ) - block.statements.append( - Assignment(target=VariableRef("_"), value=meas_expr), - ) - - # If we have an else block, add balancing to both branches - if else_block: - # Add to then branch what else consumed - for res_name, indices in unbalanced.items(): - if res_name in then_scope.resource_usage: - then_usage = then_scope.resource_usage[res_name] - else_usage = else_scope_info.resource_usage.get( - res_name, - ResourceUsage(res_name, set()), - ) - missing_in_then = else_usage.consumed - then_usage.consumed - if missing_in_then: - add_resource_consumption( - then_block, - res_name, - missing_in_then, - ) - - # Add to else branch what then consumed - for res_name in then_scope.resource_usage: - then_usage = then_scope.resource_usage[res_name] - else_usage = else_scope_info.resource_usage.get( - res_name, - ResourceUsage(res_name, set()), - ) - missing_in_else = then_usage.consumed - else_usage.consumed - if missing_in_else: - add_resource_consumption(else_block, res_name, missing_in_else) - else: - # No else block - create one to consume resources - else_block = Block() - else_block.statements.append( - Comment("Auto-generated else block for linearity"), - ) - - for res_name, indices in unbalanced.items(): - add_resource_consumption(else_block, res_name, indices) - - self.current_block = prev_block - - return IfStatement( - condition=condition, - then_block=then_block, - else_block=else_block, - ) - - def _convert_while(self, while_block) -> Statement | None: - """Convert While loop.""" - # Convert condition - condition = self._convert_condition(while_block.cond) - - # Convert body with scope tracking - body_block = Block() - prev_block = self.current_block - - with self.scope_manager.enter_scope(ScopeType.LOOP): - self.current_block = body_block - - if hasattr(while_block, "ops"): - for op in while_block.ops: - stmt = self._convert_operation(op) - if stmt: - body_block.statements.append(stmt) - - self.current_block = prev_block - - return WhileStatement( - condition=condition, - body=body_block, - ) - - def _convert_for(self, for_block) -> Statement | None: - """Convert For loop.""" - # Get loop variable and range - loop_var = for_block.var - - # Determine the iteration pattern - if hasattr(for_block, "iterable") and for_block.iterable: - # For(i, iterable) - return self._convert_for_iterable(for_block, loop_var) - if hasattr(for_block, "start") and hasattr(for_block, "stop"): - # For(i, start, stop, [step]) - return self._convert_for_range(for_block, loop_var) - # Unknown pattern - return Comment(f"TODO: Unsupported For loop pattern with variable {loop_var}") - - def _convert_for_range(self, for_block, loop_var) -> Statement | None: - """Convert For loop with range pattern.""" - start = for_block.start - stop = for_block.stop - step = getattr(for_block, "step", 1) - - # Create range() call - if step == 1: - # range(start, stop) - range_call = FunctionCall( - func_name="range", - args=[Literal(start), Literal(stop)], - ) - else: - # range(start, stop, step) - range_call = FunctionCall( - func_name="range", - args=[Literal(start), Literal(stop), Literal(step)], - ) - - # Check if we need to pre-extract conditions from If statements in the loop body - # This is necessary when we have @owned struct parameters and If conditions that - # access struct fields inside the loop - extracted_conditions = [] - if self._should_pre_extract_conditions(for_block) and hasattr(for_block, "ops"): - # Find all If statements in the loop body and extract their conditions - for op in for_block.ops: - if type(op).__name__ == "If" and hasattr(op, "cond") and self._is_struct_field_access(op.cond): - condition_var = self._generate_condition_var_name(op.cond) - if condition_var: - # Generate the extraction statement before the loop - self.current_block.statements.append( - Comment( - "Pre-extract condition to avoid @owned struct field access in loop", - ), - ) - condition_stmt = Assignment( - target=VariableRef(condition_var), - value=self._convert_condition(op.cond), - ) - self.current_block.statements.append(condition_stmt) - extracted_conditions.append((op, condition_var)) - - # Convert body with scope tracking - body_block = Block() - prev_block = self.current_block - - # Track extracted conditions so If converter can use them - if extracted_conditions: - if not hasattr(self, "pre_extracted_conditions"): - self.pre_extracted_conditions = {} - for if_op, var_name in extracted_conditions: - self.pre_extracted_conditions[id(if_op)] = var_name - - with self.scope_manager.enter_scope(ScopeType.LOOP): - self.current_block = body_block - - if hasattr(for_block, "ops"): - for op in for_block.ops: - stmt = self._convert_operation(op) - if stmt: - body_block.statements.append(stmt) - - self.current_block = prev_block - - return ForStatement( - loop_var=str(loop_var), - iterable=range_call, - body=body_block, - ) - - def _convert_for_iterable(self, for_block, loop_var) -> Statement | None: - """Convert For loop with iterable pattern.""" - # For now, just handle the iterable as a variable reference - iterable = for_block.iterable - - # Try to convert it to an IR node - if isinstance(iterable, str): - iter_node = VariableRef(iterable) - elif hasattr(iterable, "sym"): - iter_node = VariableRef(iterable.sym) - else: - # Try to represent it somehow - iter_node = Literal(str(iterable)) - - # Convert body - body_block = Block() - prev_block = self.current_block - - with self.scope_manager.enter_scope(ScopeType.LOOP): - self.current_block = body_block - - if hasattr(for_block, "ops"): - for op in for_block.ops: - stmt = self._convert_operation(op) - if stmt: - body_block.statements.append(stmt) - - self.current_block = prev_block - - return ForStatement( - loop_var=str(loop_var), - iterable=iter_node, - body=body_block, - ) - - def _convert_condition(self, cond) -> IRNode: - """Convert condition expression.""" - cond_type = type(cond).__name__ - - if cond_type == "Bit": - # Bit reference - return self._convert_bit_ref(cond) - if cond_type == "EQUIV": - # Equality comparison - - left = self._convert_condition(cond.left) - right = self._convert_condition(cond.right) - - # Optimize boolean comparisons to 1 - if isinstance(right, Literal) and right.value == 1 and type(cond.left).__name__ == "Bit": - # Just return the boolean value itself - return left - - return BinaryOp(left=left, op="==", right=right) - if cond_type == "LT": - # Less than - left = self._convert_condition(cond.left) - right = self._convert_condition(cond.right) - return BinaryOp(left=left, op="<", right=right) - if cond_type == "GT": - # Greater than - left = self._convert_condition(cond.left) - right = self._convert_condition(cond.right) - return BinaryOp(left=left, op=">", right=right) - if cond_type == "AND": - # Bitwise AND (used as logical in conditions) - left = self._convert_condition(cond.left) - right = self._convert_condition(cond.right) - return BinaryOp(left=left, op="&", right=right) - if cond_type == "OR": - # Bitwise OR (used as logical in conditions) - left = self._convert_condition(cond.left) - right = self._convert_condition(cond.right) - return BinaryOp(left=left, op="|", right=right) - if cond_type == "NOT": - # Logical NOT - operand = self._convert_condition(cond.value) - return UnaryOp(op="not", operand=operand) - if hasattr(cond, "value"): - # Literal value - return Literal(cond.value) - if isinstance(cond, int | bool | str): - # Direct literal - return Literal(cond) - - # Default: try to convert as bit reference - return self._convert_bit_ref(cond) - - def _convert_repeat(self, repeat_block) -> Statement | None: - """Convert Repeat block to for loop.""" - # Repeat is essentially a for loop with an anonymous variable - repeat_count = repeat_block.cond - - # Check if this repeat block contains conditional consumption patterns - # that would violate linearity (e.g., conditional function calls with @owned params) - has_conditional_consumption = self._has_conditional_consumption_pattern( - repeat_block, - ) - - if has_conditional_consumption: - # Special handling for conditional consumption patterns - # Instead of a loop with conditional consumption, we need to restructure - # to avoid linearity violations - return self._convert_repeat_with_conditional_consumption(repeat_block) - - # Check if conditions have already been pre-extracted at the function level - # If not, extract them here (for non-function contexts) - extracted_conditions = [] - already_extracted = hasattr(self, "pre_extracted_conditions") and self.pre_extracted_conditions - - should_extract = ( - not already_extracted - and self._should_pre_extract_conditions_repeat(repeat_block) - and hasattr(repeat_block, "ops") - ) - if should_extract: - # Find all If statements in the loop body and extract their conditions - for op in repeat_block.ops: - if type(op).__name__ == "If" and hasattr(op, "cond"): - # Check if this condition was already pre-extracted - if hasattr(self, "pre_extracted_conditions") and id(op) in self.pre_extracted_conditions: - continue # Skip - already handled - - if self._is_struct_field_access(op.cond): - condition_var = self._generate_condition_var_name(op.cond) - if condition_var: - # Generate the extraction statement before the loop - self.current_block.statements.append( - Comment( - "Pre-extract condition to avoid @owned struct field access in loop", - ), - ) - condition_stmt = Assignment( - target=VariableRef(condition_var), - value=self._convert_condition(op.cond), - ) - self.current_block.statements.append(condition_stmt) - extracted_conditions.append((op, condition_var)) - - # Convert body - body_block = Block() - prev_block = self.current_block - - # Track extracted conditions so If converter can use them - if extracted_conditions: - if not hasattr(self, "pre_extracted_conditions"): - self.pre_extracted_conditions = {} - for if_op, var_name in extracted_conditions: - self.pre_extracted_conditions[id(if_op)] = var_name - - with self.scope_manager.enter_scope(ScopeType.LOOP): - self.current_block = body_block - - if hasattr(repeat_block, "ops"): - for op in repeat_block.ops: - stmt = self._convert_operation(op) - if stmt: - body_block.statements.append(stmt) - - self.current_block = prev_block - - # Create ForStatement with anonymous variable - return ForStatement( - loop_var="_", - iterable=FunctionCall(func_name="range", args=[Literal(repeat_count)]), - body=body_block, - ) - - def _has_conditional_consumption_pattern(self, repeat_block) -> bool: - """Check if a repeat block contains conditional consumption patterns.""" - if not hasattr(repeat_block, "ops"): - return False - - # Look for If blocks containing function calls with @owned parameters - for op in repeat_block.ops: - if type(op).__name__ == "If" and hasattr(op, "ops"): - for inner_op in op.ops: - # Check if this is a function call that might have @owned params - if hasattr(inner_op, "block_name"): - # Check if this function has @owned parameters - func_name = inner_op.block_name - if func_name in [ - "PrepEncodingFTZero", - "PrepEncodingNonFTZero", - "PrepZeroVerify", - ]: - return True - return False - - def _update_mappings_after_conditional_loop(self) -> None: - """Update variable mappings after a loop with conditional consumption. - - After a loop with conditional consumption, variables might have been - conditionally replaced with fresh versions. We need to ensure that - subsequent operations use the right variables. - """ - # For the specific pattern where we have c_d_fresh that might have been - # conditionally consumed to create c_d_fresh_1, we need to ensure - # that subsequent uses reference the original c_d_fresh (not _1) - # because the _1 version only exists conditionally. - # - # The proper solution would be to track which variables are guaranteed - # to exist and use those. For now, we'll stick with the original names. - - def _convert_repeat_with_conditional_consumption( - self, - repeat_block, - ) -> Statement | None: - """Convert repeat block with conditional consumption to avoid linearity violations.""" - repeat_count = repeat_block.cond - - # For conditional consumption patterns, we need to be careful - # The issue is that variables might be consumed conditionally in the loop - # but then used unconditionally afterward - - # Track that we're in a special conditional consumption context - self._in_conditional_consumption_loop = True - - # Convert as normal for loop - body_block = Block() - prev_block = self.current_block - - with self.scope_manager.enter_scope(ScopeType.LOOP): - self.current_block = body_block - - if hasattr(repeat_block, "ops"): - for op in repeat_block.ops: - stmt = self._convert_operation(op) - if stmt: - body_block.statements.append(stmt) - - self.current_block = prev_block - self._in_conditional_consumption_loop = False - - return ForStatement( - loop_var="_", - iterable=FunctionCall(func_name="range", args=[Literal(repeat_count)]), - body=body_block, - ) - - def _convert_comment(self, comment) -> Statement | None: - """Convert comment.""" - if hasattr(comment, "txt") and comment.txt: - return Comment(comment.txt) - return None # Skip empty comments - - def _is_struct_field_in_loop_with_owned(self, cond) -> bool: - """Check if a condition accesses a struct field in a problematic context. - - Returns True if: - 1. We're in a loop scope - 2. We're in a function with @owned struct parameters - 3. The condition accesses a struct field - """ - # Check if we're in a loop - if not hasattr(self, "scope_manager") or not self.scope_manager.is_in_loop(): - return False - - # Check if we're in a function with @owned struct parameters - if not hasattr(self, "function_info") or self.current_function_name == "main": - return False - - func_info = self.function_info.get(self.current_function_name, {}) - if not func_info.get("has_owned_struct_params", False): - return False - - # Check if the condition accesses a struct field - # Handle different condition types - cond_type = type(cond).__name__ - - if cond_type == "EQUIV": - # For equality comparisons, check the left side - if hasattr(cond, "left"): - return self._is_struct_field_in_loop_with_owned(cond.left) - elif hasattr(cond, "reg") and hasattr(cond.reg, "sym"): - array_name = cond.reg.sym - # Check if this variable is a struct field - for info in self.struct_info.values(): - if array_name in info["var_names"].values(): - return True - - return False - - def _extract_condition_variable(self, cond) -> dict | None: - """Extract information about a condition variable that accesses a struct field. - - Returns a dict with: - - var_name: suggested variable name for the extracted value - - struct_field: the struct field being accessed (e.g., 'c.verify_prep[0]') - - comparison: the comparison type (e.g., 'EQUIV') - - compare_value: the value being compared against - """ - cond_type = type(cond).__name__ - - if cond_type == "EQUIV" and hasattr(cond, "left") and hasattr(cond, "right"): - # Handle EQUIV(c_verify_prep[0], 1) - left = cond.left - right = cond.right - - # Check if left side is a struct field access - if hasattr(left, "reg") and hasattr(left.reg, "sym") and hasattr(left, "index"): - array_name = left.reg.sym - index = left.index - - # Check if this is a struct field - for prefix, info in self.struct_info.items(): - if array_name in info["var_names"].values(): - # Find the field name - field_name = None - for suffix, var_name in info["var_names"].items(): - if var_name == array_name: - field_name = suffix - break - - if field_name: - # Extract the comparison value - compare_value = getattr(right, "val", right) if hasattr(right, "val") else right - - return { - "var_name": f"{field_name}_{index}_extracted", - "struct_field": f"{prefix}.{field_name}[{index}]", - "comparison": "EQUIV", - "compare_value": compare_value, - } - - return None - - def _convert_condition_value(self, cond) -> IRNode: - """Convert the struct field access part of a condition to an IR node.""" - cond_type = type(cond).__name__ - - if cond_type == "EQUIV" and hasattr(cond, "left"): - # For EQUIV(c_verify_prep[0], 1), convert the left side (c_verify_prep[0]) - left = cond.left - - if hasattr(left, "reg") and hasattr(left.reg, "sym") and hasattr(left, "index"): - array_name = left.reg.sym - index = left.index - - # Check if this is a struct field and get the struct parameter name - for prefix, info in self.struct_info.items(): - if array_name in info["var_names"].values(): - # Find the field name - field_name = None - for suffix, var_name in info["var_names"].items(): - if var_name == array_name: - field_name = suffix - break - - if field_name: - # Check if the struct has been decomposed and we should use decomposed variables - if hasattr(self, "var_remapping") and array_name in self.var_remapping: - # Struct was decomposed - use the decomposed variable directly - decomposed_var = self.var_remapping[array_name] - return ArrayAccess( - array=VariableRef(decomposed_var), - index=index, - ) - - # Get the struct parameter name (e.g., 'c') - struct_param_name = prefix - if hasattr(self, "param_mapping") and prefix in self.param_mapping: - struct_param_name = self.param_mapping[prefix] - - # Check if we have fresh structs - use them directly - if hasattr(self, "refreshed_arrays") and prefix in self.refreshed_arrays: - fresh_struct_name = self.refreshed_arrays[prefix] - struct_param_name = fresh_struct_name - # Don't replace field access for fresh structs - - # Create: c.verify_prep[0] - but check for decomposed variables first - # Check if we have decomposed variables for this struct - if hasattr(self, "decomposed_vars") and struct_param_name in self.decomposed_vars: - field_vars = self.decomposed_vars[struct_param_name] - if field_name in field_vars: - # Use the decomposed variable instead - decomposed_var = field_vars[field_name] - return ArrayAccess( - array=VariableRef(decomposed_var), - index=index, - ) - - # Fallback to original struct field access (this should now be rare) - field_access = FieldAccess( - obj=VariableRef(struct_param_name), - field=field_name, - ) - return ArrayAccess(array=field_access, index=index) - - # Fallback - return Literal(0) - - def _function_has_owned_struct_params(self, params) -> bool: - """Check if function has @owned struct parameters.""" - return any("@owned" in param_type and param_name in self.struct_info for param_name, param_type in params) - - def _has_function_calls_before_loops(self, block) -> bool: - """Check if the function has function calls before loops. - - This indicates that decomposed struct variables will be consumed for - struct reconstruction, so we can't pre-extract conditions from them. - """ - if not hasattr(block, "ops"): - return False - - # Look for function calls before any loops - found_function_call = False - - for op in block.ops: - op_type = type(op).__name__ - - # Check for function calls (which would trigger struct reconstruction) - if op_type == "Call" and hasattr(op, "func"): - # This is a function call that might consume structs - found_function_call = True - - # Check for Repeat/For loops - if we find function calls before loops, - # then we'll need to reconstruct structs and can't pre-extract - if op_type in ["Repeat", "For"] and found_function_call: - return True - - return False - - def _pre_extract_loop_conditions(self, block, body) -> dict: - """Pre-extract conditions from loops that might access @owned struct fields. - - Returns a dictionary mapping If block IDs to extracted condition variable names. - """ - return {} - - # Disable pre-extraction for now - it causes linearity conflicts with struct reconstruction - # TODO: Implement proper post-function-call condition extraction - # The code below is currently unreachable but kept for future reference - - # Find all Repeat blocks with If conditions that access struct fields - extracted: dict = {} # Initialize for dead code below - if hasattr(block, "ops"): - for op in block.ops: - if type(op).__name__ == "Repeat" and hasattr(op, "ops"): - # Check if this Repeat block contains If statements with struct field access - for inner_op in op.ops: - if ( - type(inner_op).__name__ == "If" - and hasattr( - inner_op, - "cond", - ) - and self._is_struct_field_access(inner_op.cond) - ): - # Extract this condition NOW before any operations - condition_var = self._generate_condition_var_name( - inner_op.cond, - ) - if condition_var: - body.statements.append( - Comment( - "Pre-extract condition to avoid @owned struct field access in loop", - ), - ) - condition_stmt = Assignment( - target=VariableRef(condition_var), - value=self._convert_condition(inner_op.cond), - ) - body.statements.append(condition_stmt) - extracted[id(inner_op)] = condition_var - - return extracted - - def _should_pre_extract_conditions_repeat(self, repeat_block) -> bool: - """Check if we need to pre-extract conditions from this repeat block. - - Returns True if: - 1. The loop contains If statements with conditions - 2. We're in a function with @owned struct parameters - 3. The conditions access struct fields - 4. BUT False if we have function calls that will consume the decomposed variables - """ - # Check if we're in a function with @owned struct parameters - if not hasattr(self, "function_info") or self.current_function_name == "main": - return False - - func_info = self.function_info.get(self.current_function_name, {}) - if not func_info.get("has_owned_struct_params", False): - return False - - # Check if we have decomposed variables that might be consumed for struct reconstruction - # This indicates we're in a context where pre-extraction would conflict with reconstruction - if hasattr(self, "decomposed_vars") and self.decomposed_vars: - return False - - # Check if the loop contains If statements with struct field access - if hasattr(repeat_block, "ops"): - for op in repeat_block.ops: - if type(op).__name__ == "If" and hasattr(op, "cond") and self._is_struct_field_access(op.cond): - return True - - return False - - def _should_pre_extract_conditions(self, for_block) -> bool: - """Check if we need to pre-extract conditions from this for loop. - - Returns True if: - 1. The loop contains If statements with conditions - 2. We're in a function with @owned struct parameters OR have fresh structs from returns - 3. The conditions access struct fields - """ - # Check if we're in a function with @owned struct parameters or fresh structs - if not hasattr(self, "function_info") or self.current_function_name == "main": - return False - - func_info = self.function_info.get(self.current_function_name, {}) - has_owned_params = func_info.get("has_owned_struct_params", False) - has_fresh_structs = hasattr(self, "refreshed_arrays") and bool( - self.refreshed_arrays, - ) - - if not (has_owned_params or has_fresh_structs): - return False - - # Check if the loop contains If statements with struct field access - if hasattr(for_block, "ops"): - for op in for_block.ops: - if type(op).__name__ == "If" and hasattr(op, "cond") and self._is_struct_field_access(op.cond): - return True - - return False - - def _is_struct_field_access(self, cond) -> bool: - """Check if a condition accesses a struct field.""" - cond_type = type(cond).__name__ - - if cond_type == "EQUIV": - # For equality comparisons, check the left side - if hasattr(cond, "left"): - return self._is_struct_field_access(cond.left) - elif cond_type == "Bit": - # Check if this is a struct field - if hasattr(cond, "reg") and hasattr(cond.reg, "sym"): - array_name = cond.reg.sym - # Check if this variable is a struct field (original or fresh) - for prefix, info in self.struct_info.items(): - # Check original struct fields - if array_name in info["var_names"].values(): - return True - # Check fresh struct field patterns (e.g., c_fresh accessing verify_prep) - if hasattr(self, "refreshed_arrays"): - for orig_name in self.refreshed_arrays: - if orig_name == prefix: - # Check if array_name matches fresh struct field pattern - for field_name in info["var_names"].values(): - # The condition might be accessing fresh_struct.field - if array_name == field_name: # Original field being accessed - return True - elif cond_type in ["AND", "OR", "XOR", "NOT"]: - # Check both sides for binary ops - if hasattr(cond, "left") and self._is_struct_field_access(cond.left): - return True - if hasattr(cond, "right") and self._is_struct_field_access(cond.right): - return True - - return False - - def _generate_condition_var_name(self, cond) -> str | None: - """Generate a variable name for an extracted condition.""" - cond_type = type(cond).__name__ - - if cond_type == "EQUIV" and hasattr(cond, "left"): - left = cond.left - if hasattr(left, "reg") and hasattr(left.reg, "sym") and hasattr(left, "index"): - array_name = left.reg.sym - index = left.index - - # Check if this is a struct field - for info in self.struct_info.values(): - if array_name in info["var_names"].values(): - # Find the field name - for suffix, var_name in info["var_names"].items(): - if var_name == array_name: - return f"{suffix}_{index}_condition" - elif cond_type == "Bit": - if hasattr(cond, "reg") and hasattr(cond.reg, "sym") and hasattr(cond, "index"): - array_name = cond.reg.sym - index = cond.index - - # Check if this is a struct field - for info in self.struct_info.values(): - if array_name in info["var_names"].values(): - # Find the field name - for suffix, var_name in info["var_names"].items(): - if var_name == array_name: - return f"{suffix}_{index}_condition" - - # Generate a generic name - return "extracted_condition" - - def _convert_set_operation(self, set_op) -> Statement | None: - """Convert SET operation for classical bits.""" - if not hasattr(set_op, "left") or not hasattr(set_op, "right"): - return Comment("Invalid SET operation") - - # Convert left side (target) - use array indexing for assignments - target = self._convert_bit_ref(set_op.left, is_assignment_target=True) - - # Convert right side (value) - value = self._convert_set_value(set_op.right) - - return Assignment(target=target, value=value) - - def _convert_set_value(self, value, parent_op=None) -> IRNode: - """Convert value in SET operation. - - Args: - value: The value to convert - parent_op: The parent operation type (if any) to determine if parens are needed - """ - # Check if it's a literal - if isinstance(value, int | bool): - return Literal(bool(value)) - - # Check if it's a bit reference - value_type = type(value).__name__ - if value_type == "Bit": - return self._convert_bit_ref(value) - - # Check for bitwise operations - if value_type == "XOR": - left = self._convert_set_value(value.left, parent_op=value_type) - right = self._convert_set_value(value.right, parent_op=value_type) - result = BinaryOp(left=left, op="^", right=right) - # XOR has same precedence as AND, higher than OR - # Only need parens if parent is AND (to clarify precedence) - if parent_op == "AND": - result.needs_parens = True - return result - if value_type == "AND": - left = self._convert_set_value(value.left, parent_op=value_type) - right = self._convert_set_value(value.right, parent_op=value_type) - result = BinaryOp(left=left, op="&", right=right) - # Mark as needing parens if it's a child of | - if parent_op == "OR": - result.needs_parens = True - return result - if value_type == "OR": - left = self._convert_set_value(value.left, parent_op=value_type) - right = self._convert_set_value(value.right, parent_op=value_type) - return BinaryOp(left=left, op="|", right=right) - if value_type == "NOT": - # NOT might have 'operand' or be applied to first item - if hasattr(value, "operand"): - operand = self._convert_set_value(value.operand, parent_op=value_type) - elif hasattr(value, "value"): - operand = self._convert_set_value(value.value, parent_op=value_type) - else: - # Try to get the operand another way - operand = Literal(value=True) - return UnaryOp(op="not", operand=operand) - - # Unknown value type - generate function call as fallback - args = [] - if hasattr(value, "left"): - args.append(self._convert_set_value(value.left, parent_op=value_type)) - if hasattr(value, "right"): - args.append(self._convert_set_value(value.right, parent_op=value_type)) - return FunctionCall(func_name=value_type, args=args) - - def _convert_permute(self, permute) -> Statement | None: - """Convert Permute operation.""" - # Permute swaps registers or elements - # In Guppy, we can implement this using Python's swap syntax - - if hasattr(permute, "elems_i") and hasattr(permute, "elems_f"): - elems_i = permute.elems_i - elems_f = permute.elems_f - - # Case 1: Simple register swap (a, b = b, a) - if hasattr(elems_i, "sym") and hasattr(elems_f, "sym"): - # Full register swap - comment = Comment(f"Swap {elems_i.sym} and {elems_f.sym}") - self.current_block.statements.append(comment) - - # In Guppy, we need to use a temporary variable - temp_var = f"_temp_{elems_i.sym}" - - # temp = a - self.current_block.statements.append( - Assignment( - target=VariableRef(temp_var), - value=VariableRef(elems_i.sym), - ), - ) - - # a = b - self.current_block.statements.append( - Assignment( - target=VariableRef(elems_i.sym), - value=VariableRef(elems_f.sym), - ), - ) - - # b = temp - self.current_block.statements.append( - Assignment( - target=VariableRef(elems_f.sym), - value=VariableRef(temp_var), - ), - ) - - return None # Already added statements - - # Case 2: List of elements permutation - if isinstance(elems_i, list) and isinstance(elems_f, list): - if len(elems_i) != len(elems_f): - return Comment("ERROR: Permutation lists must have same length") - - # Analyze the permutation pattern - permutation_map = self._analyze_permutation(elems_i, elems_f) - - if permutation_map is None: - return Comment("ERROR: Invalid permutation - elements don't match") - - # Generate permutation code based on the pattern - return self._generate_permutation_code( - permutation_map, - elems_i, - elems_f, - ) - - # Fallback for unrecognized patterns - return Comment("TODO: Implement complex permutation") - - def _analyze_permutation(self, elems_i, elems_f): - """Analyze permutation to create a mapping.""" - # Create a set of all elements to ensure they match - elems_i_set = set() - elems_f_set = set() - - # Build element signatures for comparison - for elem in elems_i: - if hasattr(elem, "reg") and hasattr(elem, "index"): - elems_i_set.add((elem.reg.sym, elem.index)) - elif hasattr(elem, "sym"): - # Full register reference - elems_i_set.add((elem.sym, None)) - - for elem in elems_f: - if hasattr(elem, "reg") and hasattr(elem, "index"): - elems_f_set.add((elem.reg.sym, elem.index)) - elif hasattr(elem, "sym"): - elems_f_set.add((elem.sym, None)) - - # Check if the sets match (same elements, just reordered) - if elems_i_set != elems_f_set: - return None - - # Create the mapping: what goes to position i - # If elems_f[i] == elems_i[j], then position i gets value from position j - permutation_map = {} - for i, elem_f in enumerate(elems_f): - # Find which element in elems_i matches elem_f - for j, elem_i in enumerate(elems_i): - if self._elements_equal(elem_i, elem_f): - permutation_map[i] = j # position i gets value from position j - break - - return permutation_map - - def _elements_equal(self, elem1, elem2): - """Check if two elements refer to the same qubit.""" - # Both are register[index] references - if hasattr(elem1, "reg") and hasattr(elem1, "index") and hasattr(elem2, "reg") and hasattr(elem2, "index"): - return elem1.reg.sym == elem2.reg.sym and elem1.index == elem2.index - # Both are full register references - if hasattr(elem1, "sym") and hasattr(elem2, "sym"): - return elem1.sym == elem2.sym - return False - - def _generate_permutation_code(self, permutation_map, elems_i, elems_f): - """Generate code for complex permutation patterns.""" - _ = elems_f # Currently not used, reserved for future use - # Identify cycles in the permutation - cycles = self._find_permutation_cycles(permutation_map) - - if not cycles: - return Comment("Identity permutation - no action needed") - - # Add comment describing the permutation - self.current_block.statements.append( - Comment(f"Permute {len(elems_i)} elements"), - ) - - # For each cycle, generate swap operations - for cycle in cycles: - if len(cycle) == 1: - # Fixed point, no action needed - continue - if len(cycle) == 2: - # Simple swap - self._generate_swap(elems_i[cycle[0]], elems_i[cycle[1]]) - else: - # Multi-element cycle: use temporary variables - self._generate_cycle_permutation(cycle, elems_i) - - return None # Statements already added - - def _find_permutation_cycles(self, permutation_map): - """Find cycles in a permutation.""" - visited = set() - cycles = [] - - for start in permutation_map: - if start in visited: - continue - - cycle = [] - current = start - while current not in visited: - visited.add(current) - cycle.append(current) - current = permutation_map.get(current, current) - - if len(cycle) > 0 and (len(cycle) > 1 or cycle[0] != permutation_map.get(cycle[0], cycle[0])): - cycles.append(cycle) - - return cycles - - def _generate_swap(self, elem1, elem2): - """Generate code to swap two elements.""" - ref1 = self._convert_qubit_ref(elem1) - ref2 = self._convert_qubit_ref(elem2) - - # Use a temporary variable - temp_var = "_temp_swap" - - self.current_block.statements.append( - Assignment(target=VariableRef(temp_var), value=ref1), - ) - self.current_block.statements.append( - Assignment(target=ref1, value=ref2), - ) - self.current_block.statements.append( - Assignment(target=ref2, value=VariableRef(temp_var)), - ) - - def _generate_cycle_permutation(self, cycle, elements): - """Generate code for a multi-element cycle permutation.""" - if len(cycle) < 2: - return - - # Save the first element - first_elem = elements[cycle[0]] - first_ref = self._convert_qubit_ref(first_elem) - temp_var = "_temp_cycle" - - self.current_block.statements.append( - Assignment(target=VariableRef(temp_var), value=first_ref), - ) - - # Shift elements in the cycle - for i in range(len(cycle) - 1): - src_elem = elements[cycle[i + 1]] - dst_elem = elements[cycle[i]] - - src_ref = self._convert_qubit_ref(src_elem) - dst_ref = self._convert_qubit_ref(dst_elem) - - self.current_block.statements.append( - Assignment(target=dst_ref, value=src_ref), - ) - - # Complete the cycle - last_elem = elements[cycle[-1]] - last_ref = self._convert_qubit_ref(last_elem) - - self.current_block.statements.append( - Assignment(target=last_ref, value=VariableRef(temp_var)), - ) - - def _convert_block_call(self, block) -> Statement | None: - """Convert a block to a function call or inline expansion.""" - block_type = type(block) - block_name = block_type.__name__ - - # Get original block info if preserved - original_block_name = getattr(block, "block_name", block_name) - original_block_module = getattr(block, "block_module", block_type.__module__) - - # If we're in a loop, check if we need to restore array sizes before this call - if self.scope_manager.is_in_loop(): - self._restore_array_sizes_for_block_call(block) - - # Check if this is a core block that should be inlined - if original_block_name in self.CORE_BLOCKS: - # Inline core blocks - if hasattr(block, "ops"): - self.current_block.statements.append( - Comment(f"Begin {block_name} block"), - ) - for op in block.ops: - stmt = self._convert_operation(op) - if stmt: - self.current_block.statements.append(stmt) - self.current_block.statements.append( - Comment(f"End {block_name} block"), - ) - return None - - # For non-core blocks, create a function - block_signature = self._get_block_signature(block) - - # Check if we already have a function for this block type - if block_signature not in self.block_registry: - # Determine struct prefix if this block operates on a struct - struct_prefix = None - deps = self._analyze_block_dependencies(block) - - # Check if all variables belong to the same struct - for prefix, info in self.struct_info.items(): - vars_in_this_struct = set() - for var in info["var_names"].values(): - if var in deps["quantum"] or var in deps["classical"]: - vars_in_this_struct.add(var) - - # If this block operates on variables from this struct, use - # QEC code name if available - if vars_in_this_struct: - # Use the QEC code name if we have it, otherwise use prefix - struct_prefix = info.get("qec_code_name", prefix) - break - - # Generate a unique function name with struct prefix - # Include module name if not __main__ - base_name = original_block_name - - # For Parallel blocks with content hash, include the content info - if len(block_signature) > 2 and original_block_name == "Parallel": - content_hash = block_signature[2] - # Create a more readable suffix from the hash - # e.g., "H_H" becomes "_h", "X_X" becomes "_x" - if content_hash: - gates = content_hash.split("_") - if all(g == gates[0] for g in gates): - # All gates are the same type - base_name += f"_{gates[0].lower()}" - else: - # Mixed gates - use first letter of each - suffix = "_".join(g[0].lower() for g in gates[:3]) # Limit to 3 - base_name += f"_{suffix}" - - if original_block_module and original_block_module != "__main__": - # Extract just the last part of the module name (e.g., 'test_linearity_patterns') - module_parts = original_block_module.split(".") - module_name = module_parts[-1] if module_parts else "" - if module_name and module_name.startswith("test_"): - # For test modules, include the module name - func_name = self._generate_function_name( - f"{module_name}_{base_name}", - struct_prefix, - ) - else: - func_name = self._generate_function_name(base_name, struct_prefix) - else: - func_name = self._generate_function_name(base_name, struct_prefix) - self.block_registry[block_signature] = func_name - - # Add to pending functions if not already discovered - if func_name not in self.discovered_functions: - self.pending_functions.append((block, func_name, block_signature)) - self.discovered_functions.add(func_name) - else: - func_name = self.block_registry[block_signature] - - # Generate function call - stmt = self._generate_function_call(func_name, block) - if stmt: - self.current_block.statements.append(stmt) - return None # Already added to current block - - def _get_block_signature(self, block) -> tuple: - """Get a unique signature for a block type.""" - block_type = type(block) - block_name = block_type.__name__ - original_block_name = getattr(block, "block_name", block_name) - original_block_module = getattr(block, "block_module", block_type.__module__) - - # For Parallel blocks, include content hash to differentiate blocks - # with different operations - if original_block_name == "Parallel" and hasattr(block, "ops"): - content_hash = self._get_block_content_hash(block) - return (original_block_name, original_block_module, content_hash) - - # For now, use block name and module as signature - # Could be enhanced to include parameter info - return (original_block_name, original_block_module) - - def _generate_function_name( - self, - block_name: str, - struct_prefix: str | None = None, - ) -> str: - """Generate a unique function name for a block. - - Args: - block_name: The original block name (e.g., 'H', 'PrepRUS') - struct_prefix: Optional struct prefix (e.g., 'c' for c_struct) - """ - # Convert CamelCase to snake_case, handling acronyms better - import re - - # First, handle transitions from lowercase to uppercase - snake_case = re.sub("([a-z0-9])([A-Z])", r"\1_\2", block_name) - - # Then handle multiple consecutive capitals (acronyms) - snake_case = re.sub("([A-Z]+)([A-Z][a-z])", r"\1_\2", snake_case) - - # Convert to lowercase - snake_case = snake_case.lower() - - # Add struct prefix if provided - base_name = f"{struct_prefix}_{snake_case}" if struct_prefix else snake_case - - # Ensure uniqueness - func_name = base_name - counter = 1 - while func_name in self.generated_functions: - func_name = f"{base_name}_{counter}" - counter += 1 - - return func_name - - def _get_block_content_hash(self, block) -> str: - """Get a hash of block operations for differentiation. - - This is used to differentiate Parallel blocks with different operations. - """ - ops_summary = [] - if hasattr(block, "ops"): - for op in block.ops: - op_type = type(op).__name__ - # Include gate types to differentiate - ops_summary.append(op_type) - - # Create a simple hash from operation types - return "_".join(sorted(ops_summary)) if ops_summary else "empty" - - def _generate_function_call(self, func_name: str, block) -> Statement: - """Generate a function call for a block.""" - from pecos.slr.gen_codes.guppy.ir import Assignment, Comment, VariableRef - - # Analyze block dependencies to determine arguments - deps = self._analyze_block_dependencies(block) - - # Initialize as procedural, will be updated after resource flow analysis - is_procedural_function = True - - # CRITICAL: Save which arrays are currently unpacked BEFORE processing arguments - # This is needed to detect if a function call return should use a fresh variable name - # (when the parameter was unpacked and consumed in argument processing) - unpacked_before_call = set() - if hasattr(self, "unpacked_vars"): - unpacked_before_call = set(self.unpacked_vars.keys()) - - # Determine which variables need to be passed as arguments - args = [] - quantum_args = [] # Track quantum args for return value assignment - - # Check if we should pass structs instead of individual arrays - struct_args = set() # Structs we've already added - vars_in_structs = set() # Variables that are part of structs - - # First pass: identify which variables are part of structs - for prefix, info in self.struct_info.items(): - for var in info["var_names"].values(): - if var in deps["quantum"] or var in deps["classical"]: - vars_in_structs.add(var) - if prefix not in struct_args: - # Check if this struct has been refreshed (e.g., from a previous function call) - struct_to_use = prefix - if hasattr(self, "refreshed_arrays") and prefix in self.refreshed_arrays: - # Use the refreshed name (e.g., c_fresh instead of c) - struct_to_use = self.refreshed_arrays[prefix] - - # Check if this is a struct that was decomposed and needs reconstruction - # This includes @owned structs and fresh structs that were decomposed for field access - needs_reconstruction = False - struct_was_decomposed = struct_to_use in self.decomposed_vars or ( - prefix in self.decomposed_vars and struct_to_use == prefix - ) - if hasattr(self, "decomposed_vars") and struct_was_decomposed: - # Check if the struct we want to use was decomposed - needs_reconstruction = True - - if needs_reconstruction: - # Struct was decomposed - reconstruct it from decomposed variables - struct_info = self.struct_info[prefix] - - # Create a unique name for the reconstructed struct - reconstructed_var = self._get_unique_var_name( - f"{prefix}_reconstructed", - ) - - # Create struct constructor call - constructor_args = [] - - # Check if we have decomposed field variables for this struct - if struct_to_use in self.decomposed_vars: - # Use the decomposed field variables - field_mapping = self.decomposed_vars[struct_to_use] - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - # Fallback to default naming if not in mapping - field_var = field_mapping.get( - suffix, - f"{struct_to_use}_{suffix}", - ) - constructor_args.append(VariableRef(field_var)) - else: - # Use the default field variable naming - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{prefix}_{suffix}" - - # Check if we have a fresh version of this field variable - if hasattr(self, "refreshed_arrays") and field_var in self.refreshed_arrays: - field_var = self.refreshed_arrays[field_var] - elif hasattr(self, "var_remapping") and field_var in self.var_remapping: - field_var = self.var_remapping[field_var] - - constructor_args.append(VariableRef(field_var)) - - struct_constructor = FunctionCall( - func_name=struct_info["struct_name"], - args=constructor_args, - ) - - # Add reconstruction statement - reconstruction_stmt = Assignment( - target=VariableRef(reconstructed_var), - value=struct_constructor, - ) - self.current_block.statements.append(reconstruction_stmt) - - # Use the reconstructed struct - struct_to_use = reconstructed_var - - # Add the struct as an argument - args.append(VariableRef(struct_to_use)) - struct_args.add(prefix) - # Track this for return value handling - if var in deps["quantum"]: - quantum_args.append(prefix) - - # Track unpacked arrays that need restoration after procedural calls - saved_unpacked_arrays = [] - - # Black Box Pattern: Pass complete global arrays to maintain SLR semantics - for var in sorted(deps["quantum"] & deps["reads"]): - # Check if this is an ancilla that was excluded from structs - is_excluded_ancilla = hasattr(self, "ancilla_qubits") and var in self.ancilla_qubits - - # Skip if this variable is part of a struct UNLESS it's an excluded ancilla - if var in vars_in_structs and not is_excluded_ancilla: - continue - - # Check if this variable needs remapping (we're inside a function) - actual_var = var - if hasattr(self, "var_remapping") and var in self.var_remapping: - actual_var = self.var_remapping[var] - - # For procedural functions (borrow), we can't use unpacked arrays - they need the original array - # For consuming functions (@owned), reconstruct the array from unpacked elements - # Also handle dynamically allocated arrays and decomposed ancilla arrays - if hasattr(self, "decomposed_ancilla_arrays") and var in self.decomposed_ancilla_arrays: - # Check if the array has already been reconstructed into a variable - if hasattr(self, "reconstructed_arrays") and var in self.reconstructed_arrays: - # Check if it was unpacked AFTER reconstruction - if hasattr(self, "unpacked_vars") and actual_var in self.unpacked_vars: - # Array was unpacked after reconstruction - need to reconstruct again - # First check if there's a refreshed version from a previous function call - if hasattr(self, "refreshed_arrays") and var in self.refreshed_arrays: - refreshed_name = self.refreshed_arrays[var] - args.append(VariableRef(refreshed_name)) - quantum_args.append(var) - else: - # Reconstruct from unpacked elements - element_names = self.unpacked_vars[actual_var] - array_construction = self._create_array_construction( - element_names, - ) - args.append(array_construction) - quantum_args.append(var) - else: - # Use the reconstructed array variable directly (not unpacked) - args.append(VariableRef(actual_var)) - quantum_args.append(var) - else: - # This array was decomposed into individual qubits - # Check if there's a refreshed version from a previous function call - if hasattr(self, "refreshed_arrays") and var in self.refreshed_arrays: - # Use the refreshed array from previous function call - refreshed_name = self.refreshed_arrays[var] - args.append(VariableRef(refreshed_name)) - quantum_args.append(var) - else: - # Reconstruct from decomposed elements - element_names = self.decomposed_ancilla_arrays[var] - array_construction = self._create_array_construction( - element_names, - ) - args.append(array_construction) - quantum_args.append(var) - elif hasattr(self, "dynamic_allocations") and var in self.dynamic_allocations: - # Dynamically allocated - check if there's a refreshed version first - if hasattr(self, "refreshed_arrays") and var in self.refreshed_arrays: - # Use the refreshed array from previous function call - refreshed_name = self.refreshed_arrays[var] - args.append(VariableRef(refreshed_name)) - quantum_args.append(var) - else: - # Dynamically allocated - construct array from individual qubits - # Get the size from context - var_info = self.context.lookup_variable(var) - if var_info and var_info.size: - size = var_info.size - element_names = [f"{var}_{i}" for i in range(size)] - array_construction = self._create_array_construction( - element_names, - ) - args.append(array_construction) - quantum_args.append(var) - else: - # Fallback - just pass the variable (will likely error) - args.append(VariableRef(actual_var)) - quantum_args.append(actual_var) - elif hasattr(self, "unpacked_vars") and actual_var in self.unpacked_vars: - # Array was unpacked (either from parameter or return value) - # OPTIMIZATION: If we're using ALL unpacked elements AND the array variable exists, - # just pass the array variable instead of reconstructing inline - # This happens when a function returns an array, we unpack it, then immediately - # pass it to another function - in this case, just use the variable! - element_names = self.unpacked_vars[actual_var] - - # Check if we have partial consumption (via index_mapping) - has_partial_consumption = hasattr(self, "index_mapping") and actual_var in self.index_mapping - - # Check if this was unpacked from a parameter - is_parameter_unpacked = ( - hasattr(self, "parameter_unpacked_arrays") and actual_var in self.parameter_unpacked_arrays - ) - - # Use the variable directly if: - # 1. No partial consumption (using all elements) - # 2. Not parameter-unpacked (return-unpacked arrays have the variable available) - # 3. The variable wasn't consumed yet - if not has_partial_consumption and not is_parameter_unpacked: - # The array variable should still exist - use it directly - args.append(VariableRef(actual_var)) - quantum_args.append(actual_var) - # Don't delete from unpacked_vars yet - might be needed later - else: - # Use inline array construction - # This is needed for: - # - Partial consumption (not all elements) - # - Parameter-unpacked arrays (no array variable exists) - array_construction = self._create_array_construction(element_names) - args.append(array_construction) - quantum_args.append(actual_var) - - # CRITICAL: After using inline construction, the unpacked elements are CONSUMED - # Remove from tracking so subsequent calls use the returned value instead - if hasattr(self, "parameter_unpacked_arrays"): - self.parameter_unpacked_arrays.discard(actual_var) - del self.unpacked_vars[actual_var] - if hasattr(self, "index_mapping") and actual_var in self.index_mapping: - del self.index_mapping[actual_var] - else: - # Array is already in the correct global form - # Check if this array has been refreshed (e.g., from a previous function call) - if hasattr(self, "refreshed_arrays") and var in self.refreshed_arrays: - # Use the refreshed name (e.g., data_fresh instead of data) - refreshed_name = self.refreshed_arrays[var] - args.append(VariableRef(refreshed_name)) - quantum_args.append(var) # Keep original name for tracking - else: - args.append(VariableRef(actual_var)) - quantum_args.append(actual_var) - - # Pass classical variables that are read or written (arrays are passed by reference) - for var in sorted(deps["classical"] & (deps["reads"] | deps["writes"])): - # Skip if this variable is part of a struct - if var in vars_in_structs: - continue - - # Check if this variable needs remapping - actual_var = var - if hasattr(self, "var_remapping") and var in self.var_remapping: - actual_var = self.var_remapping[var] - - # Classical arrays also need reconstruction if they were unpacked - if hasattr(self, "unpacked_vars") and actual_var in self.unpacked_vars: - # Reconstruct the classical array from unpacked elements - element_names = self.unpacked_vars[actual_var] - array_construction = self._create_array_construction(element_names) - - # Use a unique name for reconstruction to avoid linearity violation - reconstructed_var = self._get_unique_var_name(f"{actual_var}_array") - reconstruction_stmt = Assignment( - target=VariableRef(reconstructed_var), - value=array_construction, - ) - self.current_block.statements.append(reconstruction_stmt) - - # Clear the unpacking info since we've reconstructed the array - del self.unpacked_vars[actual_var] - args.append(VariableRef(reconstructed_var)) - else: - # Array is already in the correct form - args.append(VariableRef(actual_var)) - - # Create function call - call = FunctionCall( - func_name=func_name, - args=args, - ) - - # Use proper resource flow analysis to determine what's actually returned - _consumed_qubits, live_qubits = self._analyze_quantum_resource_flow(block) - - # Determine if this is a procedural function based on resource flow - # If the block has live qubits that should be returned, it's not procedural - has_live_qubits = bool(live_qubits) - is_procedural_function = not has_live_qubits - - # HYBRID APPROACH: Use smart detection for consistent function calls - if hasattr(self, "function_return_types") and func_name in self.function_return_types: - func_return_type = self.function_return_types[func_name] - if func_return_type == "None": - is_procedural_function = True - else: - # Fallback: use the same smart detection logic - should_be_procedural_call = self._should_function_be_procedural( - func_name, - block, - [(arg, "array[quantum.qubit, 2]") for arg in quantum_args], - has_live_qubits, - ) - if should_be_procedural_call: - is_procedural_function = True - - # Override: if function has multiple quantum args, it's likely not procedural - # if len(quantum_args) > 1: - # is_procedural_function = False - - # Override: if function returns a tuple, it's not procedural - # if func_name in self.function_return_types: - # func_return_type = self.function_return_types[func_name] - # if func_return_type.startswith("tuple["): - # is_procedural_function = False - - # If it appears to be procedural based on live qubits, double-check with signature - if is_procedural_function and hasattr(block, "__init__"): - import inspect - - try: - sig = inspect.signature(block.__class__.__init__) - return_annotation = sig.return_annotation - if return_annotation is None or return_annotation is type(None) or str(return_annotation) == "None": - is_procedural_function = True - else: - is_procedural_function = False # Has return annotation, not procedural - except (ValueError, TypeError, AttributeError): - # Default to procedural if can't inspect signature - # ValueError: signature cannot be determined - # TypeError: object is not callable - # AttributeError: missing expected attributes - is_procedural_function = True - - # Now determine if the calling function consumes quantum arrays - deps_for_func = self._analyze_block_dependencies(block) - has_quantum_params = bool(deps_for_func["quantum"] & deps_for_func["reads"]) - # Check if we're in main function - is_main_context = self.current_function_name == "main" - # Functions consume quantum arrays if they have quantum params AND the called function is not procedural - # This supports the nested blocks pattern where non-procedural functions return live qubits - function_consumes = has_quantum_params and (is_main_context or not is_procedural_function) - - # Force function consumption if multiple quantum args (likely tuple return) - if has_quantum_params and len(quantum_args) > 1: - function_consumes = True - - # Track consumed arrays in main function - # Check if the function being called has @owned parameters - if self.current_function_name == "main": - # Since function_info is not populated yet when building main, - # we need to be conservative and assume all quantum arrays passed to functions - # might have @owned parameters. This is especially true for procedural functions - # that have nested blocks (like prep_rus). - - # For safety, mark all quantum arrays passed to functions as consumed - # This prevents double-use errors when arrays are passed to @owned functions - for arg in quantum_args: - if isinstance(arg, str): # It's an array name - if not hasattr(self, "consumed_resources"): - self.consumed_resources = {} - if arg not in self.consumed_resources: - self.consumed_resources[arg] = set() - # Mark the entire array as consumed conservatively - # We don't know the exact size, but we can mark it as fully consumed - # by using a large range (quantum arrays are typically small) - self.consumed_resources[arg].update( - range(100), - ) # Conservative upper bound - - # Use natural SLR semantics: arrays are global resources modified in-place - # Functions that use unpacking still return arrays at boundaries to maintain this illusion - # Keep track of struct arguments before filtering - struct_args = [arg for arg in quantum_args if isinstance(arg, str) and arg in self.struct_info] - - quantum_args = [arg for arg in quantum_args if isinstance(arg, str)] # Filter for array names - - # Check if we're returning structs (already collected above) - - # Check if the function returns something based on our function definitions - self._function_returns_something(func_name) - - # CRITICAL: Determine actual return type by analyzing the block being called - # This is more reliable than looking it up in function_return_types which may not be populated yet - # APPROACH 1: Check Python type annotation on the block class - actual_returns_tuple = False - if hasattr(block, "__class__"): - try: - import inspect - - sig = inspect.signature(block.__class__.__init__) - return_annotation = sig.return_annotation - if return_annotation and return_annotation is not type(None): - return_str = str(return_annotation) - # Check if it's a tuple type annotation - actual_returns_tuple = ( - "tuple[" in return_str.lower() - or "Tuple[" in return_str - or (hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple) - ) - except (ValueError, TypeError, AttributeError): - # Can't inspect signature, will use APPROACH 2 - pass # Fallback to approach 2 - - # APPROACH 2: Infer from live_qubits analysis - # If live_qubits has multiple quantum arrays, function returns a tuple - if not actual_returns_tuple and len(live_qubits) > 1: - # Multiple quantum arrays are live - function returns a tuple - actual_returns_tuple = True - - # For both @owned and non-@owned functions, only return arrays with live qubits - # Fully consumed arrays should not be returned - returned_quantum_args = [] - for arg in quantum_args: - if isinstance(arg, str): - # Check if this arg (possibly reconstructed) maps to an original array with live qubits - original_name = arg - # Handle reconstructed array names (e.g., _q_array -> q) - if hasattr(self, "array_remapping") and arg in self.array_remapping: - original_name = self.array_remapping[arg] - elif arg.startswith("_") and arg.endswith("_array"): - # Try to infer original name from reconstructed name - # _q_array -> q - potential_original = arg[1:].replace("_array", "") - if potential_original in live_qubits: - original_name = potential_original - - if original_name in live_qubits: - returned_quantum_args.append( - arg, - ) # Use the actual arg name for assignment - - # If we forced function_consumes but have no returned_quantum_args, - # assume all quantum args should be returned (common with partial consumption patterns) - if function_consumes and not returned_quantum_args and len(quantum_args) > 1: - returned_quantum_args = list(quantum_args) - - # Also include structs that have live quantum fields - for struct_arg in struct_args: - if struct_arg not in returned_quantum_args and struct_arg in self.struct_info: - # Check if struct has any live quantum fields - struct_info = self.struct_info[struct_arg] - has_live_fields = False - for suffix, var_type, size in struct_info.get("fields", []): - if var_type == "qubit": - var_name = struct_info["var_names"].get(suffix) - if var_name and var_name in live_qubits: - has_live_fields = True - break - if has_live_fields: - returned_quantum_args.append(struct_arg) - - # Track arrays that are consumed (passed with @owned but not returned) - # Also mark arrays as consumed when passed to nested blocks (even without @owned) - is_nested_block = False - try: - from pecos.slr import Block as SlrBlock - - if hasattr(block, "__class__") and issubclass(block.__class__, SlrBlock): - is_nested_block = True - except (TypeError, AttributeError): - # Not a class or missing expected attributes - pass - - if (function_consumes or is_nested_block) and hasattr(self, "consumed_arrays"): - - # Check function signature for @owned parameters - owned_params = set() - - # TEMPORARY FIX: Hardcode known @owned parameter patterns for quantum error correction functions - # This covers the specific functions that are causing issues in the Steane code - known_owned_patterns = { - "prep_rus": [0, 1], # c_a and c_d are both @owned - "prep_encoding_ft_zero": [0, 1], # c_a and c_d are both @owned - "prep_zero_verify": [0, 1], # c_a and c_d are both @owned - "prep_encoding_non_ft_zero": [0], # c_d is @owned (first parameter) - "log_zero_rot": [0], # c_d is @owned (first parameter) - "h": [0], # c_d is @owned (first parameter) - } - - if func_name in known_owned_patterns: - owned_indices = known_owned_patterns[func_name] - for i in owned_indices: - if i < len(quantum_args): - owned_arg = quantum_args[i] - owned_params.add(owned_arg) - - # Try to find the function definition in the current module (future improvement) - # [Previous function definition lookup code can be restored later if needed] - - for arg in quantum_args: - if isinstance(arg, str): - # CRITICAL: Determine if this array should be marked as consumed - # Two cases: - # 1. Procedural function (returns None): ALL args are consumed - # 2. Functional function (returns values): Only args NOT returned are consumed - - # Procedural function - mark all args as consumed - # Functional function - only mark if not returned - should_mark_consumed = True if is_procedural_function else arg not in returned_quantum_args - - if should_mark_consumed: - # This array was consumed (not returned) - # Track the actual array name that was passed (might be reconstructed or fresh) - # Check if there's a fresh/refreshed version of this array - actual_name_to_mark = arg - if hasattr(self, "refreshed_arrays") and arg in self.refreshed_arrays: - # Use the refreshed/fresh name (e.g., c_d_fresh instead of c_d) - actual_name_to_mark = self.refreshed_arrays[arg] - elif hasattr(self, "array_remapping") and arg in self.array_remapping: - # Use the remapped name - actual_name_to_mark = self.array_remapping[arg] - - self.consumed_arrays.add(actual_name_to_mark) - # Also mark the original name to prevent double cleanup - if actual_name_to_mark != arg: - self.consumed_arrays.add(arg) - - # For procedural functions, don't assign the result - just call the function - if is_procedural_function: - # Create expression statement for the function call (no assignment) - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, _context): - return [] - - def render(self, context): - return self.expr.render(context) - - # After a procedural call, restore the unpacked arrays - # Procedural functions borrow, they don't consume, so the unpacked variables are still valid - if saved_unpacked_arrays: - for item in saved_unpacked_arrays: - if len(item) == 3: # Has reconstructed name and element names - array_name, element_names, _ = item - # Restore the unpacked variables - they're still valid after a borrow - if not hasattr(self, "unpacked_vars"): - self.unpacked_vars = {} - self.unpacked_vars[array_name] = element_names - - return ExpressionStatement(call) - - # With the functional pattern, functions that consume quantum arrays return the live ones - if returned_quantum_args and function_consumes: - # Black Box Pattern: Function returns modified global arrays/structs - # Assign directly back to original names to maintain SLR semantics - # ALSO handle @owned functions that return reconstructed structs - statements = [] - - # Check if the function returns a tuple by looking up its return type - func_return_type = self.function_return_types.get(func_name, "") - returns_tuple = func_return_type.startswith("tuple[") - - # CRITICAL: Use actual_returns_tuple from block inspection if available - # This is more reliable than function_return_types which may not be populated yet - if actual_returns_tuple: - returns_tuple = True - - # Don't force tuple unpacking based on argument count - use actual return type - # A function can take multiple args but return only one (e.g., consume some, return others) - - if len(returned_quantum_args) == 1 and not returns_tuple: - # Single return - assign back to the same variable name - # In Guppy's linear type system, reassigning to the same name shadows the old binding - name = returned_quantum_args[0] - - # Handle both reconstructed array names (_q_array) and original names (q) - base_name = name[1:].replace("_array", "") if name.startswith("_") and name.endswith("_array") else name - - # CRITICAL: If the variable was already unpacked (parameter unpacked at function start), - # we cannot assign to the same name - need a fresh variable name - # Example: def f(c_d: array @owned): - # __c_d_0, ... = c_d # c_d consumed - # c_d = h(...) # ERROR - c_d already consumed! - # Fix: use fresh name like c_d_fresh - # Use unpacked_before_call (saved state before argument processing) - # because argument processing may have deleted the array from unpacked_vars - if base_name in unpacked_before_call: - # Variable was unpacked - use fresh name for assignment - fresh_name = self._get_unique_var_name(f"{name}_fresh") - # Clear the unpacked tracking if still present - if hasattr(self, "unpacked_vars") and base_name in self.unpacked_vars: - del self.unpacked_vars[base_name] - else: - # Variable wasn't unpacked - can assign to same name (shadows old binding) - fresh_name = name - - # Use the appropriate variable name for the assignment - assignment = Assignment(target=VariableRef(fresh_name), value=call) - statements.append(assignment) - - # Track fresh variables for cleanup in procedural functions - # If we created a fresh variable (not same as parameter name), track it - if fresh_name != name: - if not hasattr(self, "fresh_return_vars"): - self.fresh_return_vars = {} - self.fresh_return_vars[fresh_name] = { - "original": name, - "func_name": func_name, - "is_quantum_array": True, - } - - # Update context for returned variable - self._update_context_for_returned_variable(name, fresh_name) - - # Also update array remapping for cleanup logic - if not hasattr(self, "array_remapping"): - self.array_remapping = {} - self.array_remapping[name] = fresh_name - - # Track this array as refreshed by function call - self.refreshed_arrays[name] = fresh_name - # Track which function refreshed this array and its position (0 for single return) - if not hasattr(self, "refreshed_by_function"): - self.refreshed_by_function = {} - self.refreshed_by_function[name] = { - "function": func_name, - "position": 0, - } - - # If this is a struct, decompose it to avoid field access issues - if name in self.struct_info: - struct_info = self.struct_info[name] - # Always decompose fresh structs to avoid AlreadyUsedError on field access - needs_decomposition = True - - if needs_decomposition: - # IMPORTANT: We cannot re-unpack from the struct because it may have been - # consumed by the function call. Instead, we need to - # update our var_remapping - # to indicate that the unpacked variables are no longer valid. - # The code should use the struct fields directly after function calls. - - # Comment explaining why we can't re-unpack - statements.append( - Comment( - "Note: Cannot use unpacked variables after calling function with @owned struct", - ), - ) - - # For fresh structs returned from functions, we need to decompose them immediately - # to avoid AlreadyUsedError when accessing fields - struct_name = struct_info["struct_name"].replace("_struct", "") - decompose_func_name = f"{struct_name}_decompose" - - # Generate field variables for decomposition - field_vars = [] - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{fresh_name}_{suffix}" - field_vars.append(field_var) - - # Add decomposition statement for the fresh struct - statements.append( - Comment( - "Decompose fresh struct to avoid field access on consumed struct", - ), - ) - - class TupleAssignment(Statement): - def __init__(self, targets, value): - self.targets = targets - self.value = value - - def analyze(self, context): - self.value.analyze(context) - - def render(self, context): - target_str = ", ".join(self.targets) - value_str = self.value.render(context)[0] - return [f"{target_str} = {value_str}"] - - decompose_call = FunctionCall( - func_name=decompose_func_name, - args=[VariableRef(fresh_name)], - ) - - decomposition_stmt = TupleAssignment( - targets=field_vars, - value=decompose_call, - ) - statements.append(decomposition_stmt) - - # Track decomposed variables for field access - if not hasattr(self, "decomposed_vars"): - self.decomposed_vars = {} - field_mapping = {} - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{fresh_name}_{suffix}" - field_mapping[suffix] = field_var - self.decomposed_vars[fresh_name] = field_mapping - - # Update var_remapping to indicate these variables should not be used - # by mapping them back to struct field access - for var_name in struct_info["var_names"].values(): - if var_name in self.var_remapping: - # This will cause future references to use struct.field notation - del self.var_remapping[var_name] - - # Force unpacking for arrays that need element access after function calls - # This is the core fix for the nested blocks MoveOutOfSubscriptError - # For refreshed arrays, check if they have element access that requires unpacking - needs_unpacking_for_refresh = False - if name in self.refreshed_arrays: - # CRITICAL FIX: Don't automatically unpack refreshed arrays - # The original analysis was for the INPUT parameter, not the refreshed return value - # Only unpack if there's explicit subscript usage AFTER this call - # This is handled by force_unpack_for_subscript below - needs_unpacking_for_refresh = False - - # CRITICAL: Only unpack returned arrays if they actually need element access - # Don't unpack just because the array was unpacked at function start - # Check if the array CURRENTLY needs unpacking based on how it's used AFTER this call - should_unpack_returned = ( - # Only unpack if actively needed for element access after this point - needs_unpacking_for_refresh - ) and name not in self.struct_info - - # CRITICAL: Always check if function returns array - # If so, force unpacking to avoid MoveOutOfSubscriptError - force_unpack_for_subscript = False - return_array_size_check = None - - # Try to get return type from function_return_types (if already analyzed) - if func_name in self.function_return_types: - return_type = self.function_return_types[func_name] - import re - - match = re.search(r"array\[.*?,\s*(\d+)\]", return_type) - if match: - return_array_size_check = int(match.group(1)) - - # Check if next operation uses subscript on this array - # This catches the pattern: q = func(q); measure(q[0]) - if ( - hasattr(self, "current_block_ops") - and hasattr(self, "current_op_index") - and self.current_block_ops is not None - and self.current_op_index is not None - ): - next_index = self.current_op_index + 1 - if next_index < len(self.current_block_ops): - next_op = self.current_block_ops[next_index] - # Check if next op uses subscript on this array - if hasattr(next_op, "qargs"): - for qarg in next_op.qargs: - if ( - hasattr(qarg, "reg") - and hasattr(qarg.reg, "sym") - and qarg.reg.sym == name - and hasattr(qarg, "index") - ): - # Next op uses subscript on returned array - force_unpack_for_subscript = True - break - else: - # Function not analyzed yet - use live_qubits from block analysis - # Check if this array has live qubits that indicate return size - if name in live_qubits and len(live_qubits[name]) >= 1: - # The block returns live qubits from this array - return_array_size_check = len(live_qubits[name]) - - # Check if next operation uses subscript on this array - if ( - hasattr(self, "current_block_ops") - and hasattr(self, "current_op_index") - and self.current_block_ops is not None - and self.current_op_index is not None - ): - next_index = self.current_op_index + 1 - if next_index < len(self.current_block_ops): - next_op = self.current_block_ops[next_index] - # Check if next op uses subscript on this array - if hasattr(next_op, "qargs"): - for qarg in next_op.qargs: - if ( - hasattr(qarg, "reg") - and hasattr(qarg.reg, "sym") - and qarg.reg.sym == name - and hasattr(qarg, "index") - ): - # Next op uses subscript on returned array - force_unpack_for_subscript = True - break - - if should_unpack_returned or force_unpack_for_subscript: - # Use the size we already extracted - return_array_size = return_array_size_check - - # If we know the return size and it's >= 1, unpack for element access - # Even size-1 arrays need unpacking to avoid MoveOutOfSubscriptError - if return_array_size and return_array_size >= 1: - # Generate unpacked variable names - # IMPORTANT: Use unique suffix "_ret" to avoid shadowing initial allocations - # When we do local_allocate strategy, we create q_0, q_1, q_2 - # When function returns array, we unpack to q_0_ret, q_1_ret to avoid conflicts - # CRITICAL: Make names unique across multiple unpackings using a counter - if not hasattr(self, "_unpack_counter"): - self._unpack_counter = {} - if name not in self._unpack_counter: - self._unpack_counter[name] = 0 - else: - self._unpack_counter[name] += 1 - unpack_suffix = ( - f"_ret{self._unpack_counter[name]}" if self._unpack_counter[name] > 0 else "_ret" - ) - element_names = [f"{name}_{i}{unpack_suffix}" for i in range(return_array_size)] - - # Add unpacking statement using ArrayUnpack IR class. - # When the array was refreshed by a function call (e.g., - # q → q_fresh), unpack from the refreshed name -- the - # original is moved/consumed at this point. Without - # this, generated Guppy looks like `q_0_ret, = q` and - # Guppy rejects with WrongNumberOfUnpacksError or - # AlreadyUsedError. - from pecos.slr.gen_codes.guppy.ir import ArrayUnpack - - unpack_source = self.refreshed_arrays.get(name, name) - unpack_stmt = ArrayUnpack( - targets=element_names, - source=unpack_source, - ) - statements.append(unpack_stmt) - - # Track unpacked variables - if not hasattr(self, "unpacked_vars"): - self.unpacked_vars = {} - self.unpacked_vars[name] = element_names - - # CRITICAL: Track index mapping for partial consumption - # If live_qubits tells us which original indices are in the returned array, - # create a mapping from original index → unpacked variable index - index_map: dict[int, int] | None = None - if name in live_qubits: - original_indices = sorted(live_qubits[name]) - if not hasattr(self, "index_mapping"): - self.index_mapping = {} - # Map original index to position in returned/unpacked array - index_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(original_indices)} - self.index_mapping[name] = index_map - - # Mirror to unified variable state (see variable_state.py) - self.var_state.bind_unpacked(name, list(element_names), index_map) - - # Update context - if hasattr(self, "context"): - var = self.context.lookup_variable(name) - if var: - var.is_unpacked = True - var.unpacked_names = element_names - - # DON'T immediately reconstruct - just leave the array unpacked - # Reconstruction will happen on-demand when needed (see below) - elif hasattr(self, "unpacked_vars") and name in self.unpacked_vars: - # Classical array or other case - invalidate old unpacked variables - old_element_names = self.unpacked_vars[name] - del self.unpacked_vars[name] - - # Also update the context to invalidate unpacked variable information - if hasattr(self, "context"): - var = self.context.lookup_variable(name) - if var: - var.is_unpacked = False - var.unpacked_names = [] - - # Add comment explaining why we can't re-unpack - statements.append( - Comment( - f"Note: Unpacked variables {old_element_names} invalidated " - "after function call - array size may have changed", - ), - ) - elif name in self.plan.arrays_to_unpack and name not in self.unpacked_vars: - # After function calls, don't automatically re-unpack arrays - # The array may have changed size and old unpacked variables are stale - # Instead, use array indexing for future references - statements.append( - Comment( - f"Note: Not re-unpacking {name} after function call - " - "array may have changed size, use array indexing instead", - ), - ) - - else: - # HYBRID TUPLE ASSIGNMENT: Choose strategy based on function and usage patterns - use_fresh_variables = self._should_use_fresh_variables( - func_name, - quantum_args, - ) - - if use_fresh_variables: - # Use fresh variables to avoid PlaceNotUsedError in problematic patterns - # Generate unique names to avoid reassignment issues in loops - if not hasattr(self, "_fresh_var_counter"): - self._fresh_var_counter = {} - - fresh_targets = [] - - # Check if we're in a consumption loop (conditional or not) - in_consumption_loop = ( - hasattr(self, "_in_conditional_consumption_loop") - and self._in_conditional_consumption_loop - and hasattr(self, "scope_manager") - and self.scope_manager.is_in_loop() - ) - - for arg in quantum_args: - # If we're in a consumption loop, - # reuse existing fresh names to avoid creating new variables in each iteration - if in_consumption_loop and arg in self.refreshed_arrays: - # Reuse the existing fresh variable name - fresh_name = self.refreshed_arrays[arg] - fresh_targets.append(fresh_name) - else: - base_name = f"{arg}_fresh" - # For loops and repeated calls, use unique suffixes - if base_name in self._fresh_var_counter: - self._fresh_var_counter[base_name] += 1 - unique_name = f"{base_name}_{self._fresh_var_counter[base_name]}" - else: - self._fresh_var_counter[base_name] = 0 - unique_name = base_name - fresh_targets.append(unique_name) - else: - # Standard tuple assignment - but check if we need to avoid borrowed variables - # OR if variables were unpacked before the call - fresh_targets = [] - for arg_idx, arg in enumerate(quantum_args): - # CRITICAL: Check if this parameter was already unpacked before the call - # If so, we MUST use a fresh variable name (can't assign to consumed variable) - # This is the same issue we fixed for single returns - was_unpacked = arg in unpacked_before_call - - # Check if this variable is a borrowed parameter (not @owned) - # If so, we need to use a different name to avoid BorrowShadowedError - is_borrowed = False - if hasattr(self, "current_function_name") and self.current_function_name: - # Check if this is a function parameter - func_info = self.function_info.get( - self.current_function_name, - {}, - ) - params = func_info.get("params", []) - for param_name, param_type in params: - if ( - param_name == arg - and "@owned" not in param_type - and "array[quantum.qubit" in param_type - ): - # This is a borrowed quantum array parameter - is_borrowed = True - break - - # Determine if we need a fresh name for any reason: - # 1. Variable was unpacked before call (consumed) - # 2. Variable is a borrowed parameter (can't shadow) - needs_fresh_name = was_unpacked or is_borrowed - - if needs_fresh_name: - # Use a fresh name to avoid: - # - AlreadyUsedError (if unpacked before call) - # - BorrowShadowedError (if borrowed parameter) - # Check if we're in a loop - if so, reuse the existing variable name - in_loop = hasattr(self, "scope_manager") and self.scope_manager.is_in_loop() - - if in_loop and hasattr(self, "refreshed_arrays") and arg in self.refreshed_arrays: - # In a loop, reuse the existing refreshed name to avoid undefined variable errors - fresh_name = self.refreshed_arrays[arg] - elif hasattr(self, "refreshed_arrays") and arg in self.refreshed_arrays: - # Not in a loop but already have a returned version, need a new unique name - if not hasattr(self, "_returned_var_counter"): - self._returned_var_counter = {} - base_name = f"{arg}_returned" - if base_name not in self._returned_var_counter: - self._returned_var_counter[base_name] = 1 - else: - self._returned_var_counter[base_name] += 1 - fresh_name = f"{base_name}_{self._returned_var_counter[base_name]}" - else: - # Choose suffix based on reason for fresh name - if was_unpacked: - # Use _fresh suffix for unpacked parameters (more descriptive) - fresh_name = self._get_unique_var_name( - f"{arg}_fresh", - ) - else: - # Use _returned suffix for borrowed parameters - fresh_name = f"{arg}_returned" - - fresh_targets.append(fresh_name) - - # Track this for later use - if not hasattr(self, "refreshed_arrays"): - self.refreshed_arrays = {} - self.refreshed_arrays[arg] = fresh_name - # Track which function refreshed this array and its position in return tuple - if not hasattr(self, "refreshed_by_function"): - self.refreshed_by_function = {} - self.refreshed_by_function[arg] = { - "function": func_name, - "position": arg_idx, - } - - # Also track in fresh_return_vars for cleanup in procedural functions - if was_unpacked: - if not hasattr(self, "fresh_return_vars"): - self.fresh_return_vars = {} - self.fresh_return_vars[fresh_name] = { - "original": arg, - "func_name": func_name, - "is_quantum_array": True, - } - else: - # Safe to use the original name (not unpacked, not borrowed) - fresh_targets.append(arg) - - class TupleAssignment(Statement): - def __init__(self, targets, value): - self.targets = targets - self.value = value - - def analyze(self, context): - self.value.analyze(context) - - def render(self, context): - target_str = ", ".join(self.targets) - value_str = self.value.render(context)[0] - return [f"{target_str} = {value_str}"] - - assignment = TupleAssignment(targets=fresh_targets, value=call) - statements.append(assignment) - - # Track all refreshed/returned variables for proper return handling - for i, original_name in enumerate(quantum_args): - if i < len(fresh_targets): - fresh_name = fresh_targets[i] - if fresh_name != original_name: - # This variable was renamed (either _fresh or _returned) - # Track it so return statements use the correct name - if not hasattr(self, "refreshed_arrays"): - self.refreshed_arrays = {} - # Always update the mapping for return handling - self.refreshed_arrays[original_name] = fresh_name - # Track which function refreshed this array and its position in return tuple - if not hasattr(self, "refreshed_by_function"): - self.refreshed_by_function = {} - self.refreshed_by_function[original_name] = { - "function": func_name, - "position": i, - } - - # Also track in fresh_return_vars for cleanup in procedural functions - # All fresh variables from tuple returns need cleanup tracking - if not hasattr(self, "fresh_return_vars"): - self.fresh_return_vars = {} - self.fresh_return_vars[fresh_name] = { - "original": original_name, - "func_name": func_name, - "is_quantum_array": True, - } - - # Check if any of the returned variables are structs and decompose them immediately - for var_name in fresh_targets: - # Check if this variable name corresponds to a struct - # It might be a fresh name (e.g., c_fresh) or original name (e.g., c) - struct_info = None - - if var_name in self.struct_info: - struct_info = self.struct_info[var_name] - else: - # Check if this is a renamed struct (e.g., c_fresh -> c) - # Be precise: only match if the variable is actually a renamed version of the struct - for key, info in self.struct_info.items(): - # Check for exact pattern: key_suffix (e.g., c_fresh) - if var_name == f"{key}_fresh" or var_name == f"{key}_returned": - struct_info = info - break - - if struct_info: - # Decompose fresh structs that will be used in loops - # This allows us to access fields without consuming the struct - struct_name = struct_info["struct_name"].replace("_struct", "") - decompose_func_name = f"{struct_name}_decompose" - - # Generate field variables for decomposition - field_vars = [] - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{var_name}_{suffix}" - field_vars.append(field_var) - - # Add decomposition statement - statements.append( - Comment(f"Decompose {var_name} for field access"), - ) - - decompose_call = FunctionCall( - func_name=decompose_func_name, - args=[VariableRef(var_name)], - ) - - decomposition_stmt = TupleAssignment( - targets=field_vars, - value=decompose_call, - ) - statements.append(decomposition_stmt) - - # Track decomposed variables - if not hasattr(self, "decomposed_vars"): - self.decomposed_vars = {} - field_mapping = {} - for suffix, field_type, field_size in sorted( - struct_info["fields"], - ): - field_var = f"{var_name}_{suffix}" - field_mapping[suffix] = field_var - self.decomposed_vars[var_name] = field_mapping - - # Handle variable mapping based on whether we used fresh variables - if use_fresh_variables: - statements.append( - Comment("Using fresh variables to avoid linearity conflicts"), - ) - - # Check if we're in a conditional within a loop - # This requires special handling to avoid linearity violations - (hasattr(self, "scope_manager") and self.scope_manager.is_in_conditional_within_loop()) - - # Update variable mapping so future references use the fresh names - # BUT only for functions that truly "refresh" the same arrays - # Functions like prep_zero_verify return different arrays, not refreshed inputs - refresh_functions = [ - "process_qubits", # Functions that process and return the same qubits - "apply_gates", # Functions that apply operations and return the same qubits - "measure_and_reset", # Functions that measure, reset, and return the same qubits - ] - - # Check if this function actually refreshes arrays (returns processed versions of inputs) - should_refresh_arrays = any(pattern in func_name.lower() for pattern in refresh_functions) - - # Additional check: if function has @owned parameters and returns fresh variables, - # it's likely refreshing the arrays - if not should_refresh_arrays and use_fresh_variables: - # Check if any fresh target names contain "fresh" - indicates array refreshing - has_fresh_returns = any("fresh" in target for target in fresh_targets) - if has_fresh_returns: - # Most quantum functions that return "fresh" variables are refreshing arrays - # This includes verification functions that return processed versions of inputs - should_refresh_arrays = True - - if should_refresh_arrays: - for i, original_name in enumerate(quantum_args): - if i < len(fresh_targets): - fresh_name = fresh_targets[i] - if fresh_name != original_name: # Only map if actually fresh - # Check if this is a conditional fresh variable (ending in _1) - if fresh_name.endswith("_1"): - # Don't update mapping for conditional variables to avoid errors - # Conditional consumption in loops is fundamentally incompatible - # with guppylang's linearity requirements - base_fresh_name = fresh_name[:-2] # Remove _1 suffix - self.conditional_fresh_vars[base_fresh_name] = fresh_name - elif original_name not in self.refreshed_arrays: - # Safe to update - first assignment - self.refreshed_arrays[original_name] = fresh_name - # Track which function refreshed this array and its position in return tuple - if not hasattr(self, "refreshed_by_function"): - self.refreshed_by_function = {} - self.refreshed_by_function[original_name] = { - "function": func_name, - "position": i, - } - self._update_context_for_returned_variable( - original_name, - fresh_name, - ) - else: - # For functions that return different arrays (like prep_zero_verify), - # don't map fresh variables as refreshed versions of inputs - # This allows proper reconstruction from unpacked variables in returns - pass - - # Immediately check if any fresh variables are likely to be unused - # and add discard for them - # Specifically, check for the ancilla pattern where ancilla_fresh is returned - # but not used after syndrome extraction - for i, original_name in enumerate(quantum_args): - if i < len(fresh_targets): - fresh_name = fresh_targets[i] - # Check if this is likely an ancilla array that won't be used - # Pattern: ancilla arrays that are measured inside the function - is_ancilla = "ancilla" in original_name.lower() - is_fresh = fresh_name != original_name - in_main = self.current_function_name == "main" - if is_ancilla and is_fresh and in_main: - # Check if we're in main (where ancillas are typically not reused) - # Add immediate discard for ancilla_fresh - statements.append( - Comment( - f"Discard unused {fresh_name} immediately", - ), - ) - discard_stmt = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(fresh_name)], - ) - - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - statements.append(ExpressionStatement(discard_stmt)) - else: - statements.append( - Comment("Standard tuple assignment to original variables"), - ) - # For standard assignment, variables keep their original names - # BUT don't overwrite if we already set a different mapping (e.g., for _returned variables) - for i, original_name in enumerate(quantum_args): - if i < len(fresh_targets): - fresh_name = fresh_targets[i] - # Only set to original name if we haven't already mapped to a different name - if fresh_name == original_name: - self.refreshed_arrays[original_name] = original_name - # If fresh_name != original_name, the mapping was already set above - - # Handle struct field invalidation after function call - for array_name in quantum_args: - if array_name in self.struct_info and hasattr( - self, - "var_remapping", - ): - struct_info = self.struct_info[array_name] - # Check if any of the struct's fields are in var_remapping - needs_update = any(var in self.var_remapping for var in struct_info["var_names"].values()) - - if needs_update: - # Cannot re-unpack - invalidate the unpacked variables - statements.append( - Comment( - "Note: Cannot use unpacked variables after calling function with @owned struct", - ), - ) - - # Update var_remapping to indicate these variables should not be used - for var_name in struct_info["var_names"].values(): - if var_name in self.var_remapping: - del self.var_remapping[var_name] - - # Unpack any arrays that need it after the function call - # BUT: Don't unpack if already unpacked (to avoid AlreadyUsedError) - for array_name in quantum_args: - if ( - array_name in self.plan.unpack_at_start - and array_name not in self.struct_info - and array_name in self.plan.arrays_to_unpack - and array_name not in self.unpacked_vars # Don't re-unpack! - ): - info = self.plan.arrays_to_unpack[array_name] - self._add_array_unpacking(array_name, info.size) - - # Check if current function is procedural (returns None) and add discards for unused quantum arrays - is_in_procedural = getattr(self, "current_function_is_procedural", False) - if is_in_procedural and len(statements) == 1: - # This is a procedural function with a single assignment (likely the last operation) - # Check if we have an unused quantum array to discard - # This happens when a procedural function calls a function that returns an array - # but doesn't use the result - stmt = statements[0] - if isinstance(stmt, Assignment): - # Check if this is an assignment to a quantum array - target_name = None - if hasattr(stmt.target, "name"): - target_name = stmt.target.name - - # Check if this is a quantum array by checking: - # 1. If it's in returned_quantum_args (passed as quantum param) - # 2. Or if func_name returns a quantum array (if we know the return type) - is_quantum_array = target_name in returned_quantum_args - - if not is_quantum_array and func_name in self.function_return_types: - return_type = self.function_return_types[func_name] - is_quantum_array = "array[quantum.qubit," in return_type - - if target_name and is_quantum_array: - # This is a quantum array that was assigned but may not be used - # Add a discard statement for it - discard_call = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(target_name)], - ) - - # Define ExpressionStatement locally if not already defined - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, _context): - return [] - - def render(self, context): - return self.expr.render(context) - - statements.append(Comment(f"Discard unused {target_name}")) - statements.append(ExpressionStatement(discard_call)) - - # Return block with all statements - if len(statements) == 1: - return statements[0] - return Block(statements=statements) - - # Either no quantum arrays OR function consumes its parameters - # In both cases, just call the function without assignment - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - return ExpressionStatement(call) - - def _function_returns_something(self, func_name: str) -> bool: - """Check if a function returns a value (not None).""" - # Functions that work with structs and return modified structs - # Check if this function name indicates it works with structs - if self.struct_info: - for info in self.struct_info.values(): - struct_name = info.get("struct_name", "") - # Extract the base name from the struct name (e.g., "steane" from "steane_struct") - if "_struct" in struct_name: - base_name = struct_name.replace("_struct", "").lower() - else: - base_name = struct_name.lower() - - if func_name.startswith(f"{base_name}_"): - # Struct functions typically return the modified struct - # Exception: functions ending in 'discard' or 'decompose' - # don't return the struct - return not (func_name.endswith(("_discard", "_decompose"))) - - # For other functions, assume they return something if they have quantum args - # This is a conservative approach - return False - - def _analyze_quantum_resource_flow( - self, - block, - ) -> tuple[dict[str, set[int]], dict[str, set[int]]]: - """Analyze which quantum resources are consumed vs. live in a block. - - Returns: - consumed_qubits: dict mapping qreg names to sets of consumed indices - live_qubits: dict mapping qreg names to sets of live indices - """ - consumed_qubits = {} - live_qubits = {} - - # Track all quantum variables used - all_quantum_vars = set() - - if hasattr(block, "ops"): - for op in block.ops: - # Check for measurements that consume qubits - if type(op).__name__ == "Measure": - if hasattr(op, "qargs"): - for qarg in op.qargs: - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - qreg_name = qarg.reg.sym - if hasattr(qarg, "index"): - # Single qubit measurement - if qreg_name not in consumed_qubits: - consumed_qubits[qreg_name] = set() - consumed_qubits[qreg_name].add(qarg.index) - elif hasattr(qarg, "sym"): - # Full array measurement - qreg_name = qarg.sym - if hasattr(qarg, "size"): - if qreg_name not in consumed_qubits: - consumed_qubits[qreg_name] = set() - consumed_qubits[qreg_name].update(range(qarg.size)) - - # Check for nested Block operations that may consume qubits - elif hasattr(op, "ops") and hasattr(op, "vars"): - # This is a nested block - analyze it recursively - nested_consumed, _nested_live = self._analyze_quantum_resource_flow( - op, - ) - - # Merge nested consumption into our tracking - for qreg_name, indices in nested_consumed.items(): - if qreg_name not in consumed_qubits: - consumed_qubits[qreg_name] = set() - consumed_qubits[qreg_name].update(indices) - - # Track all quantum variables used (for determining what's live) - if hasattr(op, "qargs"): - for qarg in op.qargs: - if isinstance(qarg, tuple): - for sub_qarg in qarg: - if hasattr(sub_qarg, "reg") and hasattr( - sub_qarg.reg, - "sym", - ): - all_quantum_vars.add(sub_qarg.reg.sym) - elif hasattr(sub_qarg, "sym"): - all_quantum_vars.add(sub_qarg.sym) - elif hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - all_quantum_vars.add(qarg.reg.sym) - elif hasattr(qarg, "sym"): - all_quantum_vars.add(qarg.sym) - - # Determine live qubits (used but not consumed) - # We need to know the actual size of arrays to determine what's live - # Get size information from the block's variable definitions - array_sizes = {} - - # Check all attributes of the block for QReg/CReg definitions - for attr_name in dir(block): - if not attr_name.startswith("_"): # Skip private attributes - try: - attr = getattr(block, attr_name, None) - if attr and hasattr(attr, "size") and hasattr(attr, "sym"): - array_sizes[attr.sym] = attr.size - # Add to all_quantum_vars if it's a quantum register - if hasattr(attr, "__class__") and "QReg" in attr.__class__.__name__: - all_quantum_vars.add(attr.sym) - except (AttributeError, TypeError): - # Ignore attributes without expected structure - pass - - # Also check variable context if available - if hasattr(self, "context") and self.context: - for var_name in all_quantum_vars: - var_info = self.context.lookup_variable(var_name) - if var_info and var_info.size: - array_sizes[var_name] = var_info.size - - # Pre-track explicit resets to know which consumed qubits are reset and should be considered live - consumed_for_tracking = {} - self._track_consumed_qubits(block, consumed_for_tracking) - - for var_name in all_quantum_vars: - if var_name not in consumed_qubits: - # Variable is used but not consumed - it's fully live - # Determine size from context or default - size = array_sizes.get(var_name, 2) # Default to 2 if unknown - live_qubits[var_name] = set(range(size)) - else: - # Check if only partially consumed - consumed_indices = consumed_qubits[var_name] - size = array_sizes.get(var_name, 2) # Default to 2 if unknown - - # Any indices not consumed OR explicitly reset are live - # Explicitly reset qubits are consumed by measurement but then recreated by Prep - explicitly_reset_indices = set() - if hasattr(self, "explicitly_reset_qubits") and var_name in self.explicitly_reset_qubits: - explicitly_reset_indices = self.explicitly_reset_qubits[var_name] - - live_indices = (set(range(size)) - consumed_indices) | explicitly_reset_indices - if live_indices: - live_qubits[var_name] = live_indices - - return consumed_qubits, live_qubits - - def _should_function_be_procedural( - self, - func_name: str, - block, - params, - has_live_qubits: bool, - ) -> bool: - """ - Smart detection to determine if a function should be procedural (return None) - vs functional (return tuple of quantum arrays). - - Functions should be procedural if they: - 1. Primarily do terminal operations (measurements without further quantum operations) - 2. Are not used in patterns where quantum returns are needed afterward - 3. Would cause PlaceNotUsedError issues with tuple returns - - Functions should be functional if they: - 1. Their quantum returns are needed for subsequent operations in the calling scope - 2. They are part of partial consumption patterns - """ - - # Pattern-based detection for known procedural functions - # BUT: only if they don't have live qubits - procedural_patterns = [ - "syndrome_extraction", # Terminal syndrome measurement blocks - "cleanup", # Cleanup operations - "discard", # Discard operations - ] - - # Check if this is an inner block that will be called by outer blocks - # Inner blocks should NOT be procedural to avoid consumption issues - if "inner" in func_name.lower(): - return False - - # Only apply pattern matching if there are no live qubits - # Functions with live qubits should return them, regardless of name - if not has_live_qubits: - for pattern in procedural_patterns: - if pattern in func_name.lower(): - # These are good candidates for procedural - return True - - # Functions with quantum parameters but no live qubits are good candidates for procedural - has_quantum_params = any("array[quantum.qubit," in param[1] for param in params if len(param) == 2) - - if has_quantum_params and not has_live_qubits: - # This is a terminal function - good candidate for procedural - return True - - # Check if this function would benefit from procedural approach based on operations - if hasattr(block, "ops"): - measurement_count = 0 - gate_count = 0 - - for op in block.ops: - if hasattr(op, "__class__"): - op_name = op.__class__.__name__ - if "Measure" in op_name: - measurement_count += 1 - elif hasattr(op, "name") or any(gate in str(op) for gate in ["H", "X", "Y", "Z", "CX", "CZ"]): - gate_count += 1 - - # If mostly measurements with no quantum gates, good candidate for procedural - # But be conservative - only if no gates at all or very few - # AND only if there are no live qubits to return (partial consumption must return live qubits) - if measurement_count > 0 and gate_count == 0 and not has_live_qubits: - return True - - # CONSERVATIVE: Default to functional approach unless clearly terminal - # This avoids breaking partial consumption patterns - return False - - def _should_use_fresh_variables(self, func_name: str, quantum_args: list) -> bool: - """ - Determine if fresh variables should be used for tuple assignment. - - Fresh variables help avoid PlaceNotUsedError when: - 1. Function has complex ownership patterns (@owned mixed with borrowed) - 2. Function might cause circular assignment issues - 3. Function is known to cause tuple assignment problems - """ - - # Known problematic patterns that benefit from fresh variables - fresh_variable_patterns = [ - "measure_ancillas", # Mixed ownership - some params consumed, some borrowed - "partial_consumption", # Partial consumption patterns - "process_qubits", # Functions that process and return quantum arrays - ] - - for pattern in fresh_variable_patterns: - if pattern in func_name.lower(): - return True - - # Check if we're inside a function that will return these values - # If the function will return these arrays, don't use fresh variables - # to avoid PlaceNotUsedError for unused fresh variables - special_funcs = ["prep_zero_verify", "prep_encoding_non_ft_zero"] - in_function = hasattr(self, "current_function_name") and self.current_function_name - if in_function and func_name in special_funcs: - # Check if this is the last statement in the function that will be returned - # For now, assume functions that manipulate and return the same arrays - # should NOT use fresh variables to avoid unused variable errors - # These functions return arrays that should be used directly - return False - - # If function has multiple quantum arguments, it might have mixed ownership - # Use fresh variables to be safe - if len(quantum_args) > 1 and hasattr(self, "current_block") and hasattr(self.current_block, "statements"): - # But check if we're at the end of a function where the result will be returned - # In that case, don't use fresh variables - # This is a heuristic - if there are not many statements after this, - # it's likely the return statement - return False # Don't use fresh variables for now - - # Default: use standard tuple assignment - return False - - def _fix_post_consuming_linearity_issues(self, body: Block) -> None: - """ - Fix linearity issues by adding fresh qubit allocations after consuming operations. - - When a qubit is consumed (e.g., by quantum.reset), and then used again later, - we need to allocate a fresh qubit to satisfy guppylang's linearity constraints. - """ - - # Track variables that have been consumed - new_statements = [] - - for stmt in body.statements: - # Add the current statement - new_statements.append(stmt) - - # Check if this statement consumes any variables - # Note: quantum.reset is now handled with assignment (qubit = quantum.reset(qubit)) - # so we no longer need to add automatic fresh qubit allocations - if hasattr(stmt, "expr") and hasattr(stmt.expr, "func_name"): - # Handle function calls that consume qubits - func_call = stmt.expr - if hasattr(func_call, "func_name") and func_call.func_name == "quantum.reset": - # quantum.reset now uses assignment, so no need for fresh allocation - # The reset operation returns the reset qubit - pass - - # Replace the statements - body.statements = new_statements - - def _fix_unused_fresh_variables(self, body: Block) -> None: - """ - Fix PlaceNotUsedError for fresh variables that may not be used in all execution paths. - - This handles the general pattern where: - 1. Fresh variables are created from function calls - 2. These variables are only used conditionally in loops - 3. Some fresh variables remain unconsumed, causing PlaceNotUsedError - """ - from pecos.slr.gen_codes.guppy.ir import Comment, FunctionCall, VariableRef - - # Define ExpressionStatement class for standalone function calls - class ExpressionStatement: - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - # General approach: find fresh variables that might be unused in conditional paths - fresh_variables_created = set() - fresh_variables_used_conditionally = set() - has_conditional_usage = False - - def collect_fresh_variables(statements): - """Recursively collect all fresh variables created and used.""" - for stmt in statements: - # Check if this is a Block and recurse into it - if hasattr(stmt, "statements"): - collect_fresh_variables(stmt.statements) - - # Find tuple assignments that create fresh variables - if hasattr(stmt, "targets") and len(stmt.targets) > 0: - for target in stmt.targets: - if isinstance(target, str) and "_fresh" in target: - fresh_variables_created.add(target) - - # Check for conditional statements (if/for) containing fresh variable usage - is_conditional = hasattr(stmt, "condition") or hasattr(stmt, "iterable") - has_body = hasattr(stmt, "body") and hasattr(stmt.body, "statements") - if is_conditional and has_body: # IfStatement or ForStatement - nonlocal has_conditional_usage - has_conditional_usage = True - # Look for fresh variable usage in conditional blocks - self._find_fresh_usage_in_statements( - stmt.body.statements, - fresh_variables_used_conditionally, - ) - - def find_procedural_functions_with_unused_fresh(): - """Find procedural functions (return None) that might have unused fresh variables.""" - if not (hasattr(self, "current_function_name") and self.current_function_name): - return False - - # Check if this is a procedural function that might have the pattern - # Method 1: Check if already recorded in function_return_types - if ( - hasattr(self, "function_return_types") - and self.function_return_types.get(self.current_function_name) == "None" - ): - return True - - # Method 2: Check if the function body has no return statements (procedural) - # This is a heuristic for functions that don't explicitly return values - has_return_stmt = any( - hasattr(stmt, "value") and hasattr(stmt, "__class__") and "return" in str(type(stmt)).lower() - for stmt in body.statements - ) - - # Method 3: Use pattern matching - functions that end with calls to other functions - # but don't return their results are likely procedural - if not has_return_stmt and len(body.statements) > 0: - last_stmt = body.statements[-1] - if hasattr(last_stmt, "expr") and hasattr(last_stmt.expr, "func_name"): - return True # Likely procedural if ends with a function call - - return False - - collect_fresh_variables(body.statements) - - is_procedural = find_procedural_functions_with_unused_fresh() - - # If we have fresh variables created and conditional usage patterns, - # and this is a procedural function, add discard statements for unused fresh variables - if fresh_variables_created and has_conditional_usage and is_procedural: - - # Find fresh variables that are likely unused in some execution paths - potentially_unused = fresh_variables_created - fresh_variables_used_conditionally - - # Also check which fresh variables are used after conditionals (shouldn't be discarded) - fresh_variables_used_after_conditionals = set() - self._find_fresh_usage_in_statements( - body.statements, - fresh_variables_used_after_conditionals, - ) - - # Only discard variables that are not used after conditionals - safe_to_discard = potentially_unused - fresh_variables_used_after_conditionals - - # Add discard statements before the last statement for potentially unused variables - last_stmt_idx = len(body.statements) - 1 - insert_offset = 0 - - for fresh_var in sorted(safe_to_discard): # Sort for consistent ordering - comment = Comment( - f"# Discard {fresh_var} to avoid PlaceNotUsedError in conditional paths", - ) - discard_call = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(fresh_var)], - ) - discard_stmt = ExpressionStatement(discard_call) - - # Insert before the last statement - body.statements.insert(last_stmt_idx + insert_offset, comment) - body.statements.insert(last_stmt_idx + insert_offset + 1, discard_stmt) - insert_offset += 2 - - def _find_fresh_usage_in_statements(self, statements, used_set): - """Helper to find fresh variable usage in a list of statements.""" - for stmt in statements: - if hasattr(stmt, "statements"): - self._find_fresh_usage_in_statements(stmt.statements, used_set) - - # Look for function calls that use fresh variables as arguments - if hasattr(stmt, "expr") and hasattr(stmt.expr, "args"): - for arg in stmt.expr.args: - if hasattr(arg, "name") and "_fresh" in arg.name: - used_set.add(arg.name) - - # Look for assignments that use fresh variables - if hasattr(stmt, "value") and hasattr(stmt.value, "args"): - for arg in stmt.value.args: - if hasattr(arg, "name") and "_fresh" in arg.name: - used_set.add(arg.name) - - def _update_context_for_returned_variable( - self, - original_name: str, - fresh_name: str, - ) -> None: - """Update context to redirect variable lookups from original to fresh name.""" - original_var = self.context.lookup_variable(original_name) - if original_var: - from pecos.slr.gen_codes.guppy.ir import ResourceState, VariableInfo - - # Create new variable info for the fresh returned variable - new_var_info = VariableInfo( - name=fresh_name, - original_name=fresh_name, - var_type=original_var.var_type, - size=original_var.size, - is_array=original_var.is_array, - state=ResourceState.AVAILABLE, - is_unpacked=original_var.is_unpacked, - unpacked_names=(original_var.unpacked_names.copy() if original_var.unpacked_names else []), - ) - - # Add the fresh variable to context - self.context.add_variable(new_var_info) - - # Add to refreshed arrays mapping for variable reference resolution - self.context.refreshed_arrays[original_name] = fresh_name - - # Mark the original variable as consumed since it was moved to the returned variable - self.context.consumed_resources.add(original_name) - - def _analyze_block_dependencies(self, block) -> dict[str, Any]: - """Analyze what variables a block depends on.""" - dependencies = { - "reads": set(), # Variables read - "writes": set(), # Variables written - "quantum": set(), # Quantum variables used - "classical": set(), # Classical variables used - } - - # Analyze operations in the block - if hasattr(block, "ops"): - for op in block.ops: - self._analyze_op_dependencies(op, dependencies, depth=0) - - return dependencies - - def _analyze_op_dependencies( - self, - op, - deps: dict[str, set], - depth: int = 0, - ) -> None: - """Analyze dependencies of a single operation.""" - op_type = type(op).__name__ - - # Handle quantum gates - if hasattr(op, "qargs"): - for qarg in op.qargs: - # Handle tuple arguments (e.g., CX gates with (control, target) pairs) - if isinstance(qarg, tuple): - for sub_qarg in qarg: - if hasattr(sub_qarg, "reg") and hasattr(sub_qarg.reg, "sym"): - var_name = sub_qarg.reg.sym - deps["reads"].add(var_name) - deps["quantum"].add(var_name) - elif hasattr(sub_qarg, "sym"): - var_name = sub_qarg.sym - deps["reads"].add(var_name) - deps["quantum"].add(var_name) - elif hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - var_name = qarg.reg.sym - deps["reads"].add(var_name) - deps["quantum"].add(var_name) - elif hasattr(qarg, "sym"): - # Direct QReg reference - var_name = qarg.sym - deps["reads"].add(var_name) - deps["quantum"].add(var_name) - - # Handle measurements - if op_type == "Measure": - if hasattr(op, "qargs"): - for qarg in op.qargs: - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - var_name = qarg.reg.sym - deps["reads"].add(var_name) - deps["quantum"].add(var_name) - elif hasattr(qarg, "sym"): - # Direct QReg reference - var_name = qarg.sym - deps["reads"].add(var_name) - deps["quantum"].add(var_name) - if hasattr(op, "cout") and op.cout: - for cout in op.cout: - if hasattr(cout, "reg") and hasattr(cout.reg, "sym"): - var_name = cout.reg.sym - deps["writes"].add(var_name) - deps["classical"].add(var_name) - elif hasattr(cout, "sym"): - # Direct CReg reference - var_name = cout.sym - deps["writes"].add(var_name) - deps["classical"].add(var_name) - - # Handle SET operations - if op_type == "SET": - if hasattr(op, "left") and hasattr(op.left, "reg"): - var_name = op.left.reg.sym - deps["writes"].add(var_name) - deps["classical"].add(var_name) - if hasattr(op, "right"): - self._analyze_expression_deps(op.right, deps) - - # Handle control flow - if op_type in ["If", "While", "For", "Repeat"]: - # Analyze condition - if hasattr(op, "condition"): - self._analyze_expression_deps(op.condition, deps) - # Analyze body operations - if hasattr(op, "ops"): - for sub_op in op.ops: - self._analyze_op_dependencies(sub_op, deps, depth + 1) - - # Handle nested blocks (but not too deep to avoid infinite recursion) - elif hasattr(op, "ops") and hasattr(op, "vars") and depth < 2: - # This is a block call - analyze it recursively but not too deep - for sub_op in op.ops: - self._analyze_op_dependencies(sub_op, deps, depth + 1) - - def _analyze_expression_deps(self, expr, deps: dict[str, set]) -> None: - """Analyze dependencies in an expression.""" - expr_type = type(expr).__name__ - - if expr_type == "Bit": - if hasattr(expr, "reg") and hasattr(expr.reg, "sym"): - var_name = expr.reg.sym - deps["reads"].add(var_name) - deps["classical"].add(var_name) - elif expr_type == "Qubit": - if hasattr(expr, "reg") and hasattr(expr.reg, "sym"): - var_name = expr.reg.sym - deps["reads"].add(var_name) - deps["quantum"].add(var_name) - elif hasattr(expr, "left") and hasattr(expr, "right"): - self._analyze_expression_deps(expr.left, deps) - self._analyze_expression_deps(expr.right, deps) - elif hasattr(expr, "value"): - self._analyze_expression_deps(expr.value, deps) - - def _add_final_handling(self, block) -> None: - """Handle struct decomposition, results, and cleanup in the correct order.""" - # First, decompose any structs that need cleanup - struct_decompositions = {} # prefix -> list of decomposed variable names - - for prefix, info in self.struct_info.items(): - # Check if this struct has unconsumed quantum fields - has_unconsumed_quantum = False - for suffix, var_type, size in info["fields"]: - if var_type == "qubit": - var_name = info["var_names"][suffix] - if var_name not in self.consumed_arrays: - has_unconsumed_quantum = True - break - - if has_unconsumed_quantum: - # Decompose the struct - qec_code_name = info.get("qec_code_name", prefix) - func_name = f"{qec_code_name}_decompose" if qec_code_name else f"{prefix}_decompose" - - # Generate variable names for decomposed fields - decomposed_vars = [] - for suffix, _, _ in sorted(info["fields"]): - decomposed_vars.append(f"{prefix}_{suffix}_final") - - # Create the decomposition call - targets = decomposed_vars - call = FunctionCall( - func_name=func_name, - args=[VariableRef(prefix)], - ) - - # Create assignment - target_tuple = TupleExpression( - elements=[VariableRef(name) for name in targets], - ) - stmt = Assignment(target=target_tuple, value=call) - - self.current_block.statements.append( - Comment(f"Decompose struct {prefix} for cleanup"), - ) - self.current_block.statements.append(stmt) - - # Store decomposition info - struct_decompositions[prefix] = list( - zip( - [f[0] for f in sorted(info["fields"])], # suffixes - decomposed_vars, - [f[1] for f in sorted(info["fields"])], # types - [f[2] for f in sorted(info["fields"])], # sizes - ), - ) - - # Now add results, using decomposed variables where necessary - self._add_results_with_decomposition(block, struct_decompositions) - - # Track what arrays have been cleaned up to avoid double-discard - cleaned_up_arrays = set() - - # Finally, clean up quantum arrays - self._add_cleanup_with_decomposition( - block, - struct_decompositions, - cleaned_up_arrays, - ) - - # Also run the regular cleanup for non-struct arrays - self._add_cleanup(block, cleaned_up_arrays) - - def _add_results_with_decomposition(self, block, struct_decompositions) -> None: - """Add result calls, using decomposed variables where necessary.""" - if hasattr(block, "vars"): - for var in block.vars: - if type(var).__name__ == "CReg": - var_name = var.sym - - # Check for renaming - actual_name = var_name - if var_name in self.plan.renamed_variables: - actual_name = self.plan.renamed_variables[var_name] - - # Check if this variable is part of a decomposed struct - value_ref = None - for prefix, info in self.struct_info.items(): - if var_name in info["var_names"].values(): - # Find the field name for this variable - for suffix, mapped_var in info["var_names"].items(): - if mapped_var == var_name: - # Check if struct was decomposed - if prefix in struct_decompositions: - # Find the decomposed variable - for ( - field_suffix, - decomposed_var, - _, - _, - ) in struct_decompositions[prefix]: - if field_suffix == suffix: - value_ref = VariableRef(decomposed_var) - break - else: - # Struct not decomposed, use field access - value_ref = FieldAccess( - obj=VariableRef(prefix), - field=suffix, - ) - break - break - - if value_ref is None: - # Check if this array was unpacked - # Check both var_name (original) and actual_name (renamed) - is_unpacked = var_name in self.plan.arrays_to_unpack or ( - hasattr(self, "unpacked_vars") - and (var_name in self.unpacked_vars or actual_name in self.unpacked_vars) - ) - - if is_unpacked: - # Array was unpacked - must reconstruct from elements for linearity - element_names = None - if hasattr(self, "unpacked_vars"): - # Try original name first, then renamed name - if var_name in self.unpacked_vars: - element_names = self.unpacked_vars[var_name] - elif actual_name in self.unpacked_vars: - element_names = self.unpacked_vars[actual_name] - - if element_names: - # Reconstruct the array and assign it back to the original variable - reconstruction_expr = self._create_array_reconstruction( - element_names, - ) - reconstruction_stmt = Assignment( - target=VariableRef(actual_name), - value=reconstruction_expr, - ) - self.current_block.statements.append( - reconstruction_stmt, - ) - value_ref = VariableRef(actual_name) - else: - # Fallback: use original array if unpacked_vars not available - value_ref = VariableRef(actual_name) - else: - # Not unpacked, use direct variable reference - value_ref = VariableRef(actual_name) - - # Add result call - call = FunctionCall( - func_name="result", - args=[ - Literal(var.sym), # Original name as label - value_ref, # Actual variable or decomposed field - ], - ) - - # Create a wrapper that renders just the function call - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append(ExpressionStatement(call)) - - def _add_cleanup_with_decomposition( - self, - block, - struct_decompositions, - cleaned_up_arrays, - ) -> None: - _ = block # Currently not used - """Add cleanup for quantum arrays, using decomposed variables.""" - # First handle decomposed struct fields - for prefix, fields in struct_decompositions.items(): - self.current_block.statements.append( - Comment(f"Discard quantum fields from {prefix}"), - ) - for suffix, decomposed_var, var_type, size in fields: - if var_type == "qubit" and decomposed_var not in cleaned_up_arrays: - stmt = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(decomposed_var)], - ) - cleaned_up_arrays.add(decomposed_var) - # Also track the original variable name to prevent double cleanup - if prefix in self.struct_info: - info = self.struct_info[prefix] - if suffix in info["var_names"]: - original_var = info["var_names"][suffix] - cleaned_up_arrays.add(original_var) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append(ExpressionStatement(stmt)) - - # Note: Non-struct arrays are handled in _add_cleanup, not here - - def _add_cleanup(self, block, cleaned_up_arrays=None) -> None: - """Add cleanup for unconsumed qubits.""" - if cleaned_up_arrays is None: - cleaned_up_arrays = set() - # Track consumed qubits during operation conversion - consumed = {} # qreg_name -> set of indices - - # Analyze operations to find consumed qubits - if hasattr(block, "ops"): - for op in block.ops: - self._track_consumed_qubits(op, consumed) - - # First, check if we have structs that need cleanup - struct_cleanup_done = set() - for prefix, info in self.struct_info.items(): - # Check if any quantum arrays in this struct need cleanup - needs_cleanup = False - for suffix, var_type, size in info["fields"]: - if var_type == "qubit": - var_name = info["var_names"][suffix] - if var_name not in self.consumed_arrays: - needs_cleanup = True - break - - if needs_cleanup and prefix not in struct_cleanup_done: - # We're at the end of main, after results. - # We can't access struct fields directly after consuming the struct, - # so we'll just leave quantum arrays in structs for now. - # The HUGR compiler will need to handle this pattern. - - # Add a comment noting this limitation - self.current_block.statements.append( - Comment( - f"Note: struct {prefix} contains unconsumed quantum arrays", - ), - ) - - struct_cleanup_done.add(prefix) - # Mark arrays as handled - for suffix, var_type, size in info["fields"]: - if var_type == "qubit": - var_name = info["var_names"][suffix] - self.consumed_arrays.add(var_name) - - # First handle fresh variables from function returns - if hasattr(self, "fresh_variables_to_track"): - for fresh_name, info in self.fresh_variables_to_track.items(): - if info["type"] == "quantum_array" and not info.get("used", False): - # This fresh variable was not used, add cleanup - # Check if it was already cleaned up (e.g., by being measured) - original_name = info["original"] - was_consumed = (hasattr(self, "consumed_arrays") and original_name in self.consumed_arrays) or ( - hasattr(self, "consumed_resources") and original_name in self.consumed_resources - ) - - if not was_consumed and fresh_name not in cleaned_up_arrays: - self.current_block.statements.append( - Comment(f"Discard unused fresh variable {fresh_name}"), - ) - # Need to check if this is an array or needs special handling - # For now, assume it's a quantum array that needs discard_array - stmt = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(fresh_name)], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append(ExpressionStatement(stmt)) - cleaned_up_arrays.add(fresh_name) - - # Check each quantum register not in structs - if hasattr(block, "vars"): - - for var in block.vars: - if type(var).__name__ == "QReg": - var_name = var.sym - - # Skip if this array is part of a struct - in_struct = False - for prefix, info in self.struct_info.items(): - if var_name in info["var_names"].values(): - in_struct = True - break - - if in_struct: - continue - # Check for renaming - if var_name in self.plan.renamed_variables: - var_name = self.plan.renamed_variables[var_name] - - consumed_indices = consumed.get(var.sym, set()) - - # Check if this array was consumed by an @owned function or measurement - was_consumed_by_function = hasattr(self, "consumed_arrays") and var.sym in self.consumed_arrays - - was_consumed_by_measurement = ( - hasattr(self, "consumed_resources") and var.sym in self.consumed_resources - ) - was_dynamically_allocated = ( - hasattr(self, "dynamic_allocations") and var.sym in self.dynamic_allocations - ) - - # Handle partially consumed arrays - # BUT: Skip if the whole array was consumed by an @owned function - if len(consumed_indices) > 0 and len(consumed_indices) < var.size and not was_consumed_by_function: - # Array was partially consumed - need to discard entire array - if var_name not in cleaned_up_arrays: - self.current_block.statements.append( - Comment(f"Discard {var.sym}"), - ) - stmt = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(var_name)], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append( - ExpressionStatement(stmt), - ) - cleaned_up_arrays.add(var_name) - # Only discard arrays that weren't consumed by @owned functions or measurements - # UNLESS they have explicitly reset qubits (which need cleanup) - elif True: - # Check if this array has explicitly reset qubits (from Prep operations) - # Even if consumed by measurement, explicitly reset qubits need cleanup - has_explicitly_reset = ( - hasattr(self, "explicitly_reset_qubits") - and var.sym in self.explicitly_reset_qubits - and len(self.explicitly_reset_qubits[var.sym]) > 0 - ) - - if not was_consumed_by_function and (not was_consumed_by_measurement or has_explicitly_reset): - if was_dynamically_allocated: - # For dynamically allocated arrays, discard individual - # qubits that weren't measured - self.current_block.statements.append( - Comment(f"Discard dynamically allocated {var.sym}"), - ) - - # Check which individual qubits were allocated and not consumed - if hasattr(self, "allocated_ancillas"): - # Track which variables we've already discarded to avoid duplicates - discarded_vars = set() - - # Discard each allocated ancilla that belongs to this qreg - # We need to check all allocated ancillas that start with the qreg name - for ancilla_var in list(self.allocated_ancillas): - # Check if this ancilla belongs to the current qreg - # It should start with the qreg name followed by underscore - if ancilla_var.startswith( - (f"{var.sym}_", f"_{var.sym}_"), - ): - # Apply variable remapping if exists (for Prep operations) - var_to_discard = self.variable_remapping.get( - ancilla_var, - ancilla_var, - ) - - # Skip if we've already discarded this variable - if var_to_discard in discarded_vars: - continue - discarded_vars.add(var_to_discard) - - discard_stmt = FunctionCall( - func_name="quantum.discard", - args=[VariableRef(var_to_discard)], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append( - ExpressionStatement(discard_stmt), - ) - else: - # Regular pre-allocated array - - # Skip if already consumed by a function call - # Also check if the remapped name was consumed - remapped_consumed = False - if hasattr(self, "array_remapping") and var_name in self.array_remapping: - remapped_name = self.array_remapping[var_name] - if hasattr(self, "consumed_arrays") and remapped_name in self.consumed_arrays: - remapped_consumed = True - - # Check if array has explicitly reset qubits (from Prep operations) - # These need cleanup even if consumed by measurement - has_explicitly_reset = ( - hasattr(self, "explicitly_reset_qubits") - and var.sym in self.explicitly_reset_qubits - and len(self.explicitly_reset_qubits[var.sym]) > 0 - ) - - # Check if array was consumed by an @owned function call or by measurements - array_consumed = ( - hasattr(self, "consumed_arrays") - and (var.sym in self.consumed_arrays or var_name in self.consumed_arrays) - ) or ( - hasattr(self, "consumed_resources") - and (var.sym in self.consumed_resources or var_name in self.consumed_resources) - ) - - # Also check if this is a reconstructed array that was passed to a function - is_reconstructed = ( - hasattr(self, "reconstructed_arrays") and var_name in self.reconstructed_arrays - ) - - # Allow cleanup if: - # 1. Array not already cleaned up - # 2. Either not consumed OR has explicitly reset qubits - # 3. Remapped version not consumed - # 4. Not a reconstructed array - if ( - var_name not in cleaned_up_arrays - and (not array_consumed or has_explicitly_reset) - and not remapped_consumed - and not is_reconstructed - ): - # Check if this array has been unpacked or remapped - # If so, we can't discard the original name - if hasattr(self, "unpacked_vars") and var_name in self.unpacked_vars: - # Array was unpacked - check if it has explicitly reset qubits - explicitly_reset_indices = set() - if ( - hasattr(self, "explicitly_reset_qubits") - and var_name in self.explicitly_reset_qubits - ): - explicitly_reset_indices = self.explicitly_reset_qubits[var_name] - - if explicitly_reset_indices: - # Check if we already did inline reconstruction - # If so, skip cleanup reconstruction to avoid AlreadyUsedError - skip_cleanup_reconstruction = ( - hasattr( - self, - "inline_reconstructed_arrays", - ) - and var_name in self.inline_reconstructed_arrays - ) - - if not skip_cleanup_reconstruction: - # Array has fresh qubits from Prep - reconstruct and discard - comment_text = ( - f"Reconstruct {var.sym} from unpacked " - f"elements (has fresh qubits)" - ) - self.current_block.statements.append( - Comment(comment_text), - ) - - # Get unpacked element names (it's a list, not a dict) - element_names = self.unpacked_vars[var_name] - - # Apply variable remapping to get the latest names - remapped_element_names = [ - self.variable_remapping.get( - elem, - elem, - ) - for elem in element_names - ] - - # Reconstruct array: var = array(elem1, elem2, ...) - array_elements = [VariableRef(elem) for elem in remapped_element_names] - array_constructor = FunctionCall( - func_name="array", - args=array_elements, - ) - reconstruct_stmt = Assignment( - target=VariableRef(var_name), - value=array_constructor, - ) - self.current_block.statements.append( - reconstruct_stmt, - ) - - # Now discard the reconstructed array - self.current_block.statements.append( - Comment( - f"Discard reconstructed {var.sym}", - ), - ) - array_ref = VariableRef(var_name) - stmt = FunctionCall( - func_name="quantum.discard_array", - args=[array_ref], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append( - ExpressionStatement(stmt), - ) - cleaned_up_arrays.add(var_name) - # Skip the remapping/normal discard code below - continue - # Array was unpacked and fully consumed - skip discard - self.current_block.statements.append( - Comment( - f"Skip discard {var.sym} - already unpacked and consumed", - ), - ) - continue - if hasattr(self, "array_remapping") and var_name in self.array_remapping: - # Array was remapped - use the new name - remapped_name = self.array_remapping[var_name] - self.current_block.statements.append( - Comment( - f"Discard {var.sym} (remapped to {remapped_name})", - ), - ) - array_ref = VariableRef(remapped_name) - else: - # Normal case - use original name - self.current_block.statements.append( - Comment(f"Discard {var.sym}"), - ) - array_ref = VariableRef(var_name) - - stmt = FunctionCall( - func_name="quantum.discard_array", - args=[array_ref], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append( - ExpressionStatement(stmt), - ) - cleaned_up_arrays.add(var_name) - - def _check_has_element_operations(self, block, var_name: str) -> bool: - """Check if a block has element-wise operations on a variable. - - This is used to determine if we should use @owned for array parameters. - Element-wise operations (like reset on individual elements) don't work - with @owned arrays in Guppy. - """ - if not hasattr(block, "ops"): - return False - - for op in block.ops: - op_type = type(op).__name__ - - # Check for Prep operations on the whole array - if op_type == "Prep" and hasattr(op, "qargs"): - for qarg in op.qargs: - if hasattr(qarg, "sym") and qarg.sym == var_name: - # Prep on the whole array - this needs element access - return True - - # Check for operations on individual elements - if hasattr(op, "qargs"): - for qarg in op.qargs: - if ( - hasattr(qarg, "reg") - and hasattr(qarg.reg, "sym") - and qarg.reg.sym == var_name - and hasattr(qarg, "index") - and op_type in ["Prep", "Measure"] - ): - return True - - # Recursively check nested blocks - if hasattr(op, "ops") and self._check_has_element_operations(op, var_name): - return True - - return False - - def _track_consumed_qubits(self, op, consumed: dict[str, set[int]]) -> None: - """Track which qubits are consumed by an operation or block. - - Also tracks explicit Prep (reset) operations to distinguish them from - automatic post-measurement replacements. - """ - op_type = type(op).__name__ - - # Handle Block types - recurse into their operations - if hasattr(op, "ops") and op_type not in ["Measure", "If", "Else", "While"]: - # This is a custom Block - analyze its operations - for nested_op in op.ops: - self._track_consumed_qubits(nested_op, consumed) - return - - # Track explicit Prep operations - these are semantic resets that should be returned - if op_type == "Prep" and hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - # Handle full array reset - if hasattr(qarg, "sym") and hasattr(qarg, "size"): - qreg_name = qarg.sym - if qreg_name not in self.explicitly_reset_qubits: - self.explicitly_reset_qubits[qreg_name] = set() - # Track all indices as explicitly reset - for i in range(qarg.size): - self.explicitly_reset_qubits[qreg_name].add(i) - # Handle individual qubit reset - elif hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - qreg_name = qarg.reg.sym - if qreg_name not in self.explicitly_reset_qubits: - self.explicitly_reset_qubits[qreg_name] = set() - - if hasattr(qarg, "index"): - self.explicitly_reset_qubits[qreg_name].add(qarg.index) - - if op_type == "Measure" and hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - # Handle full array measurement - if hasattr(qarg, "sym") and hasattr(qarg, "size"): - qreg_name = qarg.sym - if qreg_name not in consumed: - consumed[qreg_name] = set() - # Mark all qubits as consumed - indices = set(range(qarg.size)) - for i in indices: - consumed[qreg_name].add(i) - # Track in scope manager - self.scope_manager.track_resource_usage( - qreg_name, - indices, - consumed=True, - ) - # Handle individual qubit measurement - elif hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - qreg_name = qarg.reg.sym - if qreg_name not in consumed: - consumed[qreg_name] = set() - - if hasattr(qarg, "index"): - consumed[qreg_name].add(qarg.index) - # Track in scope manager - self.scope_manager.track_resource_usage( - qreg_name, - {qarg.index}, - consumed=True, - ) - - # Don't recurse into nested blocks that are separate function calls - # They handle their own consumption and return fresh qubits - # Only recurse into inline blocks (like If/Else) - if hasattr(op, "ops") and op_type in ["If", "Else", "While"]: - for nested_op in op.ops: - self._track_consumed_qubits(nested_op, consumed) - - # Check else blocks - if op_type == "If" and hasattr(op, "else_block") and op.else_block and hasattr(op.else_block, "ops"): - for nested_op in op.else_block.ops: - self._track_consumed_qubits(nested_op, consumed) - - def _array_needs_full_allocation(self, array_name: str, block) -> bool: - """Check if an array needs full allocation due to full array operations.""" - if not hasattr(block, "ops"): - return False - - for op in block.ops: - if self._operation_uses_full_array(op, array_name): - return True - - # Check nested operations - if hasattr(op, "ops"): - for nested_op in op.ops: - if self._operation_uses_full_array(nested_op, array_name): - return True - - # Check else blocks - if hasattr(op, "else_block") and op.else_block and hasattr(op.else_block, "ops"): - for nested_op in op.else_block.ops: - if self._operation_uses_full_array(nested_op, array_name): - return True - - return False - - def _operation_uses_full_array(self, op, array_name: str) -> bool: - """Check if an operation uses a full array (e.g., Measure(q) > c).""" - if hasattr(op, "qargs") and len(op.qargs) == 1: - qarg = op.qargs[0] - # Check for full array reference (has sym and size but no index) - if ( - hasattr(qarg, "sym") - and qarg.sym == array_name - and hasattr(qarg, "size") - and qarg.size > 1 - and not hasattr(qarg, "index") - ): - return True - return False - - def _add_results(self, block) -> None: - """Add result() calls for classical registers.""" - # Debug: Uncomment to see unpacked_vars state - if hasattr(block, "vars"): - for var in block.vars: - if type(var).__name__ == "CReg": - var_name = var.sym - - # Check for renaming - actual_name = var_name - if var_name in self.plan.renamed_variables: - actual_name = self.plan.renamed_variables[var_name] - - # Check if this variable is part of a struct - value_ref = None - for prefix, info in self.struct_info.items(): - if var_name in info["var_names"].values(): - # Find the field name for this variable - for suffix, mapped_var in info["var_names"].items(): - if mapped_var == var_name: - # Access through struct field - value_ref = FieldAccess( - obj=VariableRef(prefix), - field=suffix, - ) - break - break - - if value_ref is None: - # Check if this array was unpacked - # Check both var_name (original) and actual_name (renamed) - is_unpacked = var_name in self.plan.arrays_to_unpack or ( - hasattr(self, "unpacked_vars") - and (var_name in self.unpacked_vars or actual_name in self.unpacked_vars) - ) - - if is_unpacked: - # Array was unpacked - must reconstruct from elements for linearity - element_names = None - if hasattr(self, "unpacked_vars"): - # Try original name first, then renamed name - if var_name in self.unpacked_vars: - element_names = self.unpacked_vars[var_name] - elif actual_name in self.unpacked_vars: - element_names = self.unpacked_vars[actual_name] - - if element_names: - value_ref = self._create_array_reconstruction( - element_names, - ) - else: - # Fallback: use original array if unpacked_vars not available - value_ref = VariableRef(actual_name) - else: - # Not unpacked, use direct variable reference - value_ref = VariableRef(actual_name) - - # Add result call - call = FunctionCall( - func_name="result", - args=[ - Literal(var.sym), # Original name as label - value_ref, # Actual variable or struct field - ], - ) - - # Create a wrapper that renders just the function call - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - self.current_block.statements.append(ExpressionStatement(call)) - - def _detect_struct_patterns(self, block: SLRBlock) -> None: - """Detect variables that should be grouped into structs. - - Looking for patterns where multiple variables share a common prefix - (e.g., x_d, x_a, x_c all belong to quantum code 'x'). - """ - # First, try to determine the quantum code class from variable metadata - qec_code_name = None - qec_instance_mapping = {} # Maps instance name -> class name - - # Check if block.vars has source class information - if hasattr(block, "vars") and hasattr(block.vars, "var_source_classes"): - # Get the source class from the metadata - for var_name, source_class in block.vars.var_source_classes.items(): - # Extract the prefix from the variable name - if "_" in var_name: - prefix = var_name.split("_")[0] - if prefix not in qec_instance_mapping: - qec_instance_mapping[prefix] = source_class.lower() - if not qec_code_name: - qec_code_name = source_class.lower() - - # If no QEC class found in vars, fall back to searching operations - if not qec_code_name: - # Helper function to recursively search for QEC code - def find_qec_code_in_block(op, depth=0, max_depth=5): - if depth > max_depth: - return None - - results = [] - - # Check if this op has QEC module info - if hasattr(op, "__class__") and hasattr(op.__class__, "__module__"): - module = op.__class__.__module__ - # Extract QEC code name from module path like - # 'pecos.qeclib.steane.preps.pauli_states' - if "pecos.qeclib." in module: - parts = module.split(".") - if len(parts) > 2 and "qeclib" in parts: - qec_idx = parts.index("qeclib") - if qec_idx + 1 < len(parts): - candidate = parts[qec_idx + 1] - # Skip generic names like 'qubit' - if candidate not in ["qubit", "bit", "ops", "gates"]: - results.append(candidate) - - # Check nested operations - if hasattr(op, "ops"): - for nested_op in op.ops: - result = find_qec_code_in_block(nested_op, depth + 1, max_depth) - if result: - results.append(result) - - # Return the first non-generic result - for r in results: - if r not in ["qubit", "bit", "ops", "gates"]: - return r - - return results[0] if results else None - - # Try to find the QEC code class from the operations - if hasattr(block, "ops"): - for op in block.ops: - qec_code_name = find_qec_code_in_block(op) - if qec_code_name: - break - - # Collect all variables - all_vars = {} - if hasattr(block, "vars"): - for var in block.vars: - if hasattr(var, "sym"): - var_name = var.sym - all_vars[var_name] = var - - # Also check context variables - for var_name, var_info in self.context.variables.items(): - if var_name not in all_vars: - all_vars[var_name] = var_info - - # Group by prefix - prefix_groups = {} - for var_name, var in all_vars.items(): - if "_" in var_name: - prefix = var_name.split("_")[0] - suffix = "_".join(var_name.split("_")[1:]) - - if prefix not in prefix_groups: - prefix_groups[prefix] = [] - - # Determine type and size - size = var.size if hasattr(var, "size") else 1 - - # Determine if quantum or classical - is_quantum = True - if hasattr(var, "is_quantum"): - is_quantum = var.is_quantum - elif type(var).__name__ == "CReg": - is_quantum = False - elif hasattr(var, "resource_type"): - is_quantum = var.resource_type == ResourceState.QUANTUM - - var_type = "qubit" if is_quantum else "bool" - - # Check if this is an ancilla qubit that should be kept separate - is_ancilla = False - if var_type == "qubit" and hasattr(self, "qubit_usage_stats"): - stats = self.qubit_usage_stats.get(var_name) - if stats: - role = stats.classify_role() - if role == QubitRole.ANCILLA: - is_ancilla = True - # Store this for later use - if not hasattr(self, "ancilla_qubits"): - self.ancilla_qubits = set() - self.ancilla_qubits.add(var_name) - - if not is_ancilla: - prefix_groups[prefix].append((suffix, var_type, size, var_name)) - - # Create struct info for groups with multiple related variables - # BUT avoid structs with too many fields due to guppylang limitations - # Setting to 5 to be very conservative - complex QEC codes need individual array handling - max_struct_fields = 5 # Limit to avoid guppylang linearity issues - - for prefix, vars_list in prefix_groups.items(): - if len(vars_list) >= 2: - # Check if this looks like a quantum code pattern - has_quantum = any(var[1] == "qubit" for var in vars_list) - - # Skip struct creation if too many fields (causes guppylang issues) - if len(vars_list) > max_struct_fields: - msg = ( - f"# Skipping struct creation for '{prefix}' with " - f"{len(vars_list)} fields (exceeds limit of {max_struct_fields})" - ) - print(msg) - continue - - if has_quantum: - # Use QEC code name for struct if available, otherwise use prefix - struct_base_name = qec_code_name if qec_code_name else prefix - - self.struct_info[prefix] = { - "fields": [(v[0], v[1], v[2]) for v in vars_list], - "struct_name": f"{struct_base_name}_struct", - "var_names": {v[0]: v[3] for v in vars_list}, # suffix -> full var name - "qec_code_name": qec_code_name, # Store for function naming - "ancilla_vars": getattr( - self, - "ancilla_qubits", - set(), - ), # Track which vars were excluded - } - - def _generate_struct_definitions(self) -> list[str]: - """Generate Guppy struct definitions.""" - lines = [] - - for prefix, info in sorted(self.struct_info.items()): - struct_name = info["struct_name"] - - # Generate struct - lines.append("@guppy.struct") - lines.append("@no_type_check") - lines.append(f"class {struct_name}:") - - # Add fields sorted by suffix - for suffix, var_type, size in sorted(info["fields"]): - field_type = f"array[{var_type}, {size}]" if size > 1 else var_type - lines.append(f" {suffix}: {field_type}") - - lines.append("") # Empty line after struct - - return lines - - def _generate_struct_decompose_function( - self, - prefix: str, - info: dict, - ) -> Function | None: - """Generate a decompose function for a struct.""" - struct_name = info["struct_name"] - qec_code_name = info.get("qec_code_name", prefix) - func_name = f"{qec_code_name}_decompose" if qec_code_name else f"{prefix}_decompose" - - # Build return type - tuple of all fields - return_types = [] - field_names = [] - for suffix, var_type, size in sorted(info["fields"]): - field_names.append(suffix) - return_types.append( - f"array[{var_type}, {size}]" if size > 1 else var_type, - ) - - return_type = f"tuple[{', '.join(return_types)}]" - - # Create function body - body = Block() - - # The key to avoiding AlreadyUsedError: return all fields in a single expression - # This works because guppylang handles the struct consumption atomically - field_refs = [FieldAccess(obj=VariableRef(prefix), field=suffix) for suffix in field_names] - - # Return all fields directly in one statement - return_stmt = ReturnStatement(value=TupleExpression(elements=field_refs)) - body.statements.append(return_stmt) - - return Function( - name=func_name, - params=[(prefix, f"{struct_name} @owned")], - return_type=return_type, - body=body, - decorators=["guppy", "no_type_check"], - ) - - def _generate_struct_discard_function( - self, - prefix: str, - info: dict, - ) -> Function | None: - """Generate a discard function for a struct.""" - # Check if struct has quantum fields - has_quantum = any(field[1] == "qubit" for field in info["fields"]) - if not has_quantum: - return None - - struct_name = info["struct_name"] - qec_code_name = info.get("qec_code_name", prefix) - func_name = f"{qec_code_name}_discard" if qec_code_name else f"{prefix}_discard" - - # Create function body - body = Block() - - # We need to handle discard differently to avoid AlreadyUsedError - # First decompose the struct, then discard quantum fields - - # Build list of field names for decomposition - field_names = [suffix for suffix, _, _ in sorted(info["fields"])] - - # Call decompose to get all fields - decompose_func_name = f"{qec_code_name}_decompose" if qec_code_name else f"{prefix}_decompose" - decompose_call = FunctionCall( - func_name=decompose_func_name, - args=[VariableRef(prefix)], - ) - - # Create variables to hold decomposed fields - field_vars = [f"_{suffix}" if suffix == prefix else suffix for suffix in field_names] - - # Define TupleAssignment locally - class TupleAssignment(Statement): - def __init__(self, targets, value): - self.targets = targets - self.value = value - - def analyze(self, context): - self.value.analyze(context) - - def render(self, context): - targets_str = ", ".join(self.targets) - value_lines = self.value.render(context) - # FunctionCall render returns a list with one string - value_str = value_lines[0] if value_lines else "" - return [f"{targets_str} = {value_str}"] - - decompose_stmt = TupleAssignment( - targets=field_vars, - value=decompose_call, - ) - body.statements.append(decompose_stmt) - - # Now discard quantum fields - for i, (suffix, var_type, size) in enumerate(sorted(info["fields"])): - if var_type == "qubit": - field_var = field_vars[i] - stmt = FunctionCall( - func_name="quantum.discard_array", - args=[VariableRef(field_var)], - ) - - # Create expression statement wrapper - class ExpressionStatement(Statement): - def __init__(self, expr): - self.expr = expr - - def analyze(self, context): - self.expr.analyze(context) - - def render(self, context): - return self.expr.render(context) - - body.statements.append(ExpressionStatement(stmt)) - - return Function( - name=func_name, - params=[(prefix, f"{struct_name} @owned")], - return_type="None", - body=body, - decorators=["guppy", "no_type_check"], - ) - - def _add_struct_initialization( - self, - prefix: str, - info: dict, - block: SLRBlock, - ) -> None: - """Add struct initialization to current block.""" - struct_name = info["struct_name"] - - # Create the struct instance - # For now, initialize fields individually then create struct - # TODO: Could be optimized to initialize struct directly - - # First, declare the individual arrays - for suffix, var_type, size in info["fields"]: - var_name = info["var_names"][suffix] - # Find the original variable - for var in block.vars: - if hasattr(var, "sym") and var.sym == var_name: - self._add_variable_declaration(var) - break - - # Then create struct instance - field_refs = [] - for suffix, _, _ in sorted(info["fields"]): - var_name = info["var_names"][suffix] - field_refs.append(VariableRef(var_name)) - - # Create struct construction expression - struct_expr = self._create_struct_construction( - struct_name, - [f[0] for f in sorted(info["fields"])], - field_refs, - ) - - # Add assignment: prefix = struct_name(field1=var1, field2=var2, ...) - stmt = Assignment( - target=VariableRef(prefix), - value=struct_expr, - ) - self.current_block.statements.append(stmt) - - # Update context to track struct variable - self.context.add_variable( - VariableInfo( - name=prefix, - original_name=prefix, - var_type=struct_name, - is_struct=True, - struct_info=info, - ), - ) - - # Mark the individual arrays as part of the struct so operations use struct fields - for suffix, var_type, size in info["fields"]: - var_name = info["var_names"][suffix] - var_info = self.context.lookup_variable(var_name) - if var_info: - var_info.is_struct_field = True - var_info.struct_name = prefix - var_info.field_name = suffix - - def _restore_array_sizes_for_block_call(self, block) -> None: - """Restore array sizes before a function call in a loop. - - When a function returns a smaller array than it receives (e.g., consuming qubits), - and that result is used in a loop to call the same function again, we need to - restore the array size by allocating fresh qubits before the next call. - - This implements the user's guidance: "We could prepare them right before we need them" - """ - - # Check if this is a block that will become a function call - if not hasattr(block, "ops") or not hasattr(block, "vars"): - return - - # Analyze the block to get array size information - from pecos.slr.gen_codes.guppy.ir_analyzer import IRAnalyzer - - analyzer = IRAnalyzer() - analyzer.analyze_block(block, self.context.variables) - - # Analyze what this block needs - deps = self._analyze_block_dependencies(block) - - # Determine what function this block will call - func_name = self._get_function_name_for_block(block) - - # Check quantum arrays that this block uses - for var in deps["quantum"] & deps["reads"]: - # Skip struct variables - if any(var in info["var_names"].values() for info in self.struct_info.values()): - continue - - # Check if we have a refreshed version from a previous function call - actual_var = var - if hasattr(self, "refreshed_arrays") and var in self.refreshed_arrays: - actual_var = self.refreshed_arrays[var] - - # Get the expected size from the original variable context - expected_size = None - if var in self.context.variables: - var_info = self.context.variables[var] - if hasattr(var_info, "size"): - expected_size = var_info.size - - if expected_size is None: - continue # Couldn't determine expected size - - # Check the actual current size if the array is unpacked - actual_size = None - if hasattr(self, "unpacked_vars") and actual_var in self.unpacked_vars: - actual_size = len(self.unpacked_vars[actual_var]) - if actual_size is None and actual_var != var: - # This is a refreshed array from a function return - # Try to determine its size from the upcoming function call's return type - actual_size = self._infer_current_array_size_from_fresh_var( - var, - actual_var, - func_name, - expected_size, - ) - - # If we have a size mismatch, restore the array size - if actual_size is not None and actual_size < expected_size: - self._insert_array_size_restoration( - var, - actual_var, - actual_size, - expected_size, - ) - - def _get_function_name_for_block(self, block) -> str | None: - """Determine what function name a block will call when converted.""" - # The block has a name attribute that corresponds to the function - if hasattr(block, "name"): - return block.name - # If block has a __class__ attribute with the name - if hasattr(block, "__class__"): - return block.__class__.__name__.lower() - return None - - def _infer_current_array_size_from_fresh_var( - self, - var: str, - _actual_var: str, - _func_name: str | None, - expected_size: int, - ) -> int | None: - """Infer the current size of a refreshed array by checking what function produced it. - - This looks at refreshed_by_function to find what function was called to produce actual_var, - then looks up that function's return type to determine the actual size. - """ - import re - - # Check if we've tracked which function call produced this refreshed variable - if not hasattr(self, "refreshed_by_function") or var not in self.refreshed_by_function: - # No information about which function produced this variable - # This happens on the first iteration of a loop before any calls - return expected_size - - func_info = self.refreshed_by_function[var] - # Extract function name and position - if isinstance(func_info, dict): - called_func_name = func_info["function"] - return_position = func_info.get("position", 0) - else: - called_func_name = func_info # Legacy string format - return_position = 0 - - # Get the return type for this function - # Try multiple sources: function_return_types, function_info - return_type = None - - if hasattr(self, "function_return_types") and called_func_name in self.function_return_types: - return_type = self.function_return_types[called_func_name] - elif hasattr(self, "function_info") and called_func_name in self.function_info: - func_info_entry = self.function_info[called_func_name] - if "return_type" in func_info_entry: - return_type = func_info_entry["return_type"] - - if return_type is None and hasattr(self, "pending_functions"): - # Check pending functions - they haven't been built yet but we can analyze their blocks - for pending_block, pending_name, _pending_sig in self.pending_functions: - if pending_name == called_func_name: - # Analyze the pending block to determine its return type - return_type = self._infer_return_type_from_block(pending_block) - break - - if return_type is None: - return expected_size - - # Parse the return type to extract array sizes - # Return type could be: - # - "array[quantum.qubit, N]" for single return - # - "tuple[array[quantum.qubit, N1], array[quantum.qubit, N2], ...]" for multiple returns - - # Check if it's a tuple return - if return_type.startswith("tuple["): - # Extract all array sizes from the tuple - # Pattern: array[quantum.qubit, SIZE] - array_pattern = r"array\[quantum\.qubit,\s*(\d+)\]" - matches = re.findall(array_pattern, return_type) - - if return_position < len(matches): - return int(matches[return_position]) - else: - # Single return value - match = re.search(r"array\[quantum\.qubit,\s*(\d+)\]", return_type) - if match: - return int(match.group(1)) - - # If we can't determine the size, assume it's the same as expected (no restoration needed) - return expected_size - - def _infer_return_type_from_block(self, block) -> str | None: - """Analyze a block to infer its return type. - - Priority order: - 1. If both block_returns annotation AND Return() statement exist, use them together - for precise variable-to-type mapping - 2. If only block_returns annotation exists, use positional sizes - 3. Fall back to analyzing block.vars and context (old behavior) - - Returns: - A Guppy type string like "array[quantum.qubit, 2]" or - "tuple[array[quantum.qubit, 2], array[quantum.qubit, 7]]" - """ - # BEST CASE: Both annotation and Return() statement exist - if hasattr(block, "__slr_return_type__") and hasattr(block, "get_return_vars"): - return_vars = block.get_return_vars() - if return_vars: - # We have explicit Return(var1, var2, ...) statement - # Combine with annotation for robust type checking - sizes = block.__slr_return_type__ - if len(return_vars) == len(sizes): - # Perfect match - we know which variable has which size - return_types = [f"array[quantum.qubit, {size}]" for size in sizes] - if len(return_types) == 1: - return return_types[0] - return f"tuple[{', '.join(return_types)}]" - # Mismatch - validation should have caught this, but proceed with annotation - - # SECOND BEST: Just the annotation (positional sizes) - if hasattr(block, "__slr_return_type__"): - sizes = block.__slr_return_type__ - return_types = [f"array[quantum.qubit, {size}]" for size in sizes] - if len(return_types) == 1: - return return_types[0] - return f"tuple[{', '.join(return_types)}]" - - # FALLBACK: Try to infer from Return() statement variables - if hasattr(block, "get_return_vars"): - return_vars = block.get_return_vars() - if return_vars: - return self._infer_types_from_return_vars(return_vars) - - # OLD FALLBACK: Try to infer from vars and context - if not hasattr(block, "vars") or not block.vars: - return None - - # Get the return variables from block.vars - return_vars = block.vars if isinstance(block.vars, list | tuple) else [block.vars] - return self._infer_types_from_return_vars(return_vars) - - def _infer_types_from_return_vars(self, return_vars) -> str | None: - """Infer Guppy types from a list of return variables by looking them up in context. - - Args: - return_vars: List of variables to infer types for - - Returns: - A Guppy type string or None if types couldn't be inferred - """ - # For each return variable, determine its type and size - return_types = [] - for var in return_vars: - var_name = var.sym if hasattr(var, "sym") else str(var) - - # Check if the Vars object itself has size information - if hasattr(var, "size"): - size = var.size - return_types.append(f"array[quantum.qubit, {size}]") - continue - - # Check if this is a quantum array in context - if var_name in self.context.variables: - var_info = self.context.variables[var_name] - if hasattr(var_info, "size"): - # This is a quantum array - size = var_info.size - return_types.append(f"array[quantum.qubit, {size}]") - # else: Not a quantum array, skip for now - - if not return_types: - return None - - if len(return_types) == 1: - return return_types[0] - return f"tuple[{', '.join(return_types)}]" - - def _infer_refreshed_array_size( - self, - var: str, - _actual_var: str, - expected_size: int, - ) -> int | None: - """Infer the size of a refreshed array from function return types. - - When a function returns a smaller array than it received, we need to know - the actual returned size. This method looks up the function call that - produced the refreshed array and extracts the size from its return type. - """ - import re - - # Check if we've tracked which function call produced this refreshed variable - if not hasattr(self, "refreshed_by_function") or var not in self.refreshed_by_function: - # No information about which function produced this variable - return expected_size - - func_info = self.refreshed_by_function[var] - func_name = func_info.get("function") - return_position = func_info.get( - "position", - 0, - ) # Which element in the return tuple - - # Get the return type for this function - if not hasattr(self, "function_return_types") or func_name not in self.function_return_types: - return expected_size - - return_type = self.function_return_types[func_name] - - # Parse the return type to extract array sizes - # Return type could be: - # - "array[quantum.qubit, N]" for single return - # - "tuple[array[quantum.qubit, N1], array[quantum.qubit, N2], ...]" for multiple returns - - # Check if it's a tuple return - if return_type.startswith("tuple["): - # Extract all array sizes from the tuple - # Pattern: array[quantum.qubit, SIZE] - array_pattern = r"array\[quantum\.qubit,\s*(\d+)\]" - matches = re.findall(array_pattern, return_type) - - if return_position < len(matches): - return int(matches[return_position]) - else: - # Single return value - match = re.search(r"array\[quantum\.qubit,\s*(\d+)\]", return_type) - if match: - return int(match.group(1)) - - # If we can't determine the size, assume it's the same as expected (no restoration needed) - return expected_size - - def _insert_array_size_restoration( - self, - var: str, - actual_var: str, - actual_size: int, - expected_size: int, - ) -> None: - """Insert code to restore an array to its expected size by allocating fresh qubits.""" - from pecos.slr.gen_codes.guppy.ir import ( - Assignment, - Comment, - FunctionCall, - VariableRef, - ) - - num_to_allocate = expected_size - actual_size - - self.current_block.statements.append( - Comment(f"Restore {var} array size from {actual_size} to {expected_size}"), - ) - - # Unpack the current smaller array - if hasattr(self, "unpacked_vars") and actual_var in self.unpacked_vars: - current_elements = self.unpacked_vars[actual_var] - else: - # Create unpacking statement - current_elements = [f"{actual_var}_{i}" for i in range(actual_size)] - unpack_targets = ", ".join(current_elements) - self.current_block.statements.append( - Assignment( - target=VariableRef(unpack_targets), - value=VariableRef(actual_var), - ), - ) - - # Allocate fresh qubits - new_elements = [] - for i in range(num_to_allocate): - fresh_var = self._get_unique_var_name(f"{var}_allocated_{actual_size + i}") - self.current_block.statements.append( - Assignment( - target=VariableRef(fresh_var), - value=FunctionCall(func_name="quantum.qubit", args=[]), - ), - ) - new_elements.append(fresh_var) - - # Reconstruct the full-size array and reassign to the actual_var (fresh variable) - # This ensures the variable stays consistently defined throughout the loop - all_elements = current_elements + new_elements - - array_construction = self._create_array_construction(all_elements) - self.current_block.statements.append( - Assignment( - target=VariableRef(actual_var), - value=array_construction, - ), - ) diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_generator.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_generator.py deleted file mode 100644 index b39c587eb..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_generator.py +++ /dev/null @@ -1,79 +0,0 @@ -"""IR-based Guppy generator that uses two-pass architecture.""" - -from __future__ import annotations - -import warnings -from typing import TYPE_CHECKING - -from pecos.slr.gen_codes.generator import Generator -from pecos.slr.gen_codes.guppy.dependency_analyzer import DependencyAnalyzer -from pecos.slr.gen_codes.guppy.ir import ScopeContext -from pecos.slr.gen_codes.guppy.ir_builder import IRBuilder -from pecos.slr.gen_codes.guppy.ir_postprocessor import IRPostProcessor -from pecos.slr.gen_codes.guppy.unified_resource_planner import UnifiedResourcePlanner - -if TYPE_CHECKING: - from pecos.slr import Block - - -class IRGuppyGenerator(Generator): - """Generator that uses IR for two-pass Guppy code generation. - - .. deprecated:: - Use :func:`pecos.slr.generate` with ``target="guppy"`` instead. - """ - - def __init__(self, *, _internal: bool = False): - """Initialize the IR-based generator.""" - if not _internal: - warnings.warn( - "GuppyGenerator/IRGuppyGenerator is deprecated. Use pecos.slr.generate(prog, 'guppy') instead.", - DeprecationWarning, - stacklevel=2, - ) - self.output = [] - self.variable_context = {} - self.dependency_analyzer = DependencyAnalyzer() - - def generate_block(self, block: Block) -> None: - """Generate Guppy code for a block using IR.""" - # Build variable context from block - self._build_variable_context(block) - - # First pass: Analyze the block with unified resource planning - # This coordinates unpacking decisions with allocation strategies - planner = UnifiedResourcePlanner() - unified_analysis = planner.analyze(block, self.variable_context) - - # Convert unified analysis to UnpackingPlan - # The unified planner internally runs IRAnalyzer, so we don't need to run it again - unpacking_plan = unified_analysis.get_unpacking_plan() - - # Second pass: Build IR with both unpacking plan and unified analysis - builder = IRBuilder( - unpacking_plan, - unified_analysis=unified_analysis, - include_optimization_report=True, - ) - module = builder.build_module(block, []) # No pending functions for now - - # Post-processing pass: Fix array accesses after unpacking - context = ScopeContext() - postprocessor = IRPostProcessor() - postprocessor.process_module(module, context) - - # Third pass: Render to Guppy code - lines = module.render(context) - - self.output = lines - - def _build_variable_context(self, block: Block) -> None: - """Build variable context from block declarations.""" - if hasattr(block, "vars"): - for var in block.vars: - if hasattr(var, "sym"): - self.variable_context[var.sym] = var - - def get_output(self) -> str: - """Get the generated Guppy code.""" - return "\n".join(self.output) diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_postprocessor.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_postprocessor.py deleted file mode 100644 index 664756b8f..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/ir_postprocessor.py +++ /dev/null @@ -1,253 +0,0 @@ -"""Post-processor for IR nodes to fix array access after unpacking. - -This module provides a post-processing pass that runs after IR building -but before rendering to replace ArrayAccess nodes with VariableRef nodes -for arrays that have been unpacked. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from pecos.slr.gen_codes.guppy.ir import ( - ArrayAccess, - ArrayUnpack, - Assignment, - BinaryOp, - Block, - FieldAccess, - ForStatement, - FunctionCall, - IfStatement, - IRNode, - Measurement, - MethodCall, - ReturnStatement, - ScopeContext, - TupleExpression, - UnaryOp, - VariableInfo, - VariableRef, - WhileStatement, -) - -if TYPE_CHECKING: - from pecos.slr.gen_codes.guppy.ir import ( - Function, - Module, - ) - - -class IRPostProcessor: - """Post-processes IR to fix array accesses after unpacking decisions.""" - - def __init__(self): - # Track unpacked arrays per function: func_name -> array_name -> list of unpacked variable names - self.unpacked_arrays_by_function: dict[str, dict[str, list[str]]] = {} - # Track current scope for variable lookups - self.current_scope: ScopeContext | None = None - # Track refreshed arrays per function - self.refreshed_arrays: dict[str, set[str]] = {} - # Track current function being processed - self.current_function: str | None = None - - def process_module(self, module: Module, context: ScopeContext) -> None: - """Process a module and all its functions.""" - self.current_scope = context - - # Store refreshed arrays from module - self.refreshed_arrays = module.refreshed_arrays - - # First, analyze the module to populate unpacking information - module.analyze(context) - - # Then traverse and fix array accesses - for func in module.functions: - self._process_function(func, context) - - def _process_function(self, func: Function, parent_context: ScopeContext) -> None: - """Process a function.""" - # Track current function - self.current_function = func.name - - # Initialize unpacked arrays for this function if not exists - if func.name not in self.unpacked_arrays_by_function: - self.unpacked_arrays_by_function[func.name] = {} - - # Create function scope - func_context = ScopeContext(parent=parent_context) - - # Add parameters to scope - for param_name, param_type in func.params: - var_info = VariableInfo( - name=param_name, - original_name=param_name, - var_type=param_type, - ) - func_context.add_variable(var_info) - - # Process function body - self._process_block(func.body, func_context) - - def _process_block(self, block: Block, context: ScopeContext) -> None: - """Process a block of statements.""" - old_scope = self.current_scope - self.current_scope = context - - # First pass: collect unpacking information - for stmt in block.statements: - if isinstance(stmt, ArrayUnpack): - # Record unpacking info for the current function - if self.current_function: - self.unpacked_arrays_by_function[self.current_function][stmt.source] = stmt.targets - # Also update the context - var = context.lookup_variable(stmt.source) - if var: - var.is_unpacked = True - var.unpacked_names = stmt.targets - else: - # Create variable info if it doesn't exist - var_info = VariableInfo( - name=stmt.source, - original_name=stmt.source, - var_type="quantum", - is_array=True, - is_unpacked=True, - unpacked_names=stmt.targets, - ) - context.add_variable(var_info) - - # Second pass: process all statements - for i, stmt in enumerate(block.statements): - block.statements[i] = self._process_node(stmt, context) - - self.current_scope = old_scope - - def _process_node(self, node: IRNode, context: ScopeContext) -> IRNode: - """Process any IR node, replacing ArrayAccess as needed.""" - if node is None: - return None - - # Handle ArrayAccess specially - if isinstance(node, ArrayAccess): - return self._process_array_access(node, context) - - # Handle compound nodes that contain other nodes - if isinstance(node, Assignment): - node.target = self._process_node(node.target, context) - node.value = self._process_node(node.value, context) - - elif isinstance(node, FunctionCall): - node.args = [self._process_node(arg, context) for arg in node.args] - - elif isinstance(node, MethodCall): - node.obj = self._process_node(node.obj, context) - node.args = [self._process_node(arg, context) for arg in node.args] - - elif isinstance(node, BinaryOp): - node.left = self._process_node(node.left, context) - node.right = self._process_node(node.right, context) - - elif isinstance(node, UnaryOp): - node.operand = self._process_node(node.operand, context) - - elif isinstance(node, Measurement): - node.qubit = self._process_node(node.qubit, context) - if node.target: - node.target = self._process_node(node.target, context) - - elif isinstance(node, ReturnStatement): - if node.value: - node.value = self._process_node(node.value, context) - - elif isinstance(node, TupleExpression): - node.elements = [self._process_node(elem, context) for elem in node.elements] - - elif isinstance(node, IfStatement): - node.condition = self._process_node(node.condition, context) - self._process_block(node.then_block, ScopeContext(parent=context)) - if node.else_block: - self._process_block(node.else_block, ScopeContext(parent=context)) - - elif isinstance(node, WhileStatement): - node.condition = self._process_node(node.condition, context) - self._process_block(node.body, ScopeContext(parent=context)) - - elif isinstance(node, ForStatement): - node.iterable = self._process_node(node.iterable, context) - self._process_block(node.body, ScopeContext(parent=context)) - - elif isinstance(node, Block): - self._process_block(node, context) - - elif isinstance(node, FieldAccess): - node.obj = self._process_node(node.obj, context) - - # Return the node (possibly modified) - return node - - def _process_array_access(self, node: ArrayAccess, context: ScopeContext) -> IRNode: - """Process an ArrayAccess node, possibly replacing it with VariableRef.""" - # Check if this is accessing an unpacked array - array_name = None - - # Extract array name from different forms - if node.array_name: - # Old API: direct array name - array_name = node.array_name - elif isinstance(node.array, VariableRef): - # New API: array is a VariableRef - array_name = node.array.name - elif isinstance(node.array, FieldAccess): - # Complex case: struct.field[index] - # Process the field access but don't replace the array access - node.array = self._process_node(node.array, context) - return node - - # Debug: Print what we're processing - # print(f"DEBUG: Processing ArrayAccess - array_name={array_name}, index={node.index}") - # print(f"DEBUG: unpacked_arrays={self.unpacked_arrays}") - - # If we have an array name and a constant index, check for unpacking - if array_name and isinstance(node.index, int): - # Check if this array was refreshed by a function call - # If so, we should NOT convert to unpacked variable names - if ( - self.current_function - and self.current_function in self.refreshed_arrays - and array_name in self.refreshed_arrays[self.current_function] - ): - # Array was refreshed, keep as ArrayAccess with force_array_syntax - node.force_array_syntax = True - # Process array and index if needed - if node.array and isinstance(node.array, IRNode): - node.array = self._process_node(node.array, context) - if isinstance(node.index, IRNode): - node.index = self._process_node(node.index, context) - return node - - # Look up variable info - var = context.lookup_variable(array_name) - if var and var.is_unpacked and node.index < len(var.unpacked_names): - # Replace with VariableRef to the unpacked variable - # print(f"DEBUG: Replacing {array_name}[{node.index}] with {var.unpacked_names[node.index]}") - return VariableRef(var.unpacked_names[node.index]) - - # Also check our function-specific tracking - if self.current_function and self.current_function in self.unpacked_arrays_by_function: - func_unpacked = self.unpacked_arrays_by_function[self.current_function] - if array_name in func_unpacked: - unpacked_names = func_unpacked[array_name] - if node.index < len(unpacked_names): - # print(f"DEBUG: Replacing {array_name}[{node.index}] with {unpacked_names[node.index]}") - return VariableRef(unpacked_names[node.index]) - - # Process array if it's an IRNode - if node.array and isinstance(node.array, IRNode): - node.array = self._process_node(node.array, context) - - # Process index if it's an IRNode - if isinstance(node.index, IRNode): - node.index = self._process_node(node.index, context) - - return node diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/naming.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/naming.py deleted file mode 100644 index 91b16bbd3..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/naming.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Utilities for converting block names to function names.""" - -import re - - -def class_to_function_name(class_name: str) -> str: - """Convert a PascalCase class name to snake_case function name. - - Examples: - PrepareGHZ -> prepare_ghz - ApplyXCorrection -> apply_x_correction - QPEStep -> qpe_step - PrepareLogical0 -> prepare_logical_0 - """ - # First, handle the sequence of capital letters followed by lowercase (e.g., 'GHZ' -> 'ghz') - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", class_name) - # Handle lowercase (or number) followed by capital letter - s2 = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1) - # Handle letter followed by number - s3 = re.sub(r"([a-zA-Z])(\d)", r"\1_\2", s2) - return s3.lower() - - -def get_module_prefix(block_class) -> str | None: - """Get module-based prefix for a block class. - - Examples: - pecos.qeclib.steane.PrepareLogical0 -> steane_ - pecos.qeclib.surface.MeasureStabilizers -> surface_ - mypackage.circuits.teleport.BellPair -> teleport_ - """ - module = getattr(block_class, "__module__", "") - if not module: - return None - - # Look for qeclib patterns - if "qeclib" in module: - parts = module.split(".") - try: - qeclib_idx = parts.index("qeclib") - if qeclib_idx + 1 < len(parts): - # Return the module after qeclib (e.g., 'steane', 'surface') - return parts[qeclib_idx + 1] + "_" - except (ValueError, IndexError): - pass - - # For other patterns, look for meaningful module names - common_modules = { - "blocks", - "ops", - "operations", - "circuits", - "components", - "__main__", - } - parts = module.split(".") - - # Skip the class name itself if it's at the end - if parts and parts[-1] == block_class.__name__: - parts = parts[:-1] - - # Look backwards for a meaningful, specific module name - for i in range(len(parts) - 1, -1, -1): - part = parts[i] - # Skip common structural names - if part in common_modules: - continue - # Found a specific module name - return part + "_" - - return None - - -def get_function_name(block_class, *, use_module_prefix: bool = True) -> str: - """Get the full function name for a block class. - - Args: - block_class: The block class - use_module_prefix: Whether to include module-based prefix - - Returns: - Function name like 'prepare_ghz' or 'steane_prepare_logical_0' - """ - # Get base name from class - class_name = block_class.__name__ - base_name = class_to_function_name(class_name) - - # Add module prefix if requested - if use_module_prefix: - prefix = get_module_prefix(block_class) - if prefix and not base_name.startswith(prefix.rstrip("_")): - return prefix + base_name - - return base_name - - -# Example usage and tests -if __name__ == "__main__": - # Test class name conversion - test_cases = [ - ("PrepareGHZ", "prepare_ghz"), - ("ApplyXCorrection", "apply_x_correction"), - ("QPEStep", "qpe_step"), - ("PrepareLogical0", "prepare_logical_0"), - ("CNOTGate", "cnot_gate"), - ("Phase90", "phase_90"), - ("TOffoli3", "t_offoli_3"), - ] - - print("Class name conversions:") - for input_name, expected in test_cases: - result = class_to_function_name(input_name) - status = "PASS" if result == expected else "FAIL" - print(f" {status} {input_name} -> {result} (expected: {expected})") - - # Test module prefix extraction - print("\nModule prefix extraction:") - - class MockClass: - pass - - # Test different module paths - test_modules = [ - ("pecos.qeclib.steane.PrepareLogical0", "steane_"), - ("pecos.qeclib.surface.MeasureStabilizers", "surface_"), - ("pecos.qeclib.bacon_shor.ExtractSyndrome", "bacon_shor_"), - ("mypackage.circuits.teleport.BellPair", "teleport_"), - ("mypackage.circuits.BellPair", None), # 'circuits' is common - ("pecos.slr.blocks.CustomBlock", None), # 'blocks' is common - ("__main__.MyBlock", None), - ] - - for module_path, expected_prefix in test_modules: - MockClass.__module__ = module_path - MockClass.__name__ = module_path.split(".")[-1] - prefix = get_module_prefix(MockClass) - status = "PASS" if prefix == expected_prefix else "FAIL" - print(f" {status} {module_path} -> {prefix} (expected: {expected_prefix})") diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/qubit_state_validator.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/qubit_state_validator.py deleted file mode 100644 index c1ee4e74a..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/qubit_state_validator.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright 2026 The PECOS Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. - -"""Qubit state validation for SLR programs. - -This module validates that quantum gates are only applied to prepared qubit slots, -detecting compile-time errors when gates are applied to unprepared/measured qubits. - -The validation follows the two-state model from the QAlloc design: -- UNPREPARED: Initial state or after measurement - cannot apply gates -- PREPARED: After preparation - ready for gate operations - -Validation errors are collected and can be reported as compile-time errors. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from enum import Enum, auto -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from pecos.slr import Block as SLRBlock - - -class ValidationSlotState(Enum): - """State of a qubit slot for validation purposes.""" - - UNPREPARED = auto() # Initial state or after measurement - PREPARED = auto() # After preparation, ready for gates - - -@dataclass -class StateViolation: - """A validation error: gate applied to unprepared slot.""" - - array_name: str - index: int - position: int - gate_name: str - message: str - - def __str__(self) -> str: - return f"{self.array_name}[{self.index}] at position {self.position}: {self.message}" - - -@dataclass -class QubitStateTracker: - """Tracks the preparation state of qubit slots through program execution. - - Used to validate that gates are only applied to prepared qubits. - """ - - # Map from (array_name, index) to current state - slot_states: dict[tuple[str, int], ValidationSlotState] = field( - default_factory=dict, - ) - - # Collected violations - violations: list[StateViolation] = field(default_factory=list) - - # Position counter for tracking operation order - position: int = 0 - - def get_state(self, array_name: str, index: int) -> ValidationSlotState: - """Get the current state of a slot. Defaults to UNPREPARED.""" - return self.slot_states.get((array_name, index), ValidationSlotState.UNPREPARED) - - def mark_prepared(self, array_name: str, index: int) -> None: - """Mark a slot as prepared (after Prep/Init/Reset).""" - self.slot_states[(array_name, index)] = ValidationSlotState.PREPARED - - def mark_unprepared(self, array_name: str, index: int) -> None: - """Mark a slot as unprepared (after measurement).""" - self.slot_states[(array_name, index)] = ValidationSlotState.UNPREPARED - - def validate_gate(self, array_name: str, index: int, gate_name: str) -> bool: - """Validate that a gate can be applied to this slot. - - Returns True if valid, False if violation detected. - """ - state = self.get_state(array_name, index) - if state == ValidationSlotState.UNPREPARED: - self.violations.append( - StateViolation( - array_name=array_name, - index=index, - position=self.position, - gate_name=gate_name, - message=f"Gate '{gate_name}' applied to unprepared qubit. " - f"Call prepare() before applying gates.", - ), - ) - return False - return True - - def has_violations(self) -> bool: - """Check if any violations were detected.""" - return len(self.violations) > 0 - - def get_violations(self) -> list[StateViolation]: - """Get all detected violations.""" - return self.violations.copy() - - def clear_violations(self) -> None: - """Clear all violations.""" - self.violations.clear() - - -class QubitStateValidator: - """Validates qubit state requirements in SLR programs. - - Walks through the program operations and validates that: - 1. Gates are only applied to prepared qubits - 2. Measurements transition qubits to unprepared - 3. Preparations transition qubits to prepared - - Usage: - validator = QubitStateValidator() - violations = validator.validate(block, variable_context) - - if violations: - for v in violations: - print(f"Error: {v}") - """ - - # Operations that prepare qubits - PREPARATION_OPS = frozenset({"Prep", "Init", "Reset", "PrepZ", "PrepX", "PrepY"}) - - # Operations that consume/measure qubits - MEASUREMENT_OPS = frozenset({"Measure", "MeasZ", "MeasX", "MeasY"}) - - def __init__(self, *, strict: bool = True): - """Initialize the validator. - - Args: - strict: If True, all qubits start unprepared and must be explicitly prepared. - If False, qubits are assumed prepared initially (legacy compatibility). - """ - self.strict = strict - self.tracker = QubitStateTracker() - - def validate( - self, - block: SLRBlock, - variable_context: dict[str, Any] | None = None, - ) -> list[StateViolation]: - """Validate qubit states in a block. - - Args: - block: The SLR block to validate. - variable_context: Optional context of variables (QReg, CReg, etc.). - - Returns: - List of StateViolation objects for any detected errors. - """ - self.tracker = QubitStateTracker() - variable_context = variable_context or {} - - # In non-strict mode, mark all known qubits as prepared initially - if not self.strict: - self._initialize_prepared(variable_context) - - # Validate all operations - if hasattr(block, "ops"): - for op in block.ops: - self._validate_operation(op, variable_context) - self.tracker.position += 1 - - return self.tracker.get_violations() - - def _initialize_prepared(self, variable_context: dict[str, Any]) -> None: - """Mark all qubits as initially prepared (legacy mode).""" - for var in variable_context.values(): - if hasattr(var, "size") and hasattr(var, "sym"): - # Check if it's a quantum register - var_type = type(var).__name__ - if var_type in ("QReg", "QAlloc"): - for i in range(var.size): - self.tracker.mark_prepared(var.sym, i) - - def _validate_operation( - self, - op: Any, - variable_context: dict[str, Any], - ) -> None: - """Validate a single operation.""" - op_name = type(op).__name__ - - if op_name in self.MEASUREMENT_OPS: - self._handle_measurement(op) - elif op_name in self.PREPARATION_OPS: - self._handle_preparation(op) - elif op_name == "If": - self._validate_if_block(op, variable_context) - elif op_name in ("For", "While", "Repeat"): - self._validate_loop_block(op, variable_context) - elif op_name == "Parallel": - self._validate_parallel_block(op, variable_context) - elif hasattr(op, "qargs"): - # This is a quantum gate - self._validate_gate(op) - elif hasattr(op, "ops"): - # Nested block - recurse - for nested_op in op.ops: - self._validate_operation(nested_op, variable_context) - - def _handle_measurement(self, op: Any) -> None: - """Handle measurement: transitions qubits to unprepared.""" - if hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - if self._has_reg_and_index(qarg): - self.tracker.mark_unprepared(qarg.reg.sym, qarg.index) - - def _handle_preparation(self, op: Any) -> None: - """Handle preparation: transitions qubits to prepared.""" - if hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - if self._has_reg_and_index(qarg): - self.tracker.mark_prepared(qarg.reg.sym, qarg.index) - - def _validate_gate(self, op: Any) -> None: - """Validate a quantum gate: all qubits must be prepared.""" - gate_name = type(op).__name__ - if hasattr(op, "qargs") and op.qargs: - for qarg in op.qargs: - if self._has_reg_and_index(qarg): - self.tracker.validate_gate(qarg.reg.sym, qarg.index, gate_name) - - def _validate_if_block( - self, - if_block: Any, - variable_context: dict[str, Any], - ) -> None: - """Validate an if block. - - For if/else, we need to be conservative: a qubit is only considered - prepared after the if block if it's prepared in BOTH branches. - """ - # Save state before if - state_before = dict(self.tracker.slot_states) - - # Validate then block - if hasattr(if_block, "ops"): - for op in if_block.ops: - self._validate_operation(op, variable_context) - - # Save state after then - state_after_then = dict(self.tracker.slot_states) - - # Reset to state before if and validate else - self.tracker.slot_states = dict(state_before) - - if hasattr(if_block, "else_block") and if_block.else_block and hasattr(if_block.else_block, "ops"): - for op in if_block.else_block.ops: - self._validate_operation(op, variable_context) - - state_after_else = dict(self.tracker.slot_states) - - # Merge states: only prepared if prepared in BOTH branches - merged_state = {} - all_keys = set(state_after_then.keys()) | set(state_after_else.keys()) - for key in all_keys: - then_state = state_after_then.get(key, ValidationSlotState.UNPREPARED) - else_state = state_after_else.get(key, ValidationSlotState.UNPREPARED) - # Only prepared if prepared in both branches - if then_state == ValidationSlotState.PREPARED and else_state == ValidationSlotState.PREPARED: - merged_state[key] = ValidationSlotState.PREPARED - else: - merged_state[key] = ValidationSlotState.UNPREPARED - - self.tracker.slot_states = merged_state - - def _validate_loop_block( - self, - loop_block: Any, - variable_context: dict[str, Any], - ) -> None: - """Validate a loop block. - - For loops, we validate the body once but assume the state after - the loop could be any state that occurs during the loop. - """ - # Validate loop body - if hasattr(loop_block, "ops"): - for op in loop_block.ops: - self._validate_operation(op, variable_context) - - def _validate_parallel_block( - self, - parallel_block: Any, - variable_context: dict[str, Any], - ) -> None: - """Validate a parallel block.""" - if hasattr(parallel_block, "ops"): - for op in parallel_block.ops: - self._validate_operation(op, variable_context) - - def _has_reg_and_index(self, qarg: Any) -> bool: - """Check if a qubit argument has reg and index attributes.""" - return hasattr(qarg, "reg") and hasattr(qarg.reg, "sym") and hasattr(qarg, "index") - - -def validate_qubit_states( - block: SLRBlock, - variable_context: dict[str, Any] | None = None, - *, - strict: bool = True, -) -> list[StateViolation]: - """Convenience function to validate qubit states in a block. - - Args: - block: The SLR block to validate. - variable_context: Optional context of variables. - strict: If True, qubits must be explicitly prepared before use. - - Returns: - List of StateViolation objects for any detected errors. - """ - validator = QubitStateValidator(strict=strict) - return validator.validate(block, variable_context) diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/qubit_usage_analyzer.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/qubit_usage_analyzer.py deleted file mode 100644 index ded78c7e4..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/qubit_usage_analyzer.py +++ /dev/null @@ -1,259 +0,0 @@ -"""Analyzer for qubit usage patterns to optimize allocation strategies.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from enum import Enum -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pecos.slr import Block as SLRBlock - - -class QubitRole(Enum): - """Role classification for quantum registers.""" - - DATA = "data" # Long-lived data qubits - ANCILLA = "ancilla" # Short-lived ancilla qubits - UNKNOWN = "unknown" # Not yet classified - - -@dataclass -class QubitUsageStats: - """Statistics about how a quantum register is used.""" - - name: str - size: int - - # Usage patterns - measurement_count: int = 0 - reset_count: int = 0 - gate_count: int = 0 - - # Lifetime tracking - first_use_position: int | None = None - last_use_position: int | None = None - measurement_positions: list[int] = field(default_factory=list) - reset_positions: list[int] = field(default_factory=list) - - # Access patterns - individual_accesses: set[int] = field(default_factory=set) - full_array_accesses: int = 0 - - # Structural hints - is_struct_member: bool = False - struct_name: str | None = None - - @property - def lifetime(self) -> int: - """Calculate the lifetime of this register.""" - if self.first_use_position is None or self.last_use_position is None: - return 0 - return self.last_use_position - self.first_use_position - - @property - def measure_reset_ratio(self) -> float: - """Ratio of measurements+resets to total operations.""" - total_ops = self.measurement_count + self.reset_count + self.gate_count - if total_ops == 0: - return 0.0 - return (self.measurement_count + self.reset_count) / total_ops - - @property - def individual_access_ratio(self) -> float: - """Ratio of individual element accesses to size.""" - if self.size == 0: - return 0.0 - return len(self.individual_accesses) / self.size - - def classify_role(self) -> QubitRole: - """Classify the role of this register based on usage patterns.""" - # Explicit ancilla naming patterns - if any(pattern in self.name.lower() for pattern in ["ancilla", "anc", "syndrome", "flag"]): - return QubitRole.ANCILLA - - # Explicit data naming patterns - if any(pattern in self.name.lower() for pattern in ["data", "logical", "code"]): - return QubitRole.DATA - - # Pattern-based classification - # High measure/reset ratio suggests ancilla - if self.measure_reset_ratio > 0.7: - return QubitRole.ANCILLA - - # Short lifetime with measurements suggests ancilla - if self.lifetime < 10 and self.measurement_count > 0: - return QubitRole.ANCILLA - - # Part of a struct (like QEC code) suggests data - if self.is_struct_member: - return QubitRole.DATA - - # Default to data for long-lived qubits - if self.lifetime > 20: - return QubitRole.DATA - - return QubitRole.UNKNOWN - - -class QubitUsageAnalyzer: - """Analyzes qubit usage patterns to inform allocation strategies.""" - - def __init__(self): - self.register_stats: dict[str, QubitUsageStats] = {} - self.position_counter = 0 - - def analyze_block( - self, - block: SLRBlock, - struct_info: dict[str, dict] | None = None, - ) -> dict[str, QubitUsageStats]: - """Analyze a block and return usage statistics for each quantum register.""" - # Reset state - self.register_stats.clear() - self.position_counter = 0 - - # First, collect all quantum registers - if hasattr(block, "vars"): - for var in block.vars: - if type(var).__name__ == "QReg" and hasattr(var, "sym"): - stats = QubitUsageStats( - name=var.sym, - size=getattr(var, "size", 1), - ) - - # Check if part of a struct - if struct_info: - for struct_name, info in struct_info.items(): - if var.sym in info.get("var_names", {}).values(): - stats.is_struct_member = True - stats.struct_name = struct_name - break - - self.register_stats[var.sym] = stats - - # Analyze operations - if hasattr(block, "ops"): - for op in block.ops: - self._analyze_operation(op) - self.position_counter += 1 - - return self.register_stats - - def _analyze_operation(self, op) -> None: - """Analyze a single operation for qubit usage patterns.""" - op_type = type(op).__name__ - - if op_type == "Measure": - self._analyze_measurement(op) - elif op_type in ["Prep", "Reset"]: - self._analyze_reset(op) - elif hasattr(op, "qargs"): - self._analyze_gate(op) - elif hasattr(op, "ops"): - # Nested block - for nested_op in op.ops: - self._analyze_operation(nested_op) - - def _analyze_measurement(self, meas) -> None: - """Analyze measurement operations.""" - if hasattr(meas, "qargs") and meas.qargs: - for qarg in meas.qargs: - reg_name = self._get_register_name(qarg) - if reg_name and reg_name in self.register_stats: - stats = self.register_stats[reg_name] - stats.measurement_count += 1 - stats.measurement_positions.append(self.position_counter) - self._update_lifetime(stats) - - # Track access pattern - if hasattr(qarg, "index"): - stats.individual_accesses.add(qarg.index) - elif hasattr(qarg, "size"): - stats.full_array_accesses += 1 - - def _analyze_reset(self, reset_op) -> None: - """Analyze reset/prep operations.""" - if hasattr(reset_op, "qargs") and reset_op.qargs: - for qarg in reset_op.qargs: - reg_name = self._get_register_name(qarg) - if reg_name and reg_name in self.register_stats: - stats = self.register_stats[reg_name] - stats.reset_count += 1 - stats.reset_positions.append(self.position_counter) - self._update_lifetime(stats) - - def _analyze_gate(self, gate) -> None: - """Analyze gate operations.""" - if hasattr(gate, "qargs") and gate.qargs: - for qarg in gate.qargs: - # Handle tuple arguments (e.g., CX gates) - if isinstance(qarg, tuple): - for sub_qarg in qarg: - self._record_gate_usage(sub_qarg) - else: - self._record_gate_usage(qarg) - - def _record_gate_usage(self, qarg) -> None: - """Record usage from a gate operation.""" - reg_name = self._get_register_name(qarg) - if reg_name and reg_name in self.register_stats: - stats = self.register_stats[reg_name] - stats.gate_count += 1 - self._update_lifetime(stats) - - # Track access pattern - if hasattr(qarg, "index"): - stats.individual_accesses.add(qarg.index) - - def _get_register_name(self, qarg) -> str | None: - """Extract register name from a qubit argument.""" - if hasattr(qarg, "reg") and hasattr(qarg.reg, "sym"): - return qarg.reg.sym - if hasattr(qarg, "sym"): - return qarg.sym - return None - - def _update_lifetime(self, stats: QubitUsageStats) -> None: - """Update lifetime tracking for a register.""" - if stats.first_use_position is None: - stats.first_use_position = self.position_counter - stats.last_use_position = self.position_counter - - def get_allocation_recommendations(self) -> dict[str, dict]: - """Get allocation recommendations based on usage analysis.""" - recommendations = {} - - for reg_name, stats in self.register_stats.items(): - role = stats.classify_role() - - if role == QubitRole.ANCILLA: - # Ancillas benefit from dynamic allocation - recommendations[reg_name] = { - "allocation": "dynamic", - "reason": f"High measure/reset ratio ({stats.measure_reset_ratio:.2f})", - "keep_packed": False, - "pre_allocate": False, - } - elif role == QubitRole.DATA: - # Data qubits should stay bundled - recommendations[reg_name] = { - "allocation": "static", - "reason": ( - "Long-lived data qubits" - if stats.is_struct_member - else f"Low measure/reset ratio ({stats.measure_reset_ratio:.2f})" - ), - "keep_packed": True, - "pre_allocate": True, - } - else: - # Default conservative approach - recommendations[reg_name] = { - "allocation": "static", - "reason": "Unknown usage pattern", - "keep_packed": True, - "pre_allocate": True, - } - - return recommendations diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/scope_manager.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/scope_manager.py deleted file mode 100644 index 0b6683322..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/scope_manager.py +++ /dev/null @@ -1,236 +0,0 @@ -"""Enhanced scope management for IR-based code generation.""" - -from __future__ import annotations - -from contextlib import contextmanager -from dataclasses import dataclass, field -from enum import Enum - -from pecos.slr.gen_codes.guppy.ir import ResourceState, ScopeContext - -# Maximum array size fallback when actual size cannot be determined from context -# This is a conservative upper bound to ensure all indices are covered -_MAX_ARRAY_SIZE_FALLBACK = 1000 - - -class ScopeType(Enum): - """Type of scope.""" - - MODULE = "module" - FUNCTION = "function" - BLOCK = "block" - IF_THEN = "if_then" - IF_ELSE = "if_else" - LOOP = "loop" - - -@dataclass -class ResourceUsage: - """Track resource usage in a scope.""" - - qreg_name: str - indices: set[int] - is_consumed: bool = False - is_borrowed: bool = False - - -@dataclass -class ScopeInfo: - """Enhanced scope information.""" - - scope_type: ScopeType - context: ScopeContext - resource_usage: dict[str, ResourceUsage] = field(default_factory=dict) - borrowed_resources: set[str] = field(default_factory=set) - returned_resources: set[str] = field(default_factory=set) - - -class ScopeManager: - """Manages scope contexts for IR generation.""" - - def __init__(self): - self.scope_stack: list[ScopeInfo] = [] - self.global_context = ScopeContext() - - @property - def current_scope(self) -> ScopeInfo | None: - """Get the current scope.""" - return self.scope_stack[-1] if self.scope_stack else None - - @property - def current_context(self) -> ScopeContext: - """Get the current scope context.""" - if self.current_scope: - return self.current_scope.context - return self.global_context - - @contextmanager - def enter_scope(self, scope_type: ScopeType): - """Enter a new scope.""" - parent_context = self.current_context - new_context = ScopeContext(parent=parent_context) - new_scope = ScopeInfo(scope_type=scope_type, context=new_context) - - self.scope_stack.append(new_scope) - try: - yield new_scope - finally: - # Analyze resource flow when exiting scope - self._analyze_scope_exit(new_scope) - self.scope_stack.pop() - - def _analyze_scope_exit(self, scope: ScopeInfo) -> None: - """Analyze resource usage when exiting a scope.""" - # For conditional scopes, propagate resource usage to parent - if scope.scope_type in [ScopeType.IF_THEN, ScopeType.IF_ELSE] and self.current_scope: - for res_name, usage in scope.resource_usage.items(): - if res_name not in self.current_scope.resource_usage: - self.current_scope.resource_usage[res_name] = ResourceUsage( - qreg_name=usage.qreg_name, - indices=set(), - ) - # Merge usage - parent_usage = self.current_scope.resource_usage[res_name] - parent_usage.indices.update(usage.indices) - if usage.is_consumed: - parent_usage.is_consumed = True - - def track_resource_usage( - self, - qreg_name: str, - indices: set[int], - *, - consumed: bool = False, - ) -> None: - """Track usage of a quantum resource in current scope.""" - if not self.current_scope: - return - - if qreg_name not in self.current_scope.resource_usage: - self.current_scope.resource_usage[qreg_name] = ResourceUsage( - qreg_name=qreg_name, - indices=set(), - ) - - usage = self.current_scope.resource_usage[qreg_name] - usage.indices.update(indices) - if consumed: - usage.is_consumed = True - - def mark_resource_borrowed(self, qreg_name: str) -> None: - """Mark a resource as borrowed in current scope.""" - if self.current_scope: - self.current_scope.borrowed_resources.add(qreg_name) - if qreg_name in self.current_scope.resource_usage: - self.current_scope.resource_usage[qreg_name].is_borrowed = True - - def is_in_loop(self) -> bool: - """Check if currently inside a loop scope.""" - return any(scope.scope_type == ScopeType.LOOP for scope in self.scope_stack) - - def is_in_conditional_within_loop(self) -> bool: - """Check if currently inside a conditional (if) within a loop.""" - in_loop = False - in_conditional = False - - for scope in self.scope_stack: - if scope.scope_type == ScopeType.LOOP: - in_loop = True - elif scope.scope_type in (ScopeType.IF_THEN, ScopeType.IF_ELSE) and in_loop: - in_conditional = True - - return in_loop and in_conditional - - def mark_resource_returned(self, qreg_name: str) -> None: - """Mark a resource as returned from current scope.""" - if self.current_scope: - self.current_scope.returned_resources.add(qreg_name) - - def get_unconsumed_resources(self) -> dict[str, set[int]]: - """Get all unconsumed quantum resources in current scope.""" - unconsumed = {} - - # Look through all variables in current context - context = self.current_context - for var_name, var_info in context.variables.items(): - if ( - (var_info.var_type == "quantum" and var_info.state != ResourceState.CONSUMED) - and var_info.is_array - and var_info.size - ): - # Check which indices are consumed - consumed_indices = set() - if self.current_scope and var_name in self.current_scope.resource_usage: - usage = self.current_scope.resource_usage[var_name] - consumed_indices = set(range(var_info.size)) if usage.is_consumed else usage.indices - - # Find unconsumed indices - all_indices = set(range(var_info.size)) - unconsumed_indices = all_indices - consumed_indices - - if unconsumed_indices: - unconsumed[var_name] = unconsumed_indices - - return unconsumed - - def analyze_conditional_branches( - self, - then_scope: ScopeInfo, - else_scope: ScopeInfo | None, - context: ScopeContext | None = None, - ) -> dict[str, set[int]]: - """Analyze resource consumption across conditional branches.""" - # Get resources consumed in then branch - then_consumed = {} - for res_name, usage in then_scope.resource_usage.items(): - if usage.is_consumed: - # Get actual array size from context if available - if context: - var_info = context.lookup_variable(res_name) - if var_info and var_info.size: - then_consumed[res_name] = set(range(var_info.size)) - else: - then_consumed[res_name] = set(range(_MAX_ARRAY_SIZE_FALLBACK)) - else: - then_consumed[res_name] = set(range(_MAX_ARRAY_SIZE_FALLBACK)) - elif usage.indices: - then_consumed[res_name] = usage.indices - - # Get resources consumed in else branch (if exists) - else_consumed = {} - if else_scope: - for res_name, usage in else_scope.resource_usage.items(): - if usage.is_consumed: - # Get actual array size from context if available - if context: - var_info = context.lookup_variable(res_name) - if var_info and var_info.size: - else_consumed[res_name] = set(range(var_info.size)) - else: - else_consumed[res_name] = set( - range(_MAX_ARRAY_SIZE_FALLBACK), - ) - else: - else_consumed[res_name] = set(range(_MAX_ARRAY_SIZE_FALLBACK)) - elif usage.indices: - else_consumed[res_name] = usage.indices - - # Find resources that need to be balanced - all_resources = set(then_consumed.keys()) | set(else_consumed.keys()) - unbalanced = {} - - for res_name in all_resources: - then_indices = then_consumed.get(res_name, set()) - else_indices = else_consumed.get(res_name, set()) - - if then_indices != else_indices: - # Find indices consumed in one branch but not the other - else_indices - then_indices - missing_in_else = then_indices - else_indices - - # For now, we track indices missing in else branch - # (consumed in then but not else) - if missing_in_else: - unbalanced[res_name] = missing_in_else - - return unbalanced diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/unified_resource_planner.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/unified_resource_planner.py deleted file mode 100644 index 9c42cc813..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/unified_resource_planner.py +++ /dev/null @@ -1,547 +0,0 @@ -"""Unified resource planning framework for Guppy code generation. - -This module provides a holistic approach to resource management by combining: -1. Array unpacking decisions (rule-based from unpacking_rules.py) -2. Local allocation analysis (computed directly from usage patterns) -3. Data flow analysis (precise element-level tracking from data_flow.py) - -The unified planner makes coordinated decisions about BOTH unpacking and allocation, -eliminating conflicts and enabling cross-cutting optimizations. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from enum import Enum, auto -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pecos.slr import Block as SLRBlock - from pecos.slr.gen_codes.guppy.data_flow import DataFlowAnalysis - from pecos.slr.gen_codes.guppy.ir_analyzer import ArrayAccessInfo, UnpackingPlan - - -class ResourceStrategy(Enum): - """Unified strategy for how to manage a quantum/classical register. - - This combines both unpacking and allocation decisions into a coherent plan. - """ - - # Keep as packed array, pre-allocate all elements - PACKED_PREALLOCATED = auto() - - # Keep as packed array, allocate elements dynamically as needed - PACKED_DYNAMIC = auto() - - # Unpack into individual variables, pre-allocate all - UNPACKED_PREALLOCATED = auto() - - # Unpack into individual variables, allocate some locally - UNPACKED_MIXED = auto() - - # Unpack completely, all elements allocated locally when first used - UNPACKED_LOCAL = auto() - - -class DecisionPriority(Enum): - """Priority level for resource planning decisions.""" - - REQUIRED = auto() # Semantic necessity (would fail otherwise) - RECOMMENDED = auto() # Strong evidence for this approach - OPTIONAL = auto() # Minor benefit - DISCOURAGED = auto() # Minor drawback - FORBIDDEN = auto() # Would cause errors - - -@dataclass -class ResourcePlan: - """Unified plan for managing a single register (array or unpacked). - - This combines unpacking and allocation decisions into a coherent strategy. - """ - - array_name: str - size: int - is_classical: bool - - # Unified strategy - strategy: ResourceStrategy - - # Fine-grained control - elements_to_unpack: set[int] = field(default_factory=set) # Which to unpack - elements_to_allocate_locally: set[int] = field( - default_factory=set, - ) # Which to allocate locally - elements_requiring_replacement: set[int] = field( - default_factory=set, - ) # Which need Prep after Measure - - # Decision reasoning - priority: DecisionPriority = DecisionPriority.OPTIONAL - reasons: list[str] = field(default_factory=list) - evidence: dict[str, any] = field(default_factory=dict) - - @property - def needs_unpacking(self) -> bool: - """Check if this register needs to be unpacked.""" - return self.strategy in ( - ResourceStrategy.UNPACKED_PREALLOCATED, - ResourceStrategy.UNPACKED_MIXED, - ResourceStrategy.UNPACKED_LOCAL, - ) - - @property - def uses_dynamic_allocation(self) -> bool: - """Check if this register uses any dynamic allocation.""" - return self.strategy in ( - ResourceStrategy.PACKED_DYNAMIC, - ResourceStrategy.UNPACKED_MIXED, - ResourceStrategy.UNPACKED_LOCAL, - ) - - def get_explanation(self) -> str: - """Get human-readable explanation of the plan.""" - lines = [ - f"Resource Plan for '{self.array_name}' (size={self.size}, " - f"{'classical' if self.is_classical else 'quantum'}):", - f" Strategy: {self.strategy.name}", - f" Priority: {self.priority.name}", - ] - - if self.elements_to_unpack: - lines.append(f" Elements to unpack: {sorted(self.elements_to_unpack)}") - - if self.elements_to_allocate_locally: - lines.append( - f" Local allocation: {sorted(self.elements_to_allocate_locally)}", - ) - - if self.elements_requiring_replacement: - lines.append( - f" Need replacement: {sorted(self.elements_requiring_replacement)}", - ) - - if self.reasons: - lines.append(" Reasons:") - lines.extend(f" - {reason}" for reason in self.reasons) - - return "\n".join(lines) - - -@dataclass -class UnifiedResourceAnalysis: - """Complete resource analysis for a block. - - Contains coordinated plans for all registers. - """ - - plans: dict[str, ResourcePlan] = field(default_factory=dict) - global_recommendations: list[str] = field(default_factory=list) - _original_unpacking_plan: UnpackingPlan | None = field(default=None, repr=False) - - def get_plan(self, array_name: str) -> ResourcePlan | None: - """Get the resource plan for a specific array.""" - return self.plans.get(array_name) - - def set_original_unpacking_plan(self, plan: UnpackingPlan) -> None: - """Store the original UnpackingPlan from IRAnalyzer for backward compatibility.""" - self._original_unpacking_plan = plan - - def get_report(self) -> str: - """Generate comprehensive resource planning report.""" - lines = [ - "=" * 70, - "UNIFIED RESOURCE PLANNING REPORT", - "=" * 70, - "", - ] - - if self.global_recommendations: - lines.append("Global Recommendations:") - lines.extend(f" - {rec}" for rec in self.global_recommendations) - lines.append("") - - for array_name, plan in sorted(self.plans.items()): - lines.append(plan.get_explanation()) - lines.append("") - - lines.extend( - [ - "=" * 70, - f"Total registers analyzed: {len(self.plans)}", - f"Unpacking recommended: {sum(1 for p in self.plans.values() if p.needs_unpacking)}", - f"Dynamic allocation: {sum(1 for p in self.plans.values() if p.uses_dynamic_allocation)}", - "=" * 70, - ], - ) - - return "\n".join(lines) - - def get_unpacking_plan(self) -> UnpackingPlan: - """Get the UnpackingPlan from IRAnalyzer. - - The UnifiedResourcePlanner internally runs IRAnalyzer, so we always - have the original unpacking plan available. - - Returns: - UnpackingPlan from IRAnalyzer with all detailed state preserved - """ - # We always have the original plan because UnifiedResourcePlanner - # runs IRAnalyzer internally during analyze() - if self._original_unpacking_plan is None: - msg = "get_unpacking_plan() called but no original plan available" - raise RuntimeError(msg) - - return self._original_unpacking_plan - - -class UnifiedResourcePlanner: - """Unified planner that coordinates unpacking and allocation decisions. - - This planner integrates: - 1. Data flow analysis (precise element-level tracking) - 2. Unpacking rules (semantic requirements from usage patterns) - 3. Local allocation analysis (computed from consumption & reuse patterns) - - The result is a coordinated ResourcePlan for each register that makes - coherent decisions about both unpacking and allocation. - """ - - def __init__(self): - self.analysis: UnifiedResourceAnalysis | None = None - self.original_unpacking_plan: UnpackingPlan | None = None - - def analyze( - self, - block: SLRBlock, - variable_context: dict[str, any], - *, - array_access_info: dict[str, ArrayAccessInfo] | None = None, - data_flow_analysis: DataFlowAnalysis | None = None, - ) -> UnifiedResourceAnalysis: - """Perform unified resource planning for a block. - - Args: - block: The SLR block to analyze - variable_context: Context of variables in the block - array_access_info: Optional pre-computed array access info from IRAnalyzer - data_flow_analysis: Optional pre-computed data flow analysis - - Returns: - UnifiedResourceAnalysis with coordinated plans for all registers - """ - self.analysis = UnifiedResourceAnalysis() - - # If we don't have the required analyses, compute them now - if array_access_info is None: - from pecos.slr.gen_codes.guppy.ir_analyzer import IRAnalyzer - - analyzer = IRAnalyzer() - plan = analyzer.analyze_block(block, variable_context) - array_access_info = plan.all_analyzed_arrays - # Store the original unpacking plan - self.original_unpacking_plan = plan - - if data_flow_analysis is None: - from pecos.slr.gen_codes.guppy.data_flow import DataFlowAnalyzer - - dfa = DataFlowAnalyzer() - data_flow_analysis = dfa.analyze(block, variable_context) - - # Now perform unified planning for each array - for array_name, access_info in array_access_info.items(): - plan = self._create_unified_plan( - array_name, - access_info, - data_flow_analysis, - ) - self.analysis.plans[array_name] = plan - - # Add global recommendations - self._add_global_recommendations() - - # Store the original unpacking plan in the analysis for get_unpacking_plan() - if self.original_unpacking_plan: - self.analysis.set_original_unpacking_plan(self.original_unpacking_plan) - - return self.analysis - - def _create_unified_plan( - self, - array_name: str, - access_info: ArrayAccessInfo, - data_flow: DataFlowAnalysis, - ) -> ResourcePlan: - """Create a unified resource plan for a single array. - - This is the core decision logic that coordinates unpacking and allocation. - """ - plan = ResourcePlan( - array_name=array_name, - size=access_info.size, - is_classical=access_info.is_classical, - strategy=ResourceStrategy.PACKED_PREALLOCATED, # Default - ) - - # Collect evidence from different analyses - self._collect_evidence(plan, access_info, data_flow) - - # Determine which elements can be allocated locally - self._determine_local_allocation(plan, access_info, data_flow) - - # Make coordinated decision based on all evidence - self._decide_strategy(plan, access_info, data_flow) - - return plan - - def _collect_evidence( - self, - plan: ResourcePlan, - access_info: ArrayAccessInfo, - data_flow: DataFlowAnalysis, - ) -> None: - """Collect evidence from all analyses.""" - evidence = plan.evidence - - # Evidence from array access patterns (counts for decisions) - evidence["has_individual_access"] = access_info.has_individual_access - evidence["all_elements_accessed"] = access_info.all_elements_accessed - evidence["has_full_array_access"] = bool(access_info.full_array_accesses) - evidence["elements_accessed"] = len(access_info.element_accesses) - evidence["elements_consumed"] = len(access_info.elements_consumed) - evidence["has_operations_between"] = access_info.has_operations_between - evidence["has_conditionals"] = access_info.has_conditionals_between - - # Copy element-level information for get_unpacking_plan() - evidence["element_accesses"] = access_info.element_accesses - evidence["elements_consumed_set"] = access_info.elements_consumed - - # Evidence from data flow analysis (element-level precision) - for (arr_name, idx), flow_info in data_flow.element_flows.items(): - if arr_name == plan.array_name and flow_info.has_use_after_consumption(): - plan.elements_requiring_replacement.add(idx) - - # Evidence from conditional tracking (element-level) - conditionally_accessed = set() - for arr_name, idx in data_flow.conditional_accesses: - if arr_name == plan.array_name: - conditionally_accessed.add(idx) - evidence["conditionally_accessed_elements"] = conditionally_accessed - - def _determine_local_allocation( - self, - plan: ResourcePlan, - access_info: ArrayAccessInfo, - _data_flow: DataFlowAnalysis, - ) -> None: - """Determine which elements can be allocated locally. - - Elements can be allocated locally if they are: - - Quantum qubits (classical arrays don't use local allocation) - - Consumed (measured) and not reused - - Not in conditional scopes or loops (single-scope usage) - """ - if plan.is_classical: - return # Classical arrays don't use local allocation - - # Find elements that are consumed and not reused - for idx in access_info.elements_consumed: - # Check if this element is reused after consumption - if idx in plan.elements_requiring_replacement: - continue # This element is reused, can't allocate locally - - # Check if used in conditionals (prevents local allocation) - if idx in access_info.conditionally_accessed_elements: - continue # Conditional usage prevents local allocation - - # This element is a good candidate for local allocation - plan.elements_to_allocate_locally.add(idx) - - def _decide_strategy( - self, - plan: ResourcePlan, - access_info: ArrayAccessInfo, - _data_flow: DataFlowAnalysis, - ) -> None: - """Make unified strategy decision based on collected evidence. - - Decision tree (in priority order): - 1. Check for REQUIRED unpacking (semantic necessity) - 2. Check for FORBIDDEN unpacking (would cause errors) - 3. Check for allocation optimization opportunities - 4. Make quality-based decisions - - Note: Local allocation candidates are already determined in - _determine_local_allocation() and stored in plan.elements_to_allocate_locally - """ - ev = plan.evidence - - # Rule 1: Full array operations FORBID unpacking - if ev["has_full_array_access"]: - plan.strategy = ResourceStrategy.PACKED_PREALLOCATED - plan.priority = DecisionPriority.FORBIDDEN - plan.reasons.append( - "Full array operations require packed representation", - ) - # Clear local allocation - packed arrays don't use it - plan.elements_to_allocate_locally.clear() - return - - # Rule 2: No individual access = no unpacking needed - if not ev["has_individual_access"]: - # Check if allocation optimizer suggests dynamic allocation - if plan.elements_to_allocate_locally: - plan.strategy = ResourceStrategy.PACKED_DYNAMIC - plan.priority = DecisionPriority.RECOMMENDED - plan.reasons.append("Dynamic allocation recommended by optimizer") - else: - plan.strategy = ResourceStrategy.PACKED_PREALLOCATED - plan.priority = DecisionPriority.OPTIONAL - plan.reasons.append("No individual element access detected") - # Clear local allocation - packed arrays don't use it - plan.elements_to_allocate_locally.clear() - return - - # Rule 3: Quantum arrays with operations after measurement REQUIRE unpacking - if not plan.is_classical and ev["has_operations_between"]: - # Check if we can use local allocation - if plan.elements_to_allocate_locally: - plan.strategy = ResourceStrategy.UNPACKED_MIXED - plan.elements_to_unpack = set(range(plan.size)) - # Local elements already determined in _determine_local_allocation() - plan.priority = DecisionPriority.REQUIRED - plan.reasons.append( - "Operations after measurement require unpacking (with local allocation)", - ) - else: - plan.strategy = ResourceStrategy.UNPACKED_PREALLOCATED - plan.elements_to_unpack = set(range(plan.size)) - plan.priority = DecisionPriority.REQUIRED - plan.reasons.append( - "Operations after measurement require unpacking", - ) - return - - # Rule 4: Individual quantum measurements REQUIRE unpacking - if not plan.is_classical and ev["elements_consumed"] > 0: - # Determine unpacking strategy based on allocation - if plan.elements_to_allocate_locally: - # Some elements can be allocated locally - plan.strategy = ResourceStrategy.UNPACKED_MIXED - plan.elements_to_unpack = set(range(plan.size)) - # Local elements already determined in _determine_local_allocation() - plan.priority = DecisionPriority.REQUIRED - plan.reasons.append( - f"Individual quantum measurements require unpacking " - f"({len(plan.elements_to_allocate_locally)} elements local)", - ) - else: - plan.strategy = ResourceStrategy.UNPACKED_PREALLOCATED - plan.elements_to_unpack = set(range(plan.size)) - plan.priority = DecisionPriority.REQUIRED - plan.reasons.append( - "Individual quantum measurements require unpacking", - ) - return - - # Rule 5: Conditional element access REQUIRES unpacking - conditional_elements = ev.get("conditionally_accessed_elements", set()) - if conditional_elements: - # Only unpack elements that are actually accessed (not just in conditionals) - elements_needing_unpack = conditional_elements & access_info.element_accesses - - if elements_needing_unpack: - # Check allocation strategy - if plan.elements_to_allocate_locally: - plan.strategy = ResourceStrategy.UNPACKED_MIXED - # Local elements already determined in _determine_local_allocation() - else: - plan.strategy = ResourceStrategy.UNPACKED_PREALLOCATED - - plan.elements_to_unpack = set(range(plan.size)) - plan.priority = DecisionPriority.REQUIRED - plan.reasons.append( - f"Conditional access to elements {sorted(elements_needing_unpack)} requires unpacking", - ) - return - - # Rule 6: Single element access - prefer direct indexing - if ev["elements_accessed"] == 1: - plan.strategy = ResourceStrategy.PACKED_PREALLOCATED - plan.priority = DecisionPriority.RECOMMENDED - plan.reasons.append( - "Single element access - direct indexing preferred", - ) - return - - # Rule 7: Classical arrays with multiple accesses benefit from unpacking - if plan.is_classical and ev["elements_accessed"] > 1: - plan.strategy = ResourceStrategy.UNPACKED_PREALLOCATED - plan.elements_to_unpack = set(range(plan.size)) - plan.priority = DecisionPriority.RECOMMENDED - plan.reasons.append( - f"Classical array with {ev['elements_accessed']} accesses - unpacking improves readability", - ) - return - - # Rule 9: Partial array usage - if ev["elements_accessed"] > 0 and not ev["all_elements_accessed"]: - access_ratio = ev["elements_accessed"] / plan.size - if access_ratio > 0.5: - plan.strategy = ResourceStrategy.UNPACKED_PREALLOCATED - plan.elements_to_unpack = set(range(plan.size)) - plan.priority = DecisionPriority.OPTIONAL - plan.reasons.append( - f"Partial array usage ({access_ratio:.0%}) - unpacking for clarity", - ) - return - - # Low access ratio - keep as array - plan.strategy = ResourceStrategy.PACKED_PREALLOCATED - plan.priority = DecisionPriority.OPTIONAL - plan.reasons.append( - f"Low access ratio ({access_ratio:.0%}) - keeping as array", - ) - return - - # Default: Keep as packed, pre-allocated (simplest approach) - plan.strategy = ResourceStrategy.PACKED_PREALLOCATED - plan.priority = DecisionPriority.OPTIONAL - plan.reasons.append("Default strategy - no strong evidence for alternatives") - - def _add_global_recommendations(self) -> None: - """Add global recommendations based on overall analysis.""" - if not self.analysis: - return - - # Count strategies - strategy_counts = {} - for plan in self.analysis.plans.values(): - strategy = plan.strategy - strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1 - - # Recommend patterns - total = len(self.analysis.plans) - if total == 0: - return - - unpacked_count = sum(1 for p in self.analysis.plans.values() if p.needs_unpacking) - dynamic_count = sum(1 for p in self.analysis.plans.values() if p.uses_dynamic_allocation) - - if unpacked_count > total * 0.7: - self.analysis.global_recommendations.append( - f"High unpacking ratio ({unpacked_count}/{total}) - " - "consider if element-level APIs would be more natural", - ) - - if dynamic_count > 0: - self.analysis.global_recommendations.append( - f"Dynamic allocation used for {dynamic_count}/{total} registers - ensure proper lifetime management", - ) - - # Check for potential conflicts - required_plans = [p for p in self.analysis.plans.values() if p.priority == DecisionPriority.REQUIRED] - if len(required_plans) == total and total > 1: - self.analysis.global_recommendations.append( - "All registers require unpacking - this may indicate complex control flow", - ) diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/unpacking_rules.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/unpacking_rules.py deleted file mode 100644 index ec6984b81..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/unpacking_rules.py +++ /dev/null @@ -1,237 +0,0 @@ -"""Rule-based decision tree for array unpacking in Guppy code generation. - -This module provides a cleaner, more maintainable approach to deciding when arrays -need to be unpacked, replacing the complex heuristic logic with explicit rules. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from enum import Enum, auto -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pecos.slr.gen_codes.guppy.ir_analyzer import ArrayAccessInfo - - -class UnpackingReason(Enum): - """Enumeration of reasons why an array might need unpacking.""" - - # Required unpacking (semantic necessity) - INDIVIDUAL_QUANTUM_MEASUREMENT = auto() # Measuring individual qubits requires unpacking - OPERATIONS_AFTER_MEASUREMENT = auto() # Using qubits after measurement requires replacement - CONDITIONAL_ELEMENT_ACCESS = auto() # Accessing elements conditionally requires unpacking - - # Optional unpacking (code quality) - MULTIPLE_INDIVIDUAL_ACCESSES = auto() # Multiple element accesses cleaner when unpacked - PARTIAL_ARRAY_USAGE = auto() # Not all elements used together - - # No unpacking needed - FULL_ARRAY_ONLY = auto() # Only full array operations (e.g., measure_array) - SINGLE_ELEMENT_ONLY = auto() # Only one element accessed (use direct indexing) - NO_INDIVIDUAL_ACCESS = auto() # No individual element access - - -class UnpackingDecision(Enum): - """Decision outcome for array unpacking.""" - - MUST_UNPACK = auto() # Semantically required - SHOULD_UNPACK = auto() # Improves code quality - SHOULD_NOT_UNPACK = auto() # Better to keep as array - MUST_NOT_UNPACK = auto() # Would cause errors - - -@dataclass -class DecisionResult: - """Result of unpacking decision with reasoning.""" - - decision: UnpackingDecision - reason: UnpackingReason - explanation: str - - @property - def should_unpack(self) -> bool: - """Whether the array should be unpacked.""" - return self.decision in ( - UnpackingDecision.MUST_UNPACK, - UnpackingDecision.SHOULD_UNPACK, - ) - - -class UnpackingDecisionTree: - """Rule-based decision tree for determining if an array needs unpacking. - - This replaces the complex heuristic logic in ArrayAccessInfo.needs_unpacking - with an explicit, testable decision tree. - - Decision rules are applied in order of priority: - 1. Check for conditions that REQUIRE unpacking (semantic necessity) - 2. Check for conditions that FORBID unpacking (would cause errors) - 3. Check for conditions where unpacking IMPROVES code quality - 4. Default to not unpacking (prefer simpler code) - """ - - def decide(self, info: ArrayAccessInfo) -> DecisionResult: - """Determine if an array should be unpacked based on access patterns. - - Args: - info: Information about how the array is accessed - - Returns: - DecisionResult with the decision and reasoning - """ - # Rule 1: Full array operations forbid unpacking - if info.full_array_accesses: - return DecisionResult( - decision=UnpackingDecision.MUST_NOT_UNPACK, - reason=UnpackingReason.FULL_ARRAY_ONLY, - explanation=( - f"Array '{info.array_name}' has full-array operations " - f"(e.g., measure_array) at positions {info.full_array_accesses}. " - "Unpacking would prevent these operations." - ), - ) - - # Rule 2: No individual access means no unpacking needed - if not info.has_individual_access: - return DecisionResult( - decision=UnpackingDecision.SHOULD_NOT_UNPACK, - reason=UnpackingReason.NO_INDIVIDUAL_ACCESS, - explanation=(f"Array '{info.array_name}' has no individual element access. Keeping as array."), - ) - - # Rule 3: Operations after measurement REQUIRES unpacking (quantum arrays only) - # This is because measured qubits are consumed and need to be replaced - if not info.is_classical and info.has_operations_between: - return DecisionResult( - decision=UnpackingDecision.MUST_UNPACK, - reason=UnpackingReason.OPERATIONS_AFTER_MEASUREMENT, - explanation=( - f"Quantum array '{info.array_name}' has operations on qubits " - "after measurement. This requires unpacking to handle qubit " - "replacement correctly." - ), - ) - - # Rule 4: Individual quantum measurements REQUIRE unpacking - # This avoids MoveOutOfSubscriptError when measuring from array indices - if not info.is_classical and info.elements_consumed: - return DecisionResult( - decision=UnpackingDecision.MUST_UNPACK, - reason=UnpackingReason.INDIVIDUAL_QUANTUM_MEASUREMENT, - explanation=( - f"Quantum array '{info.array_name}' has individual element " - f"measurements (indices: {sorted(info.elements_consumed)}). " - "This requires unpacking to avoid MoveOutOfSubscriptError." - ), - ) - - # Rule 5: Conditional element access REQUIRES unpacking - # Elements accessed in conditionals need to be separate variables - # NEW: Use precise element-level tracking if available - if hasattr(info, "conditionally_accessed_elements") and info.conditionally_accessed_elements: - # Use precise tracking - only unpack if conditionally accessed elements - # are also individually accessed - conditional_and_accessed = info.conditionally_accessed_elements & info.element_accesses - if conditional_and_accessed: - return DecisionResult( - decision=UnpackingDecision.MUST_UNPACK, - reason=UnpackingReason.CONDITIONAL_ELEMENT_ACCESS, - explanation=( - f"Array '{info.array_name}' has elements " - f"{sorted(conditional_and_accessed)} accessed in conditional " - "blocks. This requires unpacking for proper control flow handling." - ), - ) - elif info.has_conditionals_between: - # Fallback to old heuristic if precise tracking not available - return DecisionResult( - decision=UnpackingDecision.MUST_UNPACK, - reason=UnpackingReason.CONDITIONAL_ELEMENT_ACCESS, - explanation=( - f"Array '{info.array_name}' has elements accessed in conditional " - "blocks. This requires unpacking for proper control flow handling." - ), - ) - - # Rule 6: Single element access should use direct indexing (no unpack) - # This avoids PlaceNotUsedError when unpacking all but using only one - if len(info.element_accesses) == 1: - return DecisionResult( - decision=UnpackingDecision.SHOULD_NOT_UNPACK, - reason=UnpackingReason.SINGLE_ELEMENT_ONLY, - explanation=( - f"Array '{info.array_name}' has only one element accessed " - f"(index {next(iter(info.element_accesses))}). " - "Using direct array indexing instead of unpacking." - ), - ) - - # Rule 7: Classical arrays with multiple individual accesses should unpack - # This produces cleaner code (e.g., c0, c1 instead of c[0], c[1]) - if info.is_classical and len(info.element_accesses) > 1: - return DecisionResult( - decision=UnpackingDecision.SHOULD_UNPACK, - reason=UnpackingReason.MULTIPLE_INDIVIDUAL_ACCESSES, - explanation=( - f"Classical array '{info.array_name}' has multiple individual " - f"element accesses ({len(info.element_accesses)} elements). " - "Unpacking produces cleaner code." - ), - ) - - # Rule 8: Partial array usage (not all elements accessed) - # If accessing most elements individually, unpacking may be clearer - if not info.all_elements_accessed and info.has_individual_access: - # Only unpack if accessing a significant portion (> 50%) - access_ratio = len(info.element_accesses) / info.size - if access_ratio > 0.5: - return DecisionResult( - decision=UnpackingDecision.SHOULD_UNPACK, - reason=UnpackingReason.PARTIAL_ARRAY_USAGE, - explanation=( - f"Array '{info.array_name}' has {len(info.element_accesses)} " - f"of {info.size} elements accessed individually " - f"({access_ratio:.0%}). Unpacking for clarity." - ), - ) - return DecisionResult( - decision=UnpackingDecision.SHOULD_NOT_UNPACK, - reason=UnpackingReason.PARTIAL_ARRAY_USAGE, - explanation=( - f"Array '{info.array_name}' has only {len(info.element_accesses)} " - f"of {info.size} elements accessed individually " - f"({access_ratio:.0%}). Keeping as array." - ), - ) - - # Default: Don't unpack (prefer simpler code) - return DecisionResult( - decision=UnpackingDecision.SHOULD_NOT_UNPACK, - reason=UnpackingReason.NO_INDIVIDUAL_ACCESS, - explanation=( - f"Array '{info.array_name}' does not meet criteria for unpacking. Keeping as array for simpler code." - ), - ) - - -def should_unpack_array(info: ArrayAccessInfo, *, verbose: bool = False) -> bool: - """Convenience function to determine if an array should be unpacked. - - Args: - info: Information about how the array is accessed - verbose: If True, print the decision reasoning - - Returns: - True if the array should be unpacked, False otherwise - """ - decision_tree = UnpackingDecisionTree() - result = decision_tree.decide(info) - - if verbose: - print(f"Array '{info.array_name}' unpacking decision:") - print(f" Decision: {result.decision.name}") - print(f" Reason: {result.reason.name}") - print(f" Explanation: {result.explanation}") - - return result.should_unpack diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/variable_state.py b/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/variable_state.py deleted file mode 100644 index f5805f6ce..000000000 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/guppy/variable_state.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Unified variable-state tracking for the Guppy IR generator. - -The Guppy generator translates SLR programs (high-level quantum DSL) to -Guppy source. Guppy uses linear types: every qubit must be used exactly -once, and arrays-of-qubits get "moved" into and out of operations rather -than mutated in place. - -Translating SLR to Guppy means tracking, for each SLR variable, *what -Guppy variable currently holds it*. The form changes over the lifetime -of the SLR variable -- it might be a whole array, get unpacked into -element variables for individual access, get refreshed by a function -return, get partially consumed, etc. - -Historically the IRGuppyGenerator did this with ~6+ separate dicts -(`unpacked_vars`, `refreshed_arrays`, `array_remapping`, `index_mapping`, -`variable_remapping`, `function_var_remapping`, `replaced_qubits`, -`fresh_variables_to_track`, ...). Different code generation sites -consult different subsets of these dicts; sites that miss a state -transition emit Guppy that violates linearity ("AlreadyUsedError", -"WrongNumberOfUnpacksError", etc.). - -This module replaces that with one model: each SLR variable has a -*current binding* describing its Guppy form right now. Operations on the -variable consult the binding; transitions update it. Code-generation -sites that need the variable in a particular form call helpers like -`ensure_whole()` which emit reconstruction statements transparently. - -The migration is incremental. While the legacy dicts still exist, this -module shadows them: writes go to both, reads prefer this module. Once -all read sites are migrated, the legacy dicts can be removed. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field - - -@dataclass(frozen=True) -class WholeArray: - """SLR variable is currently bound to a single Guppy array variable. - - `guppy_name` is the live identifier; subsequent ops can reference - `guppy_name` directly or index into it via `guppy_name[i]`. - """ - - guppy_name: str - - -@dataclass(frozen=True) -class UnpackedArray: - """SLR variable was unpacked into per-element Guppy variables. - - `element_names[i]` is the Guppy variable for original SLR index - `i` -- unless `index_mapping` is set, in which case mapping - `original_index -> position_in_element_names` is used (this happens - when a function call returned a partially-consumed array). - """ - - element_names: tuple[str, ...] - index_mapping: tuple[tuple[int, int], ...] = () # (orig_idx, position) - - def position_for(self, original_index: int) -> int | None: - """Return the position in `element_names` for an SLR index. - - With no `index_mapping`, returns `original_index` directly when in - bounds. With a mapping, looks up the position; returns None for - SLR indices that aren't present in the partial array. - """ - if not self.index_mapping: - return original_index if original_index < len(self.element_names) else None - for orig, pos in self.index_mapping: - if orig == original_index: - return pos - return None - - -@dataclass(frozen=True) -class Consumed: - """SLR variable is fully consumed; subsequent references are bugs. - - `reason` is a short human-readable note for diagnostics ("measured", - "passed to function as @owned", etc.). - """ - - reason: str = "" - - -Binding = WholeArray | UnpackedArray | Consumed - - -@dataclass -class VariableState: - """Current Guppy bindings for SLR variables in one generation context. - - A "context" is typically one Guppy function being generated -- the - main function or one of the extracted sub-block functions. Bindings - are local to a context; the same SLR variable name in different - contexts can have different bindings. - """ - - bindings: dict[str, Binding] = field(default_factory=dict) - - def bind_whole(self, slr_name: str, guppy_name: str) -> None: - """Record that `slr_name` is currently held by Guppy var `guppy_name`.""" - self.bindings[slr_name] = WholeArray(guppy_name) - - def bind_unpacked( - self, - slr_name: str, - element_names: list[str], - index_mapping: dict[int, int] | None = None, - ) -> None: - """Record that `slr_name` was unpacked into per-element Guppy vars.""" - mapping_tuple = tuple(sorted(index_mapping.items())) if index_mapping else () - self.bindings[slr_name] = UnpackedArray(tuple(element_names), mapping_tuple) - - def bind_consumed(self, slr_name: str, reason: str = "") -> None: - """Record that `slr_name` is no longer accessible.""" - self.bindings[slr_name] = Consumed(reason) - - def get(self, slr_name: str) -> Binding | None: - """Return current binding, or None if `slr_name` is unknown here.""" - return self.bindings.get(slr_name) - - def is_unpacked(self, slr_name: str) -> bool: - """True iff `slr_name` is currently in unpacked form.""" - return isinstance(self.bindings.get(slr_name), UnpackedArray) - - def is_consumed(self, slr_name: str) -> bool: - """True iff `slr_name` has been consumed.""" - return isinstance(self.bindings.get(slr_name), Consumed) - - def ensure_whole(self, slr_name: str) -> tuple[list[str], str | None]: - """Ensure `slr_name` is bound as a whole array; emit prep code if not. - - Returns (preparation_lines, guppy_name). The caller emits the - preparation_lines (Guppy source as `array(elem_0, elem_1, ...)` - repacking) before whatever it does with `guppy_name`. Returns - ([], guppy_name) when already whole. Returns ([], None) when - `slr_name` is consumed or unknown -- caller should treat as a - programming error. - - After repack, the binding is updated to WholeArray so subsequent - callers don't repack again. - """ - binding = self.bindings.get(slr_name) - if isinstance(binding, WholeArray): - return [], binding.guppy_name - if isinstance(binding, UnpackedArray): - elements = ", ".join(binding.element_names) - line = f"{slr_name} = array({elements})" - self.bindings[slr_name] = WholeArray(slr_name) - return [line], slr_name - return [], None diff --git a/python/quantum-pecos/src/pecos/slr/gen_codes/qir_gate_mapping.py b/python/quantum-pecos/src/pecos/slr/gen_codes/qir_gate_mapping.py index f0fbfd9c0..444686b14 100644 --- a/python/quantum-pecos/src/pecos/slr/gen_codes/qir_gate_mapping.py +++ b/python/quantum-pecos/src/pecos/slr/gen_codes/qir_gate_mapping.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING import pecos as pc +from pecos.slr.angle import rad from pecos.slr.qeclib import qubit as q if TYPE_CHECKING: @@ -55,7 +56,7 @@ def __init__(self, gate: QG): R1XY = QG("u1q") RZZ = QG("rzz") SZZ = QG("zz") - Prep = QG("reset") + PZ = QG("reset") # These dagger/adjoint gates are generated with slightly different names Sdg = QG("s__adj") @@ -68,22 +69,22 @@ def __init__(self, gate: QG): SX = QG.decompose( lambda sx: [ - q.RX[pc.f64.frac_pi_2](sx.qargs[0]), + q.RX(rad(pc.f64.frac_pi_2), sx.qargs[0]), ], ) SXdg = QG.decompose( lambda sxdg: [ - q.RX[-pc.f64.frac_pi_2](sxdg.qargs[0]), + q.RX(rad(-pc.f64.frac_pi_2), sxdg.qargs[0]), ], ) SY = QG.decompose( lambda sy: [ - q.RY[pc.f64.frac_pi_2](sy.qargs[0]), + q.RY(rad(pc.f64.frac_pi_2), sy.qargs[0]), ], ) SYdg = QG.decompose( lambda sydg: [ - q.RY[-pc.f64.frac_pi_2](sydg.qargs[0]), + q.RY(rad(-pc.f64.frac_pi_2), sydg.qargs[0]), ], ) @@ -112,7 +113,7 @@ def __init__(self, gate: QG): lambda f4dg: [ q.SXdg(f4dg.qargs[0]), # q.SZdg(f4dg.qargs[0]), - q.RZ[-pc.f64.frac_pi_2](f4dg.qargs[0]), + q.RZ(rad(-pc.f64.frac_pi_2), f4dg.qargs[0]), ], ) diff --git a/python/quantum-pecos/src/pecos/slr/misc.py b/python/quantum-pecos/src/pecos/slr/misc.py index a1f40322b..46bf87a4e 100644 --- a/python/quantum-pecos/src/pecos/slr/misc.py +++ b/python/quantum-pecos/src/pecos/slr/misc.py @@ -10,13 +10,17 @@ # specific language governing permissions and limitations under the License. from __future__ import annotations +import re from typing import TYPE_CHECKING +from pecos.slr.block import Block +from pecos.slr.fund import Statement +from pecos.slr.vars import Bit, CReg, SymbolicBit + if TYPE_CHECKING: from pecos.slr.vars import Elem, QReg, Qubit, Reg -from pecos.slr.block import Block -from pecos.slr.fund import Statement +_TAG_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") class Barrier(Statement): @@ -89,3 +93,82 @@ def __init__(self, *return_vars) -> None: *return_vars: Variables to return, in order. Can be QReg, Qubit, Bit, or other variables. """ self.return_vars = return_vars + + +class Print(Statement): + """Emit an intermediate streamed value at the call site. + + Lowers to Guppy's ``result(name, value)``. Scope-orthogonal side-effect: + does not touch block ownership or compile-time return shape. + + The emitted Guppy tag is ``f"{namespace}.{tag}"``. Default namespace is + ``"result"``. If ``tag`` is not provided it is derived from ``value``'s + name (CReg name, or ``f"{reg}_{index}"`` for a Bit). + + Args: + value: A CReg or Bit (CReg element). Only these are supported. + Expression values (e.g. ``c[0] ^ c[1]``), SymbolicBit, and other + types are rejected at construction time. + tag: Explicit tag string overriding the derived name. Must match + ``[A-Za-z_][A-Za-z0-9_]*`` (Python identifier rules). + namespace: Tag prefix. Default ``"result"``. Must match + ``[A-Za-z_][A-Za-z0-9_]*``. + + Example: + Main( + c := CReg("c", 2), + ..., + Print(c), # tag "result.c" + Print(c[0], tag="first"), # tag "result.first" + Print(c, namespace="debug"), # tag "debug.c" + ) + """ + + def __init__(self, value, *, tag: str | None = None, namespace: str = "result") -> None: + """Construction-time validation per `v2-print.md`. + + Validates value type, tag/namespace character rules, and derives the + default tag from the value's name. AST/Guppy-level checks (path- + signature consistency in If/Elif, inline-CReg definite-assignment) + run later during emission. + """ + if isinstance(value, SymbolicBit): + msg = "Print does not support SymbolicBit (LoopVar-indexed) values." + raise TypeError(msg) + if not isinstance(value, (CReg, Bit)): + msg = ( + f"Print(value, ...) requires a CReg or Bit value; got {type(value).__name__}. " + "Expression values (e.g. c[0] ^ c[1]) are deferred and must be passed with explicit tag=...; " + "Only CReg and Bit values are supported." + ) + raise TypeError(msg) + + if not _TAG_RE.match(namespace): + msg = ( + f"Print namespace {namespace!r} must match [A-Za-z_][A-Za-z0-9_]* " + "(Python identifier rules). The dot is reserved as the namespace-tag separator." + ) + raise ValueError(msg) + + if tag is None: + tag = self._derive_tag(value) + if not _TAG_RE.match(tag): + msg = ( + f"Print tag {tag!r} must match [A-Za-z_][A-Za-z0-9_]* " + "(Python identifier rules). The dot is reserved as the namespace-tag separator. " + "Tags derived from non-identifier register names are rejected; pass tag=... explicitly." + ) + raise ValueError(msg) + + self.value = value + self.tag = tag + self.namespace = namespace + + @staticmethod + def _derive_tag(value: CReg | Bit) -> str: + if isinstance(value, CReg): + return value.sym + if isinstance(value, Bit): + return f"{value.reg.sym}_{value.index}" + msg = f"Cannot derive Print tag from {type(value).__name__}" + raise TypeError(msg) diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/color488/color488_class.py b/python/quantum-pecos/src/pecos/slr/qeclib/color488/color488_class.py index b89ffe09b..4086bc7fb 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/color488/color488_class.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/color488/color488_class.py @@ -102,7 +102,7 @@ def prep_z_bare(self, syndromes: list[CReg]) -> Block: """ block = Block() block.extend( - qb.Prep(self.d), + qb.PZ(self.d), ) for s in syndromes: block.extend( diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/color488/syn_extract/bare.py b/python/quantum-pecos/src/pecos/slr/qeclib/color488/syn_extract/bare.py index 35f9ecce7..ef3ebe035 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/color488/syn_extract/bare.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/color488/syn_extract/bare.py @@ -1,9 +1,9 @@ """Bare syndrome extraction implementations for Color488 codes.""" +import math from itertools import chain, cycle, repeat from typing import Any -import pecos as pc from pecos.slr import Block, Comment, CReg, Parallel, QReg from pecos.slr.qeclib.generic.check import Check @@ -89,7 +89,7 @@ def __init__(self, data: QReg, ancillas: QReg, checks: list, syn: CReg) -> None: super().__init__() annotations = Block() - num_parallel_blocks = 2 * pc.ceil(len(checks) / len(ancillas)) + num_parallel_blocks = 2 * math.ceil(len(checks) / len(ancillas)) par_blocks = [Parallel() for _ in range(num_parallel_blocks)] # iterator for parallelizing circuits for one round of ancilla use diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/generic/check.py b/python/quantum-pecos/src/pecos/slr/qeclib/generic/check.py index 6c747f6ba..762abf74c 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/generic/check.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/generic/check.py @@ -18,10 +18,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from pecos.slr import Barrier, Block, Comment -from pecos.slr.qeclib.qubit import CX, CY, CZ, H, Measure, Prep +from pecos.slr.qeclib.qubit import CX, CY, CZ, PZ, H, Measure if TYPE_CHECKING: from pecos.slr import Bit, Qubit @@ -35,6 +35,17 @@ class Check(Block): a sequence of Pauli operators to data qubits controlled by an ancilla qubit. """ + # Scratch-ancilla effect: emit BlockDecl + BlockCall instead + # of inlining. Data qubits pass through unchanged; `a` is a reset-reused + # scratch ancilla (prepped + measured inside, allocated internally in + # Guppy so callers like SynExtractBare can reuse one physical slot across + # sequential Checks); `out` is the live_preserved measurement-result bit. + block_inputs: ClassVar[dict[str, str]] = { + "d": "live_preserved", + "a": "scratch", + "out": "live_preserved", + } + def __init__( self, d: list[Qubit], @@ -61,6 +72,11 @@ def __init__( Exception: If invalid Pauli operator is specified. """ super().__init__() + # SLR -> AST converter reads these to bind block_inputs param names + # to outer-scope refs. + self.d = d + self.a = a + self.out = out n: int = len(d) @@ -83,7 +99,7 @@ def __init__( self.extend( Comment(f"Measure check {ps}"), - Prep(a), + PZ(a), H(a), ) diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/generic/check_1flag.py b/python/quantum-pecos/src/pecos/slr/qeclib/generic/check_1flag.py index 1118c68cc..eb0bbb87f 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/generic/check_1flag.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/generic/check_1flag.py @@ -18,10 +18,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from pecos.slr import Barrier, Block, Comment -from pecos.slr.qeclib.qubit import CH, CX, CY, CZ, H, Measure, Prep +from pecos.slr.qeclib.qubit import CH, CX, CY, CZ, PZ, H, Measure if TYPE_CHECKING: from pecos.slr import Bit, Qubit @@ -35,6 +35,22 @@ class Check1Flag(Block): to detect errors during the syndrome extraction process. """ + # Scratch-ancilla effect: emit BlockDecl + BlockCall. + # `a` and `flag` are both reset-reused scratch ancillas (prepped + + # measured inside, allocated internally in Guppy so callers can reuse + # physical slots across sequential checks); data passes through; the two + # out bits are live_preserved measurement-result write-backs. The body's + # single `PZ(a, flag)` is split into `PZ(a), PZ(flag)` (option + # (a) -- byte-identical in QASM/Guppy/Selene, confirmed 2026-05-16; + # one PrepareOp per scratch input avoids multi-destination substitution). + block_inputs: ClassVar[dict[str, str]] = { + "d": "live_preserved", + "a": "scratch", + "flag": "scratch", + "out": "live_preserved", + "out_flag": "live_preserved", + } + def __init__( self, d: list[Qubit], @@ -65,6 +81,13 @@ def __init__( Exception: If invalid operator is specified. """ super().__init__() + # SLR -> AST converter reads these to bind block_inputs param names + # to outer-scope refs. + self.d = d + self.a = a + self.flag = flag + self.out = out + self.out_flag = out_flag n: int = len(d) @@ -85,7 +108,8 @@ def __init__( self.extend( Comment(f"Measure check {ops}"), - Prep(a, flag), + PZ(a), + PZ(flag), H(a), ) if with_barriers: diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/__init__.py b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/__init__.py index 5c5f4343d..718082481 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/__init__.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/__init__.py @@ -15,9 +15,10 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +from pecos.slr.angle import Angle, rad, turns from pecos.slr.qeclib.qubit.measures import Measure -from pecos.slr.qeclib.qubit.preps import Prep -from pecos.slr.qeclib.qubit.rots import CRZ, RX, RY, RZ, RZZ +from pecos.slr.qeclib.qubit.preps import PNX, PNY, PNZ, PX, PY, PZ +from pecos.slr.qeclib.qubit.rots import CRX, CRY, CRZ, RX, RY, RZ, RZZ from pecos.slr.qeclib.qubit.sq_face_rots import F4, F, F4dg, Fdg from pecos.slr.qeclib.qubit.sq_hadamards import H from pecos.slr.qeclib.qubit.sq_noncliffords import T, Tdg @@ -38,11 +39,19 @@ __all__ = [ "CH", + "CRX", + "CRY", "CRZ", "CX", "CY", "CZ", "F4", + "PNX", + "PNY", + "PNZ", + "PX", + "PY", + "PZ", "RX", "RY", "RZ", @@ -53,12 +62,12 @@ "SYY", "SZ", "SZZ", + "Angle", "F", "F4dg", "Fdg", "H", "Measure", - "Prep", "SXXdg", "SXdg", "SYYdg", @@ -70,4 +79,6 @@ "X", "Y", "Z", + "rad", + "turns", ] diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/preps.py b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/preps.py index 614173775..ac42dcbf9 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/preps.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/preps.py @@ -19,5 +19,25 @@ from pecos.slr.qeclib.qubit.qgate_base import QGate -class Prep(QGate): - """Preparing/resetting a qubit to the zero state.""" +class PZ(QGate): + """Prepare/reset a qubit to |0> (+Z eigenstate).""" + + +class PNZ(QGate): + """Prepare/reset a qubit to |1> (-Z eigenstate).""" + + +class PX(QGate): + """Prepare/reset a qubit to |+> (+X eigenstate).""" + + +class PNX(QGate): + """Prepare/reset a qubit to |-> (-X eigenstate).""" + + +class PY(QGate): + """Prepare/reset a qubit to |+i> (+Y eigenstate).""" + + +class PNY(QGate): + """Prepare/reset a qubit to |-i> (-Y eigenstate).""" diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qgate_base.py b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qgate_base.py index fdec19921..aedf5969d 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qgate_base.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qgate_base.py @@ -48,6 +48,11 @@ class QGate: qsize = 1 csize = 0 has_parameters = False + # Number of leading angle parameters for a parameterized gate. The + # SLR call convention is angle(s)-FIRST: `RX(theta, q)`, + # `CRZ(theta, control, target)`, `RZZ(theta, q0, q1)`. A + # parameterized gate must set `num_params` (and `has_parameters`). + num_params = 0 def __init__(self, *qargs: Qubit) -> None: """Initialize a quantum gate. @@ -84,15 +89,19 @@ def copy(self) -> Self: return copy.copy(self) def __getitem__(self, *params: complex) -> Self: - """Set gate parameters using square bracket notation.""" - g = self.copy() - - if params and not self.has_parameters: - msg = "This gate does not accept parameters. You might of meant to put qubits in square brackets." - raise Exception(msg) - g.params = params + """Reject the legacy bracket-parameter form. - return g + The SLR API now takes rotation angles as leading positional + arguments -- ``RX(theta, q)`` -- not via brackets. The old + ``RX[theta](q)`` form is removed (angles-first is the single + supported convention); raise a clear migration error. + """ + msg = ( + f"The bracket-parameter form `{self.sym}[angle](qubit)` is no longer " + f"supported. Pass the angle as a leading positional argument instead: " + f"`{self.sym}(angle, qubit)` (angles come before qubit ids)." + ) + raise TypeError(msg) def qubits(self, *qargs: Qubit) -> None: """Add qubits to the gate. @@ -102,18 +111,100 @@ def qubits(self, *qargs: Qubit) -> None: """ self(*qargs) - def __call__(self, *qargs: Qubit) -> Self: - """Create a new gate instance with specified qubits. + def __call__(self, *args: Qubit | complex) -> Self: + """Create a new gate instance from angle(s) and qubit(s). + + For a parameterized gate the first `num_params` arguments are + the rotation angle(s); the remaining arguments are the qubits: + `RX(theta, q)`, `CRZ(theta, control, target)`. For a + non-parameterized gate every argument is a qubit. Args: - *qargs: Variable number of qubits to apply the gate to. + *args: `num_params` leading angle parameter(s) (if any) + followed by the qubit(s) the gate acts on. Returns: - New gate instance with the specified qubits. + New gate instance with the specified params + qubits. """ g = self.copy() - g.add_qargs(qargs) + if self.has_parameters: + n = self.num_params + if len(args) < n: + msg = ( + f"{self.sym} is a parameterized gate; call it as " + f"`{self.sym}(angle, qubit...)` with {n} leading angle " + f"parameter(s) before the qubit(s). Got {len(args)} argument(s)." + ) + raise TypeError(msg) + params = tuple(args[:n]) + qargs = tuple(args[n:]) + # Typed-angle guard: each angle slot must be a typed `Angle` + # (built with `rad(...)` / `turns(...)`), and each qubit slot + # must be a quantum qubit shape. This rejects the classic + # mis-ordered call (`RX(q, 0.5)` instead of `RX(rad(0.5), q)`) + # AND the now-removed bare-float form (`RX(0.5, q)`) loudly at + # the call, so a typo can never reach codegen as a no-op or as + # a rotation on a classical register. Qubit slots accept ONLY + # `Qubit`/`QReg`/`SymbolicQubit` -- NOT the broad `Var` (which + # also covers classical `CReg`/`Bit`/`SymbolicBit`). + from pecos.slr.angle import Angle # noqa: PLC0415 (avoid import cycle) + from pecos.slr.vars import QReg, Qubit, SymbolicQubit, Var # noqa: PLC0415 (avoid import cycle) + + qubit_types = (Qubit, QReg, SymbolicQubit) + for p in params: + if isinstance(p, Angle): + continue + if isinstance(p, Var): + msg = ( + f"{self.sym}: a register/qubit reference {p!r} was passed in an angle " + f"position. Call as `{self.sym}(angle, qubit...)` -- angles come before qubit " + "ids, and the angle must be a typed `Angle` (use `rad(...)` / `turns(...)`)." + ) + raise TypeError(msg) + if isinstance(p, (bool, int, float, complex)): + msg = ( + f"{self.sym}: bare numeric angle {p!r} is no longer accepted. Wrap it in a " + f"typed `Angle`: `{self.sym}(rad({p}), qubit...)` (radians) or " + f"`{self.sym}(turns(...), qubit...)`." + ) + raise TypeError(msg) + msg = ( + f"{self.sym}: angle parameter {p!r} must be a typed `Angle` " + "built with `rad(...)` / `turns(...)`." + ) + raise TypeError(msg) + for qa in qargs: + if not isinstance(qa, qubit_types): + kind = "classical register/bit" if isinstance(qa, Var) else "non-qubit" + msg = ( + f"{self.sym}: a {kind} {qa!r} was passed in a qubit position. " + f"Call as `{self.sym}(angle, qubit...)` with {n} leading angle parameter(s); " + "qubit positions accept only qubits/QRegs." + ) + raise TypeError(msg) + # Construction-time arity guard: a parameterized call must + # supply enough qubits, so a malformed `gate(angle)` (no + # qubit) or `RZZ(angle, q[0])` (one short) fails loud here + # rather than surviving to QIR/QASM. A whole `QReg` broadcasts + # (its size is only known on expansion), so the explicit-qubit + # count is only checked when no register is passed. + if not qargs: + msg = ( + f"{self.sym}: a parameterized gate needs at least one qubit; got only angle(s). " + f"Call as `{self.sym}(angle, qubit...)`." + ) + raise TypeError(msg) + if not any(isinstance(qa, QReg) for qa in qargs) and len(qargs) < self.qsize: + msg = ( + f"{self.sym}: needs at least {self.qsize} qubit(s) for a {self.qsize}-qubit gate, " + f"got {len(qargs)}. (Pass a whole QReg to broadcast.)" + ) + raise TypeError(msg) + g.params = params + g.add_qargs(qargs) + else: + g.add_qargs(args) return g diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qubit.py b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qubit.py index 75faad679..a0e9c0240 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qubit.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/qubit.py @@ -63,9 +63,9 @@ def cz(*qargs: Qubit) -> tq_cliffords.CZ: return tq_cliffords.CZ(*qargs) @staticmethod - def pz(*qargs: Qubit) -> preps.Prep: - """Measurement gate.""" - return preps.Prep(*qargs) + def pz(*qargs: Qubit) -> preps.PZ: + """Prepare/reset to |0> (Z basis).""" + return preps.PZ(*qargs) @staticmethod def mz( diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/rots.py b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/rots.py index 96d70a939..c57a565f0 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/qubit/rots.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/qubit/rots.py @@ -26,6 +26,7 @@ class RXGate(QGate): """ has_parameters = True + num_params = 1 RX = RXGate() @@ -38,6 +39,7 @@ class RYGate(QGate): """ has_parameters = True + num_params = 1 RY = RYGate() @@ -50,6 +52,7 @@ class RZGate(QGate): """ has_parameters = True + num_params = 1 RZ = RZGate() @@ -62,11 +65,40 @@ class RZZGate(TQGate): """ has_parameters = True + num_params = 1 RZZ = RZZGate() +class CRXGate(TQGate): + """Controlled-RX gate. + + This gate applies an RX rotation to the target qubit controlled by + the control qubit. The rotation angle is specified as a parameter. + """ + + has_parameters = True + num_params = 1 + + +CRX = CRXGate() + + +class CRYGate(TQGate): + """Controlled-RY gate. + + This gate applies an RY rotation to the target qubit controlled by + the control qubit. The rotation angle is specified as a parameter. + """ + + has_parameters = True + num_params = 1 + + +CRY = CRYGate() + + class CRZGate(TQGate): """Controlled-RZ gate. @@ -75,6 +107,7 @@ class CRZGate(TQGate): """ has_parameters = True + num_params = 1 CRZ = CRZGate() diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/hadamards.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/hadamards.py index aad20599d..5a05ca30e 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/hadamards.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/hadamards.py @@ -15,6 +15,8 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +from typing import ClassVar + from pecos.slr import Block, Comment, QReg from pecos.slr.qeclib import qubit @@ -28,6 +30,8 @@ class H(Block): Y -> -Y """ + block_inputs: ClassVar[dict[str, str]] = {"q": "live_preserved"} + def __init__(self, q: QReg) -> None: """Initialize a logical Hadamard gate on the Steane code. @@ -42,7 +46,9 @@ def __init__(self, q: QReg) -> None: msg = f"Size of register {len(q.elems)} != 7" raise Exception(msg) - super().__init__( + super().__init__() + self.q = q + self.extend( Comment("Logical H"), qubit.H(q), ) diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/paulis.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/paulis.py index b34ac9591..4ff48c79c 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/paulis.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_sq/paulis.py @@ -15,6 +15,8 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +from typing import ClassVar + from pecos.slr import Block, Comment, QReg from pecos.slr.qeclib import qubit @@ -28,6 +30,8 @@ class X(Block): Y -> -Y """ + block_inputs: ClassVar[dict[str, str]] = {"q": "live_preserved"} + def __init__(self, q: QReg) -> None: """Initialize a logical Pauli X gate on the Steane code. @@ -42,7 +46,9 @@ def __init__(self, q: QReg) -> None: msg = f"Size of register {len(q.elems)} != 7" raise Exception(msg) - super().__init__( + super().__init__() + self.q = q + self.extend( Comment("Logical X"), qubit.X(q[4]), qubit.X(q[5]), @@ -59,6 +65,8 @@ class Y(Block): Y -> Y """ + block_inputs: ClassVar[dict[str, str]] = {"q": "live_preserved"} + def __init__(self, q: QReg) -> None: """Initialize a logical Pauli Y gate on the Steane code. @@ -73,7 +81,9 @@ def __init__(self, q: QReg) -> None: msg = f"Size of register {len(q.elems)} != 7" raise Exception(msg) - super().__init__( + super().__init__() + self.q = q + self.extend( Comment("Logical Y"), qubit.Y(q[4]), qubit.Y(q[5]), @@ -90,6 +100,8 @@ class Z(Block): Y -> -Y """ + block_inputs: ClassVar[dict[str, str]] = {"q": "live_preserved"} + def __init__(self, q: QReg) -> None: """Initialize a logical Pauli Z gate on the Steane code. @@ -104,7 +116,9 @@ def __init__(self, q: QReg) -> None: msg = f"Size of register {len(q.elems)} != 7" raise Exception(msg) - super().__init__( + super().__init__() + self.q = q + self.extend( Comment("Logical Z"), qubit.Z(q[4]), qubit.Z(q[5]), diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_tq/transversal_tq.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_tq/transversal_tq.py index 6b1ed60b7..820b863ea 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_tq/transversal_tq.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/gates_tq/transversal_tq.py @@ -15,6 +15,8 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +from typing import ClassVar + from pecos.slr import Barrier, Block, Comment, QReg from pecos.slr.qeclib import qubit @@ -26,6 +28,14 @@ class CX(Block): two logical qubits encoded in the Steane code. """ + # Declare resource effects so SLR -> AST emits BlockDecl + BlockCall + # instead of inlining. Both registers are live_preserved -- the body applies a + # transversal CX pairwise (no internal measurements). + block_inputs: ClassVar[dict[str, str]] = { + "q1": "live_preserved", + "q2": "live_preserved", + } + def __init__(self, q1: QReg, q2: QReg, *, barrier: bool = True) -> None: """Initialize a transversal logical CX gate on two Steane code logical qubits. @@ -47,6 +57,9 @@ def __init__(self, q1: QReg, q2: QReg, *, barrier: bool = True) -> None: raise Exception(msg) super().__init__() + # SLR -> AST converter reads these to bind block_inputs param names to outer-scope refs. + self.q1 = q1 + self.q2 = q2 self.extend( Comment("Transversal Logical CX"), ) @@ -78,6 +91,11 @@ class CY(Block): two logical qubits encoded in the Steane code. """ + block_inputs: ClassVar[dict[str, str]] = { + "q1": "live_preserved", + "q2": "live_preserved", + } + def __init__(self, q1: QReg, q2: QReg) -> None: """Initialize a transversal logical CY gate on two Steane code logical qubits. @@ -110,6 +128,8 @@ def __init__(self, q1: QReg, q2: QReg) -> None: ), Barrier(q1, q2), ) + self.q1 = q1 + self.q2 = q2 class CZ(Block): @@ -119,6 +139,11 @@ class CZ(Block): two logical qubits encoded in the Steane code. """ + block_inputs: ClassVar[dict[str, str]] = { + "q1": "live_preserved", + "q2": "live_preserved", + } + def __init__(self, q1: QReg, q2: QReg) -> None: """Initialize a transversal logical CZ gate on two Steane code logical qubits. @@ -151,6 +176,8 @@ def __init__(self, q1: QReg, q2: QReg) -> None: ), Barrier(q1, q2), ) + self.q1 = q1 + self.q2 = q2 class SZZ(Block): diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_x.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_x.py index 58e5013b4..3b0b75993 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_x.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_x.py @@ -38,7 +38,7 @@ def __init__(self, d: list[Qubit], a: QReg, out: CReg) -> None: self.extend( Comment("Measure logical X with no flagging"), - qubit.Prep(a[0]), + qubit.PZ(a[0]), qubit.H(a[0]), qubit.CX( (d[0], a[0]), diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_z.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_z.py index 95de90d03..40e48d038 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_z.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/meas/measure_z.py @@ -38,7 +38,7 @@ def __init__(self, d: list[Qubit], a: QReg, out: CReg) -> None: self.extend( Comment("Measure logical Z with no flagging"), - qubit.Prep(a[0]), + qubit.PZ(a[0]), qubit.H(a[0]), qubit.CZ( (d[0], a[0]), diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/encoding_circ.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/encoding_circ.py index e7e222212..68d866c00 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/encoding_circ.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/encoding_circ.py @@ -46,7 +46,7 @@ def __init__(self, q: QReg) -> None: self.extend( Comment("\nEncoding circuit"), Comment("---------------"), - qubit.Prep( + qubit.PZ( q[0], q[1], q[2], diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/pauli_states.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/pauli_states.py index af0e34ecf..7f3774516 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/pauli_states.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/pauli_states.py @@ -18,7 +18,7 @@ from pecos.slr import Barrier, Bit, Block, Comment, If, QReg, Qubit, Repeat from pecos.slr.misc import Return from pecos.slr.qeclib import qubit -from pecos.slr.qeclib.qubit import Prep +from pecos.slr.qeclib.qubit import PZ from pecos.slr.qeclib.steane.gates_sq import sqrt_paulis from pecos.slr.qeclib.steane.gates_sq.hadamards import H from pecos.slr.qeclib.steane.gates_sq.paulis import X, Z @@ -102,7 +102,7 @@ def __init__( if reset_ancilla: self.extend( Comment(), - Prep(a), + PZ(a), ) self.extend( @@ -165,8 +165,8 @@ def __init__( if reset: self.extend( - Prep(q), - Prep(a), + PZ(q), + PZ(a), Barrier(q, a), ) diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/plus_h_state.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/plus_h_state.py index db3b946a4..18d3aaaa9 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/plus_h_state.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/plus_h_state.py @@ -17,6 +17,7 @@ import pecos as pc from pecos.slr import Bit, Block, Comment, CReg, If, QReg, Repeat +from pecos.slr.angle import rad from pecos.slr.misc import Return from pecos.slr.qeclib import qubit from pecos.slr.qeclib.generic.check_1flag import Check1Flag @@ -81,8 +82,8 @@ def __init__( # non-fault-tolerantly encode logical |+H> # ---------------------------------------- self.extend( - qubit.Prep(d[6]), - qubit.RY[pc.f64.frac_pi_4](d[6]), + qubit.PZ(d[6]), + qubit.RY(rad(pc.f64.frac_pi_4), d[6]), EncodingCircuit(d), ) diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/t_plus_state.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/t_plus_state.py index c6061319c..7e8e1a18b 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/t_plus_state.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/preps/t_plus_state.py @@ -42,7 +42,7 @@ def __init__(self, q: QReg) -> None: """ super().__init__( Comment("Initialize logical |T> = T|+>\n============================="), - qubit.Prep(q[6]), + qubit.PZ(q[6]), qubit.H(q[6]), qubit.T(q[6]), EncodingCircuit(q), @@ -69,7 +69,7 @@ def __init__(self, q: QReg) -> None: """ super().__init__( Comment("Initialize logical |T> = T|+>\n============================="), - qubit.Prep(q[6]), + qubit.PZ(q[6]), qubit.H(q[6]), qubit.Tdg(q[6]), EncodingCircuit(q), diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/six_check_nonflagging.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/six_check_nonflagging.py index 3ccc9831f..a245006cc 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/six_check_nonflagging.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/six_check_nonflagging.py @@ -50,7 +50,7 @@ def __init__(self, data: QReg, ancillas: QReg, syn_x: CReg, syn_z: CReg) -> None syn_x.set(0), syn_z.set(0), Comment(), - gq.Prep( + gq.PZ( q[0], q[8], q[9], @@ -102,7 +102,7 @@ def __init__(self, data: QReg, ancillas: QReg, syn_x: CReg, syn_z: CReg) -> None Comment(), Comment("// Z check 1, X check 2, X check 3"), Comment(), - gq.Prep( + gq.PZ( q[0], q[8], q[9], diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/three_parallel_flagging.py b/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/three_parallel_flagging.py index 06d99c0f8..6c8d5e83a 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/three_parallel_flagging.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/steane/syn_extract/three_parallel_flagging.py @@ -68,7 +68,7 @@ def __init__( Comment("X check 1, Z check 2, Z check 3"), Comment("==============================="), Comment(), - gq.Prep( + gq.PZ( q[0], q[8], q[9], @@ -175,7 +175,7 @@ def __init__( Comment("Z check 1, X check 2, X check 3"), Comment("==============================="), Comment(), - gq.Prep( + gq.PZ( q[0], q[8], q[9], diff --git a/python/quantum-pecos/src/pecos/slr/qeclib/surface/macrolibs/preps/project_pauli.py b/python/quantum-pecos/src/pecos/slr/qeclib/surface/macrolibs/preps/project_pauli.py index debe63dca..5a476bba5 100644 --- a/python/quantum-pecos/src/pecos/slr/qeclib/surface/macrolibs/preps/project_pauli.py +++ b/python/quantum-pecos/src/pecos/slr/qeclib/surface/macrolibs/preps/project_pauli.py @@ -11,28 +11,10 @@ """Pauli projection preparation blocks for surface code operations.""" -from pecos.slr import Block, QReg, Qubit +from pecos.slr import Block, Qubit from pecos.slr.qeclib.qubit.qubit import PhysicalQubit as Q -class PrepZ(Block): - """Prepare the +Z operator.""" - - def __init__(self, q: QReg, data_indices: list[int]) -> None: - """Initialize the +Z state preparation block. - - Args: - q: Quantum register containing the qubits. - data_indices: List of indices for data qubits to prepare in +Z state. - """ - super().__init__() - - for i in data_indices: - self.extend( - Q.pz(q[i]), - ) - - class PrepProjectZ(Block): """Prepare the +Z operator.""" @@ -44,7 +26,9 @@ def __init__(self, qs: list[Qubit]) -> None: """ super().__init__() - self.extend( - PrepZ(*qs), - ) + # Prepare each data qubit in |0> with the qubit-level primitive. + # (`qs` is a list[Qubit]; the register-indexed PrepZ block this + # used to call had a `(QReg, list[int])` API that did not match + # the qubit-list shape -- it was dead/broken and was removed.) + self.extend(*(Q.pz(q) for q in qs)) # TODO: Measure the X checks diff --git a/python/quantum-pecos/src/pecos/slr/slr_converter.py b/python/quantum-pecos/src/pecos/slr/slr_converter.py index 34090386d..86e40e543 100644 --- a/python/quantum-pecos/src/pecos/slr/slr_converter.py +++ b/python/quantum-pecos/src/pecos/slr/slr_converter.py @@ -43,6 +43,7 @@ def __init__(self, block: Main | None = None, *, optimize_parallel: bool = True) optimize_parallel: Whether to apply ParallelOptimizer transformation (default: True). Only affects blocks containing Parallel() statements. """ + self._original_block = block self._block = block self._optimize_parallel = optimize_parallel @@ -106,32 +107,37 @@ def _generate_qasm(self, *, include_header: bool = True) -> str: def _generate_guppy(self) -> str: """Generate Guppy code using AST-based codegen.""" - from pecos.slr.ast.codegen.guppy import ast_to_guppy + from pecos.slr.ast.codegen.guppy import ast_to_guppy, validate_slr_for_guppy_v1 + validate_slr_for_guppy_v1(self._original_block) ast = self._to_ast() return ast_to_guppy(ast) def _generate_qir(self, *, bytecode: bool = False) -> str | bytes: """Generate QIR code using AST-based codegen.""" - if bytecode: - # QIR bytecode requires the old generator - from pecos.slr.gen_codes.gen_qir import QIRGenerator - - if QIRGenerator is None: - msg = ( - "Trying to compile QIR without the appropriate optional dependencies install. " - "Use optional dependency group `qir` or `all`" - ) - raise ImportError(msg) - - generator = QIRGenerator(_internal=True) - generator.generate_block(self._block) - return generator.get_bc() - from pecos.slr.ast.codegen.qir import ast_to_qir ast = self._to_ast() - return ast_to_qir(ast) + ir_text = ast_to_qir(ast) + if not bytecode: + return ir_text + + try: + from pecos_rslib_llvm import binding + except ImportError as exc: + msg = ( + "Trying to compile QIR without the appropriate optional dependencies install. " + "Use optional dependency group `qir` or `all`" + ) + raise ImportError(msg) from exc + + try: + bc = binding.parse_assembly(ir_text).as_bitcode() + except RuntimeError as exc: + msg = f"Failed to compile QIR to bitcode: {exc}" + raise RuntimeError(msg) from exc + binding.shutdown() + return bc def qasm(self, *, skip_headers: bool = False, add_versions: bool = False) -> str: """Generate QASM code. @@ -180,25 +186,51 @@ def hugr(self): ImportError: If guppylang is not available RuntimeError: If compilation fails """ - # Generate Guppy code - self._generate_guppy() + return self._compile_hugr() - # Compile to HUGR - try: - from pecos.slr.gen_codes.guppy.hugr_compiler import HugrCompiler - except ImportError as e: - msg = "Failed to import HugrCompiler. Make sure guppylang is installed." - raise ImportError(msg) from e + def _compile_hugr(self): + from pecos.slr.ast.codegen.entry_wrapper import build_no_arg_entry_wrapper, truncate_source_for_error + + guppy_code = self._generate_guppy() + program = self._to_ast() - # HugrCompiler needs the generator object for its internal state - # For now, fall back to the old path - from pecos.slr.gen_codes.guppy import IRGuppyGenerator + wrapper, _info = build_no_arg_entry_wrapper(program) + full_source = guppy_code + wrapper - generator = IRGuppyGenerator(_internal=True) - generator.generate_block(self._block) + import linecache + import sys + import tempfile + from contextlib import suppress + from pathlib import Path - compiler = HugrCompiler(generator) - return compiler.compile_to_hugr() + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + temp_file = Path(f.name) + f.write(full_source) + + module_name = f"_ast_guppy_generated_{temp_file.stem}" + linecache.cache[str(temp_file)] = ( + len(full_source), + None, + full_source.splitlines(keepends=True), + str(temp_file), + ) + + try: + try: + return _load_and_compile_entry(temp_file, module_name) + except Exception as exc: + truncated = truncate_source_for_error(full_source) + msg = ( + f"Failed to compile AST-generated Guppy to HUGR.\n\n" + f"Error: {type(exc).__name__}: {exc}\n\n" + f"Generated Guppy source (truncated):\n{truncated}" + ) + raise RuntimeError(msg) from exc + finally: + sys.modules.pop(module_name, None) + linecache.cache.pop(str(temp_file), None) + with suppress(OSError, FileNotFoundError): + temp_file.unlink() def stim(self) -> stim.Circuit: """Generate a Stim circuit from the SLR block. @@ -278,3 +310,31 @@ def from_quantum_circuit(cls, qc, *, optimize_parallel: bool = True): optimizer = ParallelOptimizer() slr_block = optimizer.transform(slr_block) return slr_block + + +def _load_and_compile_entry(temp_file, module_name: str): + """Import the AST-generated module from `temp_file` and compile its `entry()`. + + Failures here (spec creation, exec_module raising on import/decorator/syntax + errors, missing `entry`, Guppy compile errors) propagate to the caller so a + single outer except can attach the generated source to the error message. + """ + import importlib.util + import sys + + spec = importlib.util.spec_from_file_location(module_name, temp_file) + if spec is None or spec.loader is None: + msg = "Failed to create module spec for AST-generated Guppy source" + raise RuntimeError(msg) + + module = importlib.util.module_from_spec(spec) + module.__file__ = str(temp_file) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + entry_func = getattr(module, "entry", None) + if entry_func is None: + msg = "No entry function found in AST-generated Guppy source" + raise RuntimeError(msg) + + return entry_func.compile() diff --git a/python/quantum-pecos/src/pecos/slr/transforms/parallel_optimizer.py b/python/quantum-pecos/src/pecos/slr/transforms/parallel_optimizer.py index 78f85f745..7d49d359d 100644 --- a/python/quantum-pecos/src/pecos/slr/transforms/parallel_optimizer.py +++ b/python/quantum-pecos/src/pecos/slr/transforms/parallel_optimizer.py @@ -92,9 +92,22 @@ def _transform_block(self, block: Block) -> Block: elif isinstance(block, Main) and type(block) is Main: new_block = Main(*new_ops) elif isinstance(block, Block): - # Use isinstance to handle Block subclasses - new_block = Block(*new_ops) - # Preserve block metadata if available + # For Block subclasses, preserve class identity so class-level + # attributes (block_inputs, block_returns, etc.) and any instance + # state set by the subclass __init__ (e.g. `self.q1`, `self.q2` + # bound for `block_inputs` lookup) survive the optimization pass. + # Reconstructing via `Block(*new_ops)` here would erase that state + # and cause converted Blocks to silently fall back to + # the v1 flatten path. + cls = type(block) + if cls is Block: + new_block = Block(*new_ops) + else: + new_block = cls.__new__(cls) + new_block.__dict__.update(block.__dict__) + new_block.ops = list(new_ops) + # The metadata copy below is redundant for the subclass case (already + # carried via __dict__) but kept for the plain-Block branch. if hasattr(block, "block_name"): new_block.block_name = block.block_name if hasattr(block, "block_module"): @@ -152,15 +165,34 @@ def _transform_parallel(self, parallel: Parallel) -> Block | Parallel: def _can_optimize_parallel(self, parallel: Parallel) -> bool: """Check if a Parallel block can be safely optimized. - Returns False if the block contains control flow (If/Repeat). + Returns False if the block contains control flow (If/Repeat) or any + converted Block (a Block subclass with class-level `block_inputs`). + Splatting a converted Block's body into the surrounding Parallel + would destroy its scope boundary -- the BlockCall lowering needs the + Block to remain a single op in the AST. """ for op in parallel.ops: if isinstance(op, If | Repeat): return False - if isinstance(op, Block) and self._contains_control_flow(op): - return False + if isinstance(op, Block): + if hasattr(type(op), "block_inputs"): + return False + if self._contains_control_flow(op): + return False + if self._contains_converted_block(op): + return False return True + def _contains_converted_block(self, block: Block) -> bool: + """Check if a Block (recursively) contains any converted Block.""" + for op in block.ops: + if isinstance(op, Block): + if hasattr(type(op), "block_inputs"): + return True + if self._contains_converted_block(op): + return True + return False + def _contains_control_flow(self, block: Block) -> bool: """Check if a block contains any control flow operations.""" for op in block.ops: diff --git a/python/quantum-pecos/src/pecos/slr/vars.py b/python/quantum-pecos/src/pecos/slr/vars.py index afd0d31f5..6d2d11990 100644 --- a/python/quantum-pecos/src/pecos/slr/vars.py +++ b/python/quantum-pecos/src/pecos/slr/vars.py @@ -215,17 +215,14 @@ def resolve(self, value: int) -> Qubit: class CReg(Reg, PyCOp): - def __init__(self, sym: str, size: int, *, result: bool = True) -> None: - """ - Representation for a collection of bits. + def __init__(self, sym: str, size: int) -> None: + """Representation for a collection of bits. Args: - sym: - size: - result: Whether this register is a result register (default True) + sym: Register name. + size: Number of bits. """ super().__init__(sym, size, elem_type=Bit) - self.result = result @property def _symbolic_elem_type(self) -> type[SymbolicBit]: diff --git a/python/quantum-pecos/tests/conftest.py b/python/quantum-pecos/tests/conftest.py index 7076b0d55..62d2c1a37 100644 --- a/python/quantum-pecos/tests/conftest.py +++ b/python/quantum-pecos/tests/conftest.py @@ -5,3 +5,15 @@ # in isolation. Import the installed/source package first so later # ``import pecos`` statements resolve to the public PECOS package. import pecos + + +def pytest_configure(config): + """Register markers at the test-tree root so they are known for ANY + invocation (e.g. running a single file directly), not only when + pytest happens to pick a ``pyproject.toml`` whose + ``[tool.pytest.ini_options].markers`` lists them.""" + config.addinivalue_line( + "markers", + "slow: mark tests that provide extra integration coverage but are " + "excluded from the default fast Python test lane", + ) diff --git a/python/quantum-pecos/tests/pecos/integration/state_sim_tests/gate_matrix_def.py b/python/quantum-pecos/tests/pecos/integration/state_sim_tests/gate_matrix_def.py index 9b88d83df..66b32d943 100644 --- a/python/quantum-pecos/tests/pecos/integration/state_sim_tests/gate_matrix_def.py +++ b/python/quantum-pecos/tests/pecos/integration/state_sim_tests/gate_matrix_def.py @@ -482,6 +482,73 @@ def CH() -> pc.Array: assert eqv2phase(CH(), ch_def) +def CRZ(theta: float) -> pc.Array: + """Controlled-RZ(theta) gate. Convention: block-diag(I, RZ(theta)). + + Decomposition (2q-minimal: 1 RZZ + 2 single-qubit RZ): + CRZ(theta) = (RZ(theta/2) o RZ(theta/2)) . RZZ(-theta/2) + Works because PECOS_RZ and PECOS_RZZ share the same e^{i.t/2} + global-phase convention. RZ on control absorbs the c=1-only phase + that the bare RZZ-based form would leave (it would otherwise be + a *relative* phase, not a global one, and thus observable). + """ + return oporder_multiply( + [ + RZZ(-theta / 2), + (RZ(theta / 2), RZ(theta / 2)), + ], + ) + + +for _ in range(5): + crz_th = pc.random.random() + assert eqv2phase(CRZ(crz_th), (project_zero & I) + (project_one & RZ(crz_th))) + + +def CRX(theta: float) -> pc.Array: + """Controlled-RX(theta) gate. Convention: block-diag(I, RX(theta)). + + Decomposition (2q-minimal: 1 RZZ, via H conjugation of CRZ): + CRX(theta) = (I o H) . CRZ(theta) . (I o H) + """ + return oporder_multiply( + [ + (I, H()), + CRZ(theta), + (I, H()), + ], + ) + + +for _ in range(5): + crx_th = pc.random.random() + assert eqv2phase(CRX(crx_th), (project_zero & I) + (project_one & RX(crx_th))) + + +def CRY(theta: float) -> pc.Array: + """Controlled-RY(theta) gate. Convention: block-diag(I, RY(theta)). + + Decomposition (2q-minimal: 1 RZZ, via (S.H) conjugation of CRZ): + CRY(theta) = (I o (S.H)) . CRZ(theta) . (I o (H.Sdg)) + Identity used: S.X.Sdg = Y, so S.Rx.Sdg = Ry; combined with + H.Rz.H = Rx gives S.H.Rz.H.Sdg = Ry. + """ + return oporder_multiply( + [ + (I, Sdg()), + (I, H()), + CRZ(theta), + (I, H()), + (I, S()), + ], + ) + + +for _ in range(5): + cry_th = pc.random.random() + assert eqv2phase(CRY(cry_th), (project_zero & I) + (project_one & RY(cry_th))) + + def Toffoli() -> pc.Array: """C3 gate: Toffoli. diff --git a/python/quantum-pecos/tests/pecos/integration/state_sim_tests/test_statevec.py b/python/quantum-pecos/tests/pecos/integration/state_sim_tests/test_statevec.py index fe2cb5641..76cf1037e 100644 --- a/python/quantum-pecos/tests/pecos/integration/state_sim_tests/test_statevec.py +++ b/python/quantum-pecos/tests/pecos/integration/state_sim_tests/test_statevec.py @@ -313,6 +313,9 @@ def _apply(gate: dict, **params: object) -> None: _apply({"SWAP": {(3, 0)}}) _apply({"Tdg": {3, 1}}) _apply({"RXX": {(1, 3)}}, angles=(pc.f64.frac_pi_4,)) + _apply({"CRZ": {(0, 1)}}, angles=(pc.f64.pi / 5,)) + _apply({"CRX": {(2, 3)}}, angles=(pc.f64.pi / 7,)) + _apply({"CRY": {(1, 2)}}, angles=(pc.f64.pi / 6,)) _apply({"Q": {0, 1, 2}}) _apply({"Qd": {0, 3}}) _apply({"R": {0}}) @@ -339,6 +342,38 @@ def _apply(gate: dict, **params: object) -> None: check_measurement(simulator, qc) +def test_controlled_rotations_statevec() -> None: + """CRX/CRY/CRZ via the StateVec backend (cross-codegen QC support). + + Verifies the 1-RZZ default implementations in + `ArbitraryRotationGateable::{crx,cry,crz}` produce the textbook + block-diag(I, R*(theta)) action by comparing against direct + R*(theta) on the target with the control prepared in |1>. + """ + theta = pc.f64.pi / 3 + for sym, direct_sym in [("CRZ", "RZ"), ("CRX", "RX"), ("CRY", "RY")]: + # c=0: |00> stays |00>. + sim_c0 = StateVec(2) + sim_c0.backend.run_2q_gate(sym, (0, 1), {"angle": theta}) + baseline_c0 = StateVec(2) + assert pc.allclose( + sim_c0.backend.vector, + baseline_c0.backend.vector, + ), f"{sym}|00> changed the state -- expected identity when control=0" + + # c=1: |10> -> |1, R*(theta)|0>>. Compare against direct R* on target. + sim_c1 = StateVec(2) + sim_c1.backend.run_1q_gate("X", 0, None) + sim_c1.backend.run_2q_gate(sym, (0, 1), {"angle": theta}) + sim_direct = StateVec(2) + sim_direct.backend.run_1q_gate("X", 0, None) + sim_direct.backend.run_1q_gate(direct_sym, 1, {"angle": theta}) + assert pc.allclose( + sim_c1.backend.vector, + sim_direct.backend.vector, + ), f"{sym}(theta) when c=1 must equal direct {direct_sym}(theta) on target" + + @pytest.mark.parametrize( "simulator", [ diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/examples/test_logical_steane_code_program.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/examples/test_logical_steane_code_program.py index 87fb61868..34bb09c6f 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/examples/test_logical_steane_code_program.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/examples/test_logical_steane_code_program.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import Barrier, CReg, If, Main +from pecos.slr import Barrier, CReg, If, Main, Return from pecos.slr.qeclib.steane.steane_class import Steane @@ -48,6 +48,7 @@ def telep(prep_basis: str, meas_basis: str) -> Main: If(m_bell[0] == 0).Then(sout.z()), # Final output stored in `m_out[0]` sout.m(meas_basis, m_out[0]), + Return(m_bell, m_out), ) @@ -100,6 +101,7 @@ def t_gate(prep_basis: str, meas_basis: str) -> Main: If(sin.t_meas == 1).Then(sin.sz()), # Final output stored in `m_out[1]` sin.m(meas_basis, m_out[1]), + Return(m_reject, m_t, m_out), ) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_preps.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_preps.py index fe8a718af..14c6de500 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_preps.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_preps.py @@ -18,8 +18,8 @@ def test_Prep(compare_qasm: Callable[..., None]) -> None: - """Test Prep gate QASM regression.""" + """Test PZ gate QASM regression.""" q = QReg("q_test", 1) - prog = qubit.Prep(q[0]) + prog = qubit.PZ(q[0]) compare_qasm(prog) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_rots.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_rots.py index 80b118f47..3495b4959 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_rots.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/qubit/test_rots.py @@ -14,33 +14,33 @@ from collections.abc import Callable import pecos as pc -from pecos.slr import QReg +from pecos.slr import QReg, rad from pecos.slr.qeclib import qubit def test_RX(compare_qasm: Callable[..., None]) -> None: """Test RX rotation gate QASM regression.""" q = QReg("q_test", 1) - prog = qubit.RX[pc.f64.pi / 3](q[0]) + prog = qubit.RX(rad(pc.f64.pi / 3), q[0]) compare_qasm(prog) def test_RY(compare_qasm: Callable[..., None]) -> None: """Test RY rotation gate QASM regression.""" q = QReg("q_test", 1) - prog = qubit.RY[pc.f64.pi / 3](q[0]) + prog = qubit.RY(rad(pc.f64.pi / 3), q[0]) compare_qasm(prog) def test_RZ(compare_qasm: Callable[..., None]) -> None: """Test RZ rotation gate QASM regression.""" q = QReg("q_test", 1) - prog = qubit.RZ[pc.f64.pi / 3](q[0]) + prog = qubit.RZ(rad(pc.f64.pi / 3), q[0]) compare_qasm(prog) def test_RZZ(compare_qasm: Callable[..., None]) -> None: """Test RZZ two-qubit rotation gate QASM regression.""" q = QReg("q_test", 4) - prog = qubit.RZZ[pc.f64.pi / 3](q[1], q[3]) + prog = qubit.RZZ(rad(pc.f64.pi / 3), q[1], q[3]) compare_qasm(prog) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_destructive_meas.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_destructive_meas.py index e202da4ef..41e911e17 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_destructive_meas.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_destructive_meas.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import CReg, QReg +from pecos.slr import CReg, QReg, Return from pecos.slr.qeclib.steane.meas.destructive_meas import ( MeasDecode, Measure, @@ -32,6 +32,7 @@ def test_MeasureX(compare_qasm: Callable[..., None]) -> None: for barrier in [True, False]: block = MeasureX(q, meas_creg, log_raw, barrier=barrier) + block.extend(Return(meas_creg, log_raw)) compare_qasm(block, barrier) @@ -43,6 +44,7 @@ def test_MeasureY(compare_qasm: Callable[..., None]) -> None: for barrier in [True, False]: block = MeasureY(q, meas_creg, log_raw, barrier=barrier) + block.extend(Return(meas_creg, log_raw)) compare_qasm(block, barrier) @@ -54,6 +56,7 @@ def test_MeasureZ(compare_qasm: Callable[..., None]) -> None: for barrier in [True, False]: block = MeasureZ(q, meas_creg, log_raw, barrier=barrier) + block.extend(Return(meas_creg, log_raw)) compare_qasm(block, barrier) @@ -65,6 +68,7 @@ def test_Measure(compare_qasm: Callable[..., None]) -> None: for meas_basis in ["X", "Y", "Z"]: block = Measure(q, meas_creg, log_raw, meas_basis=meas_basis) + block.extend(Return(meas_creg, log_raw)) compare_qasm(block, meas_basis) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_x.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_x.py index c7f611ccc..6ad397dde 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_x.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_x.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import CReg, QReg +from pecos.slr import CReg, QReg, Return from pecos.slr.qeclib.steane.meas.measure_x import NoFlagMeasureX @@ -24,4 +24,5 @@ def test_MeasureX(compare_qasm: Callable[..., None]) -> None: out = CReg("out_test", 1) block = NoFlagMeasureX(q, a, out) + block.extend(Return(out)) compare_qasm(block) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_z.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_z.py index 768fdd30f..53f78b8a4 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_z.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/meas/test_measure_z.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import CReg, QReg +from pecos.slr import CReg, QReg, Return from pecos.slr.qeclib.steane.meas.measure_z import NoFlagMeasureZ @@ -24,4 +24,5 @@ def test_MeasureX(compare_qasm: Callable[..., None]) -> None: out = CReg("out_test", 1) block = NoFlagMeasureZ(q, a, out) + block.extend(Return(out)) compare_qasm(block) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/preps/test_pauli_states.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/preps/test_pauli_states.py index b51375fb0..39d460af0 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/preps/test_pauli_states.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/preps/test_pauli_states.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import CReg, QReg +from pecos.slr import CReg, QReg, Return from pecos.slr.qeclib.steane.preps.pauli_states import ( LogZeroRot, PrepEncodingFTZero, @@ -37,6 +37,7 @@ def test_PrepZeroVerify(compare_qasm: Callable[..., None]) -> None: init_bit = CReg("init_bit", 1) for reset_ancilla in [True, False]: block = PrepZeroVerify(q, a[0], init_bit[0], reset_ancilla=reset_ancilla) + block.extend(Return(init_bit)) compare_qasm(block, reset_ancilla) @@ -68,6 +69,7 @@ def test_PrepRUS(compare_qasm: Callable[..., None]) -> None: state, first_round_reset=first_round_reset, ) + block.extend(Return(init)) compare_qasm(block, limit, state, first_round_reset) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/qec/test_qec_3parallel.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/qec/test_qec_3parallel.py index 4102f5d34..cf1826259 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/qec/test_qec_3parallel.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/qec/test_qec_3parallel.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import CReg, QReg +from pecos.slr import CReg, QReg, Return from pecos.slr.qeclib.steane.qec.qec_3parallel import ParallelFlagQECActiveCorrection @@ -47,4 +47,18 @@ def test_ParallelFlagQECActiveCorrection(compare_qasm: Callable[..., None]) -> N pf[1], scratch, ) + block.extend( + Return( + flag_x, + flag_z, + flags, + syn_x, + syn_z, + last_raw_syn_x, + last_raw_syn_z, + syndromes, + pf, + scratch, + ), + ) compare_qasm(block) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_six_check_nonflagging.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_six_check_nonflagging.py index a07004242..fb8b8424f 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_six_check_nonflagging.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_six_check_nonflagging.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import CReg, QReg +from pecos.slr import CReg, QReg, Return from pecos.slr.qeclib.steane.syn_extract.six_check_nonflagging import SixUnflaggedSyn @@ -25,4 +25,5 @@ def test_SixUnflaggedSyn(compare_qasm: Callable[..., None]) -> None: syn_z = CReg("syn_z_test", 3) block = SixUnflaggedSyn(q, a, syn_x, syn_z) + block.extend(Return(syn_x, syn_z)) compare_qasm(block) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_three_parallel_flagging.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_three_parallel_flagging.py index 9f7b9c951..9a8d3465f 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_three_parallel_flagging.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/pecos/qeclib/steane/syn_extract/test_three_parallel_flagging.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import CReg, QReg +from pecos.slr import CReg, QReg, Return from pecos.slr.qeclib.steane.syn_extract.three_parallel_flagging import ( ThreeParallelFlaggingXZZ, ThreeParallelFlaggingZXX, @@ -39,6 +39,7 @@ def test_ThreeParallelFlaggingXZZ(compare_qasm: Callable[..., None]) -> None: last_raw_syn_x, last_raw_syn_z, ) + block.extend(Return(flag_x, flag_z, flags, last_raw_syn_x, last_raw_syn_z)) compare_qasm(block) @@ -61,4 +62,5 @@ def test_ThreeParallelFlaggingZXX(compare_qasm: Callable[..., None]) -> None: last_raw_syn_x, last_raw_syn_z, ) + block.extend(Return(flag_x, flag_z, flags, last_raw_syn_x, last_raw_syn_z)) compare_qasm(block) diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_control_flow.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_control_flow.py index e6037aab7..7c185623b 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_control_flow.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_control_flow.py @@ -13,7 +13,7 @@ from collections.abc import Callable -from pecos.slr import Block, CReg, If, Main, Parallel, QReg, Repeat +from pecos.slr import Block, CReg, If, Main, Parallel, QReg, Repeat, Return from pecos.slr.qeclib import qubit as qb @@ -25,6 +25,7 @@ def test_phys_teleport(compare_qasm: Callable[..., None]) -> None: qb.H(q[0]), qb.CX(q[0], q[1]), qb.Measure(q) > c, + Return(c), ) compare_qasm(prog, filename="phys.teleport") @@ -44,6 +45,7 @@ def test_phys_tele_block_block(compare_qasm: Callable[..., None]) -> None: qb.H(q[1]), ), ), + Return(c), ) compare_qasm(prog, filename="phys.tele_block_block") @@ -60,6 +62,7 @@ def test_phys_tele_if(compare_qasm: Callable[..., None]) -> None: If(c == 0).Then( qb.H(q[0]), ), + Return(c), ) compare_qasm(prog, filename="phys.tele_if") @@ -79,6 +82,7 @@ def test_phys_tele_if_block_block(compare_qasm: Callable[..., None]) -> None: qb.H(q[1]), ), ), + Return(c), ) compare_qasm(prog, filename="phys.tele_if_block_block") @@ -94,7 +98,7 @@ def test_phys_tele_block_telep_block(compare_qasm: Callable[..., None]) -> None: qb.CX(q[0], q[1]), qb.Measure(q) > c, Block( - qb.Prep(q), + qb.PZ(q), qb.H(q[0]), qb.CX(q[0], q[1]), qb.Measure(q) > c2, @@ -102,6 +106,7 @@ def test_phys_tele_block_telep_block(compare_qasm: Callable[..., None]) -> None: qb.H(q[0]), ), ), + Return(c, c2), ) compare_qasm(prog, filename="phys.tele_block_telep_block") @@ -117,6 +122,7 @@ def test_phys_repeat(compare_qasm: Callable[..., None]) -> None: qb.CX(q[0], q[1]), qb.Measure(q) > c, ), + Return(c), ) compare_qasm(prog, filename="phys.tele_repeat") @@ -136,6 +142,7 @@ def test_phys_parallel() -> None: qb.Y(q[3]), ), qb.Measure(q) > c, + Return(c), ) qasm = SlrConverter(prog).qasm() @@ -168,6 +175,7 @@ def test_phys_nested_parallel() -> None: qb.Z(q[3]), ), qb.Measure(q) > c, + Return(c), ) qasm = SlrConverter(prog).qasm() @@ -198,6 +206,7 @@ def test_phys_parallel_in_if() -> None: ), ), qb.Measure(q[1:4]) > c[1:4], + Return(c), ) qasm = SlrConverter(prog).qasm() diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_permute.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_permute.py index 8df835e7d..bf42ad276 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_permute.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_permute.py @@ -11,7 +11,7 @@ """Testing SLR->QASM permute cases.""" -from pecos.slr import Block, CReg, Main, Permute, SlrConverter +from pecos.slr import Block, CReg, Main, Permute, Return, SlrConverter from pecos.slr.qeclib.steane.steane_class import Steane @@ -25,6 +25,7 @@ def test_permute1() -> None: Permute(a.a, b.a), a.mx(meas[0]), b.my(meas[1]), + Return(meas), ) qasm = SlrConverter(prog).qasm() @@ -54,6 +55,7 @@ def my_permute(a: Steane, b: Steane) -> Block: my_permute(a, b), a.mx(meas[0]), b.my(meas[1]), + Return(meas), ) qasm = SlrConverter(prog).qasm() @@ -74,6 +76,7 @@ def test_permute3() -> None: b := Steane("b"), meas := CReg("meas", 1), a.px(), + Return(meas), ) for _i in range(1): prog.extend( @@ -106,6 +109,7 @@ def test_permute4() -> None: b := Steane("b"), meas := CReg("meas", 1), a.px(), + Return(meas), ) for _i in range(1): prog.extend( diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_slr_phys.py b/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_slr_phys.py index 1a6d81387..041b0c5ff 100644 --- a/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_slr_phys.py +++ b/python/quantum-pecos/tests/pecos/regression/test_qasm/random_cases/test_slr_phys.py @@ -13,7 +13,9 @@ QReg, Qubit, Repeat, + Return, SlrConverter, + rad, ) from pecos.slr.qeclib import qubit as p from pecos.slr.qeclib.steane.steane_class import Steane @@ -62,6 +64,7 @@ def telep(prep_basis: str, meas_basis: str) -> str: If(m_bell[0] == 0).Then(sout.z()), # Final output stored in `m_out[0]` sout.m(meas_basis, m_out[0]), + Return(m_bell, m_out), ) @@ -73,6 +76,7 @@ def test_bell() -> None: p.H(q[0]), p.CX(q[0], q[1]), p.Measure(q) > m, + Return(m), ) qasm = ( @@ -98,6 +102,7 @@ def test_bell_qir() -> None: p.H(q[0]), p.CX(q[0], q[1]), p.Measure(q) > m, + Return(m), ) qir = SlrConverter(prog).qir() @@ -113,6 +118,7 @@ def test_bell_qreg_qir() -> None: p.H(q), p.CX(q[0], q[1]), p.Measure(q) > m, + Return(m), ) qir = SlrConverter(prog).qir() @@ -126,8 +132,8 @@ class Bell(Block): def __init__(self, q0: Qubit, q1: Qubit, m0: Bit, m1: Bit) -> None: super().__init__() self.extend( - p.Prep(q0), - p.Prep(q1), + p.PZ(q0), + p.PZ(q1), p.H(q0), p.CX(q0, q1), p.Measure(q0) > m0, @@ -139,6 +145,7 @@ def __init__(self, q0: Qubit, q1: Qubit, m0: Bit, m1: Bit) -> None: m := CReg("m", 2), c := CReg("c", 4), If(c == 1).Then(Bell(q0=q[0], q1=q[1], m0=m[0], m1=m[1])), + Return(m, c), ) qasm = ( @@ -171,6 +178,7 @@ def test_strange_program() -> None: c.set(b & 1), Permute([q[0], q[1]], [q[1], q[0]]), p.H(q[0]), + Return(c, b), ) qasm = ( @@ -200,7 +208,7 @@ def test_control_flow_qir() -> None: prog = Main( q := QReg("q", 2), m := CReg("m", 2), - m_hidden := CReg("m_hidden", 2, result=False), + m_hidden := CReg("m_hidden", 2), Repeat(3).block( p.H(q[0]), ), @@ -213,7 +221,7 @@ def test_control_flow_qir() -> None: ), ) .Else( - p.RX[0.3](q[0]), + p.RX(rad(0.3), q[0]), ), If(m < m_hidden).Then( p.H(q[0]), @@ -223,8 +231,9 @@ def test_control_flow_qir() -> None: p.SZdg(q[0]), p.CX(q[0], q[1]), Barrier(q[1], q[0]), - p.RX[0.3](q[0]), + p.RX(rad(0.3), q[0]), p.Measure(q) > m, + Return(m), ) qir = SlrConverter(prog).qir() assert "__quantum__qis__h__body" in qir @@ -241,6 +250,7 @@ def test_plus_qir() -> None: m.set(2), n.set(2), o.set(m + n), + Return(m, n, o), ) qir = SlrConverter(prog).qir() assert "add" in qir @@ -259,6 +269,7 @@ def test_nested_xor_qir() -> None: n.set(2), o.set(2), p[0].set((m[0] ^ n[0]) ^ o[0]), + Return(m, n, o, p), ) qir = SlrConverter(prog).qir() assert "xor" in qir @@ -275,6 +286,7 @@ def test_minus_qir() -> None: m.set(2), n.set(2), o.set(m - n), + Return(m, n, o), ) qir = SlrConverter(prog).qir() assert "sub" in qir diff --git a/python/quantum-pecos/tests/pecos/regression/test_qasm/regression_qasm/pecos.slr.qeclib.qubit.preps.Prep.qasm b/python/quantum-pecos/tests/pecos/regression/test_qasm/regression_qasm/pecos.slr.qeclib.qubit.preps.PZ.qasm similarity index 100% rename from python/quantum-pecos/tests/pecos/regression/test_qasm/regression_qasm/pecos.slr.qeclib.qubit.preps.Prep.qasm rename to python/quantum-pecos/tests/pecos/regression/test_qasm/regression_qasm/pecos.slr.qeclib.qubit.preps.PZ.qasm diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_depth_analyzer.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_depth_analyzer.py index 5b542e421..e073a3a80 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_depth_analyzer.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_depth_analyzer.py @@ -34,20 +34,20 @@ def test_single_gate(self) -> None: """Single gate has depth 1.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), ) ast = slr_to_ast(prog) result = analyze_depth(ast) - assert result.depth == 2 # Prep + H + assert result.depth == 2 # PZ + H def test_sequential_gates_same_qubit(self) -> None: """Sequential gates on same qubit add to depth.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), qb.X(q[0]), qb.Z(q[0]), @@ -56,14 +56,14 @@ def test_sequential_gates_same_qubit(self) -> None: result = analyze_depth(ast) - assert result.depth == 4 # Prep + H + X + Z + assert result.depth == 4 # PZ + H + X + Z def test_parallel_gates_different_qubits(self) -> None: """Gates on different qubits can run in parallel.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.H(q[0]), qb.X(q[1]), ) @@ -71,8 +71,8 @@ def test_parallel_gates_different_qubits(self) -> None: result = analyze_depth(ast) - # q[0]: Prep(1) -> H(2) - # q[1]: Prep(1) -> X(2) + # q[0]: PZ(1) -> H(2) + # q[1]: PZ(1) -> X(2) # Both paths have depth 2 assert result.depth == 2 @@ -84,8 +84,8 @@ def test_two_qubit_gate_depth(self) -> None: """Two-qubit gate increases depth for both qubits.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.CX(q[0], q[1]), ) ast = slr_to_ast(prog) @@ -100,8 +100,8 @@ def test_two_qubit_gate_waits_for_both(self) -> None: """Two-qubit gate waits for both qubits to be ready.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.H(q[0]), # q[0] now at depth 2 qb.CX(q[0], q[1]), # Must wait for q[0] ) @@ -109,17 +109,17 @@ def test_two_qubit_gate_waits_for_both(self) -> None: result = analyze_depth(ast) - # q[0]: Prep(1) -> H(2) -> CX(3) - # q[1]: Prep(1) -> (wait) -> CX(3) + # q[0]: PZ(1) -> H(2) -> CX(3) + # q[1]: PZ(1) -> (wait) -> CX(3) assert result.depth == 3 def test_chain_of_two_qubit_gates(self) -> None: """Chain of two-qubit gates increases depth.""" prog = Main( q := QReg("q", 3), - qb.Prep(q[0]), - qb.Prep(q[1]), - qb.Prep(q[2]), + qb.PZ(q[0]), + qb.PZ(q[1]), + qb.PZ(q[2]), qb.CX(q[0], q[1]), # Depth 2 qb.CX(q[1], q[2]), # Depth 3 (waits for q[1]) ) @@ -138,8 +138,8 @@ def test_bell_state_depth(self) -> None: """Bell state has depth 3 (prep + H + CX).""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.H(q[0]), qb.CX(q[0], q[1]), ) @@ -147,8 +147,8 @@ def test_bell_state_depth(self) -> None: result = analyze_depth(ast) - # q[0]: Prep(1) -> H(2) -> CX(3) - # q[1]: Prep(1) -> (wait) -> CX(3) + # q[0]: PZ(1) -> H(2) -> CX(3) + # q[1]: PZ(1) -> (wait) -> CX(3) assert result.depth == 3 @@ -159,7 +159,7 @@ def test_repeat_adds_depth(self) -> None: """Repeat loop adds depth for each iteration.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), Repeat(cond=3).block( qb.H(q[0]), ), @@ -168,7 +168,7 @@ def test_repeat_adds_depth(self) -> None: result = analyze_depth(ast) - # Prep(1) + H(2) + H(3) + H(4) + # PZ(1) + H(2) + H(3) + H(4) assert result.depth == 4 @@ -181,9 +181,9 @@ def test_syndrome_extraction_depth(self) -> None: data := QReg("data", 2), ancilla := QReg("ancilla", 1), c := CReg("c", 1), - qb.Prep(data[0]), - qb.Prep(data[1]), - qb.Prep(ancilla[0]), + qb.PZ(data[0]), + qb.PZ(data[1]), + qb.PZ(ancilla[0]), qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), qb.Measure(ancilla[0]) > c[0], @@ -209,12 +209,12 @@ def test_analyzer_reusable(self) -> None: prog1 = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), ) prog2 = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), qb.X(q[0]), ) @@ -232,8 +232,8 @@ def test_result_string_representation(self) -> None: """DepthResult has useful string representation.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.H(q[0]), qb.CX(q[0], q[1]), ) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_qubit_state_validator.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_qubit_state_validator.py index 0b73439f2..feb642c17 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_qubit_state_validator.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_qubit_state_validator.py @@ -56,7 +56,7 @@ def test_gate_with_prep_passes(self) -> None: """Gate on prepared qubit should pass.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), ) ast = slr_to_ast(prog) @@ -69,7 +69,7 @@ def test_multiple_gates_after_prep(self) -> None: """Multiple gates after prep should all pass.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), qb.X(q[0]), qb.Z(q[0]), @@ -89,7 +89,7 @@ def test_gate_after_measure_fails(self) -> None: prog = Main( q := QReg("q", 1), c := CReg("c", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), qb.Measure(q[0]) > c[0], qb.X(q[0]), # Gate after measure @@ -107,10 +107,10 @@ def test_reprep_after_measure_passes(self) -> None: prog = Main( q := QReg("q", 1), c := CReg("c", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), qb.Measure(q[0]) > c[0], - qb.Prep(q[0]), # Re-prep + qb.PZ(q[0]), # Re-prep qb.X(q[0]), ) ast = slr_to_ast(prog) @@ -127,8 +127,8 @@ def test_two_qubit_gate_both_prepared(self) -> None: """Two-qubit gate with both qubits prepared should pass.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.CX(q[0], q[1]), ) ast = slr_to_ast(prog) @@ -141,7 +141,7 @@ def test_two_qubit_gate_one_unprepared(self) -> None: """Two-qubit gate with one qubit unprepared should fail.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), + qb.PZ(q[0]), # q[1] not prepared qb.CX(q[0], q[1]), ) @@ -169,12 +169,12 @@ class TestQubitStateValidatorControlFlow: """Control flow tests.""" def test_if_branch_prep_not_sufficient(self) -> None: - """Prep in only one branch is not sufficient.""" + """PZ in only one branch is not sufficient.""" prog = Main( q := QReg("q", 1), c := CReg("c", 1), If(c[0] == 1).Then( - qb.Prep(q[0]), + qb.PZ(q[0]), ), # After if, q[0] might not be prepared (else branch doesn't prep) qb.H(q[0]), @@ -186,16 +186,16 @@ def test_if_branch_prep_not_sufficient(self) -> None: assert len(violations) == 1 def test_if_both_branches_prep(self) -> None: - """Prep in both branches is sufficient.""" + """PZ in both branches is sufficient.""" prog = Main( q := QReg("q", 1), c := CReg("c", 1), If(c[0] == 1) .Then( - qb.Prep(q[0]), + qb.PZ(q[0]), ) .Else( - qb.Prep(q[0]), + qb.PZ(q[0]), ), qb.H(q[0]), # Safe - prepared in both branches ) @@ -210,7 +210,7 @@ def test_if_body_uses_prep_from_before(self) -> None: prog = Main( q := QReg("q", 1), c := CReg("c", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), If(c[0] == 1).Then( qb.H(q[0]), ), @@ -225,7 +225,7 @@ def test_repeat_uses_prep_from_before(self) -> None: """Repeat body can use prep from before.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), Repeat(cond=3).block( qb.H(q[0]), ), @@ -246,11 +246,11 @@ def test_syndrome_extraction_pattern(self) -> None: data := QReg("data", 2), ancilla := QReg("ancilla", 1), c := CReg("c", 1), - # Prep data qubits - qb.Prep(data[0]), - qb.Prep(data[1]), - # Prep ancilla - qb.Prep(ancilla[0]), + # PZ data qubits + qb.PZ(data[0]), + qb.PZ(data[1]), + # PZ ancilla + qb.PZ(ancilla[0]), # Syndrome extraction qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), @@ -269,16 +269,16 @@ def test_repeated_syndrome_extraction(self) -> None: data := QReg("data", 2), ancilla := QReg("ancilla", 1), c := CReg("c", 1), - # Prep everything - qb.Prep(data[0]), - qb.Prep(data[1]), - qb.Prep(ancilla[0]), + # PZ everything + qb.PZ(data[0]), + qb.PZ(data[1]), + qb.PZ(ancilla[0]), # First round qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), qb.Measure(ancilla[0]) > c[0], # Second round - need to re-prep ancilla - qb.Prep(ancilla[0]), + qb.PZ(ancilla[0]), qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), qb.Measure(ancilla[0]) > c[0], @@ -305,7 +305,7 @@ def test_validator_reusable(self) -> None: prog2 = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), # No violation ) ast2 = slr_to_ast(prog2) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_resource_counter.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_resource_counter.py index 5eaa13e26..4b08f3355 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_resource_counter.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/analysis/test_resource_counter.py @@ -84,7 +84,7 @@ def test_single_qubit_gates(self) -> None: """Single-qubit gates are counted.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), qb.X(q[0]), qb.Z(q[0]), @@ -93,7 +93,7 @@ def test_single_qubit_gates(self) -> None: result = count_resources(ast) - assert result.total_gates == 3 # H, X, Z (Prep is not a gate) + assert result.total_gates == 3 # H, X, Z (PZ is not a gate) assert result.single_qubit_gates == 3 assert result.two_qubit_gates == 0 @@ -101,8 +101,8 @@ def test_two_qubit_gates(self) -> None: """Two-qubit gates are counted.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.CX(q[0], q[1]), qb.CZ(q[0], q[1]), ) @@ -118,8 +118,8 @@ def test_mixed_gates(self) -> None: """Mixed single and two-qubit gates are counted correctly.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.H(q[0]), qb.CX(q[0], q[1]), qb.X(q[1]), @@ -136,8 +136,8 @@ def test_gate_counts_by_type(self) -> None: """Gates are counted by type.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.H(q[0]), qb.H(q[1]), qb.CX(q[0], q[1]), @@ -159,8 +159,8 @@ def test_measurements_counted(self) -> None: prog = Main( q := QReg("q", 2), c := CReg("c", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), qb.Measure(q[0]) > c[0], qb.Measure(q[1]) > c[1], ) @@ -174,8 +174,8 @@ def test_preparations_counted(self) -> None: """Preparations are counted.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), + qb.PZ(q[0]), + qb.PZ(q[1]), ) ast = slr_to_ast(prog) @@ -191,7 +191,7 @@ def test_repeat_multiplies_resources(self) -> None: """Repeat loop multiplies gate counts.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), Repeat(cond=5).block( qb.H(q[0]), qb.X(q[0]), @@ -215,9 +215,9 @@ def test_syndrome_extraction_resources(self) -> None: data := QReg("data", 2), ancilla := QReg("ancilla", 1), c := CReg("c", 1), - qb.Prep(data[0]), - qb.Prep(data[1]), - qb.Prep(ancilla[0]), + qb.PZ(data[0]), + qb.PZ(data[1]), + qb.PZ(ancilla[0]), qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), qb.Measure(ancilla[0]) > c[0], @@ -257,7 +257,7 @@ def test_result_string_representation(self) -> None: prog = Main( q := QReg("q", 2), _c := CReg("c", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), qb.H(q[0]), qb.CX(q[0], q[1]), ) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_identity_removal.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_identity_removal.py index f3130837d..a287c4907 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_identity_removal.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_identity_removal.py @@ -13,7 +13,7 @@ import math -from pecos.slr import CReg, If, Main, QReg, Repeat +from pecos.slr import CReg, If, Main, QReg, Repeat, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.optimizations import IdentityRemovalPass from pecos.slr.qeclib import qubit as qb @@ -26,7 +26,7 @@ def test_rz_zero_removed(self) -> None: """RZ(0) is removed.""" prog = Main( q := QReg("q", 1), - qb.RZ[0](q[0]), + qb.RZ(rad(0), q[0]), ) ast = slr_to_ast(prog) @@ -39,7 +39,7 @@ def test_rx_zero_removed(self) -> None: """RX(0) is removed.""" prog = Main( q := QReg("q", 1), - qb.RX[0](q[0]), + qb.RX(rad(0), q[0]), ) ast = slr_to_ast(prog) @@ -52,7 +52,7 @@ def test_ry_zero_removed(self) -> None: """RY(0) is removed.""" prog = Main( q := QReg("q", 1), - qb.RY[0](q[0]), + qb.RY(rad(0), q[0]), ) ast = slr_to_ast(prog) @@ -65,7 +65,7 @@ def test_rz_2pi_removed(self) -> None: """RZ(2*pi) is removed.""" prog = Main( q := QReg("q", 1), - qb.RZ[2 * math.pi](q[0]), + qb.RZ(rad(2 * math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -78,7 +78,7 @@ def test_rz_4pi_removed(self) -> None: """RZ(4*pi) is removed (multiple of 2*pi).""" prog = Main( q := QReg("q", 1), - qb.RZ[4 * math.pi](q[0]), + qb.RZ(rad(4 * math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -95,7 +95,7 @@ def test_rz_nonzero_not_removed(self) -> None: """RZ(0.5) is not removed.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), + qb.RZ(rad(0.5), q[0]), ) ast = slr_to_ast(prog) @@ -108,7 +108,7 @@ def test_rz_pi_not_removed(self) -> None: """RZ(pi) is not removed.""" prog = Main( q := QReg("q", 1), - qb.RZ[math.pi](q[0]), + qb.RZ(rad(math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -141,7 +141,7 @@ def test_removal_inside_if(self) -> None: q := QReg("q", 1), c := CReg("c", 1), If(c[0] == 1).Then( - qb.RZ[0](q[0]), + qb.RZ(rad(0), q[0]), qb.H(q[0]), ), ) @@ -157,7 +157,7 @@ def test_removal_inside_repeat(self) -> None: prog = Main( q := QReg("q", 1), Repeat(cond=3).block( - qb.RX[0](q[0]), + qb.RX(rad(0), q[0]), ), ) @@ -175,10 +175,10 @@ def test_multiple_identity_gates(self) -> None: """Multiple identity gates are removed.""" prog = Main( q := QReg("q", 1), - qb.RZ[0](q[0]), + qb.RZ(rad(0), q[0]), qb.H(q[0]), - qb.RX[0](q[0]), - qb.RY[2 * math.pi](q[0]), + qb.RX(rad(0), q[0]), + qb.RY(rad(2 * math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -191,10 +191,10 @@ def test_mixed_with_nonidentity(self) -> None: """Identity gates removed among non-identity gates.""" prog = Main( q := QReg("q", 1), - qb.RZ[0](q[0]), # Removed - qb.RZ[0.5](q[0]), # Kept - qb.RX[0](q[0]), # Removed - qb.RX[0.5](q[0]), # Kept + qb.RZ(rad(0), q[0]), # Removed + qb.RZ(rad(0.5), q[0]), # Kept + qb.RX(rad(0), q[0]), # Removed + qb.RX(rad(0.5), q[0]), # Kept ) ast = slr_to_ast(prog) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_pipeline.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_pipeline.py index 978e0ee51..ab34609bc 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_pipeline.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_pipeline.py @@ -13,7 +13,7 @@ import math -from pecos.slr import Main, QReg +from pecos.slr import Main, QReg, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.nodes import GateKind, GateOp, LiteralExpr from pecos.slr.ast.optimizations import ( @@ -78,8 +78,8 @@ def test_level_2_rotation_merging(self) -> None: """Level 2 adds rotation merging.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RZ[0.3](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RZ(rad(0.3), q[0]), ) ast = slr_to_ast(prog) @@ -92,7 +92,7 @@ def test_level_3_identity_removal(self) -> None: """Level 3 adds identity removal.""" prog = Main( q := QReg("q", 1), - qb.RZ[0](q[0]), + qb.RZ(rad(0), q[0]), ) ast = slr_to_ast(prog) @@ -137,8 +137,8 @@ def test_pipeline_fixed_point(self) -> None: prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RZ[-0.5](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RZ(rad(-0.5), q[0]), ) ast = slr_to_ast(prog) @@ -206,13 +206,13 @@ def test_default_pipeline_all_optimizations(self) -> None: # Circuit with multiple optimization opportunities prog = Main( q := QReg("q", 1), - qb.RZ[0](q[0]), # Identity removal + qb.RZ(rad(0), q[0]), # Identity removal qb.X(q[0]), # Gate cancellation qb.X(q[0]), qb.SZ(q[0]), # Inverse cancellation qb.SZdg(q[0]), - qb.RZ[0.5](q[0]), # Rotation merging - qb.RZ[0.3](q[0]), + qb.RZ(rad(0.5), q[0]), # Rotation merging + qb.RZ(rad(0.3), q[0]), ) ast = slr_to_ast(prog) @@ -224,7 +224,7 @@ def test_default_pipeline_all_optimizations(self) -> None: assert isinstance(gate, GateOp) assert gate.gate == GateKind.RZ assert isinstance(gate.params[0], LiteralExpr) - assert abs(gate.params[0].value - 0.8) < 1e-10 + assert abs(gate.params[0].value.value.to_radians_signed() - 0.8) < 1e-10 class TestPipelinePassTracking: @@ -257,8 +257,8 @@ def test_total_optimizations(self) -> None: q := QReg("q", 1), qb.X(q[0]), qb.X(q[0]), - qb.RZ[0.5](q[0]), - qb.RZ[0.3](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RZ(rad(0.3), q[0]), ) ast = slr_to_ast(prog) @@ -305,8 +305,8 @@ def test_rotation_to_identity_chain(self) -> None: """Chain of rotations that sum to identity.""" prog = Main( q := QReg("q", 1), - qb.RZ[math.pi](q[0]), - qb.RZ[math.pi](q[0]), # Sum is 2*pi = identity + qb.RZ(rad(math.pi), q[0]), + qb.RZ(rad(math.pi), q[0]), # Sum is 2*pi = identity ) ast = slr_to_ast(prog) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_rotation_merging.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_rotation_merging.py index 0a34abfd3..57bc855c0 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_rotation_merging.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/optimizations/test_rotation_merging.py @@ -13,7 +13,7 @@ import math -from pecos.slr import CReg, If, Main, QReg, Repeat +from pecos.slr import CReg, If, Main, QReg, Repeat, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.nodes import GateKind, GateOp, LiteralExpr from pecos.slr.ast.optimizations import RotationMergingPass @@ -27,8 +27,8 @@ def test_rz_rz_merges(self) -> None: """RZ+RZ on same qubit merges.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RZ[0.3](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RZ(rad(0.3), q[0]), ) ast = slr_to_ast(prog) @@ -43,14 +43,14 @@ def test_rz_rz_merges(self) -> None: assert gate.gate == GateKind.RZ assert len(gate.params) == 1 assert isinstance(gate.params[0], LiteralExpr) - assert abs(gate.params[0].value - 0.8) < 1e-10 + assert abs(gate.params[0].value.value.to_radians_signed() - 0.8) < 1e-10 def test_rx_rx_merges(self) -> None: """RX+RX on same qubit merges.""" prog = Main( q := QReg("q", 1), - qb.RX[math.pi / 4](q[0]), - qb.RX[math.pi / 4](q[0]), + qb.RX(rad(math.pi / 4), q[0]), + qb.RX(rad(math.pi / 4), q[0]), ) ast = slr_to_ast(prog) @@ -63,14 +63,14 @@ def test_rx_rx_merges(self) -> None: assert isinstance(gate, GateOp) assert gate.gate == GateKind.RX assert isinstance(gate.params[0], LiteralExpr) - assert abs(gate.params[0].value - math.pi / 2) < 1e-10 + assert abs(gate.params[0].value.value.to_radians_signed() - math.pi / 2) < 1e-10 def test_ry_ry_merges(self) -> None: """RY+RY on same qubit merges.""" prog = Main( q := QReg("q", 1), - qb.RY[0.1](q[0]), - qb.RY[0.2](q[0]), + qb.RY(rad(0.1), q[0]), + qb.RY(rad(0.2), q[0]), ) ast = slr_to_ast(prog) @@ -83,7 +83,7 @@ def test_ry_ry_merges(self) -> None: assert isinstance(gate, GateOp) assert gate.gate == GateKind.RY assert isinstance(gate.params[0], LiteralExpr) - assert abs(gate.params[0].value - 0.3) < 1e-10 + assert abs(gate.params[0].value.value.to_radians_signed() - 0.3) < 1e-10 class TestRotationMergingNoMerge: @@ -93,8 +93,8 @@ def test_different_rotation_types_no_merge(self) -> None: """Different rotation types do not merge.""" prog = Main( q := QReg("q", 1), - qb.RX[0.5](q[0]), - qb.RZ[0.3](q[0]), + qb.RX(rad(0.5), q[0]), + qb.RZ(rad(0.3), q[0]), ) ast = slr_to_ast(prog) @@ -107,8 +107,8 @@ def test_different_qubits_no_merge(self) -> None: """Rotations on different qubits do not merge.""" prog = Main( q := QReg("q", 2), - qb.RZ[0.5](q[0]), - qb.RZ[0.3](q[1]), + qb.RZ(rad(0.5), q[0]), + qb.RZ(rad(0.3), q[1]), ) ast = slr_to_ast(prog) @@ -121,9 +121,9 @@ def test_interleaved_rotations_no_merge(self) -> None: """Interleaved rotations do not merge.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), + qb.RZ(rad(0.5), q[0]), qb.H(q[0]), # Separates the RZ gates - qb.RZ[0.3](q[0]), + qb.RZ(rad(0.3), q[0]), ) ast = slr_to_ast(prog) @@ -142,8 +142,8 @@ def test_merge_inside_if(self) -> None: q := QReg("q", 1), c := CReg("c", 1), If(c[0] == 1).Then( - qb.RZ[0.5](q[0]), - qb.RZ[0.3](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RZ(rad(0.3), q[0]), ), ) @@ -158,8 +158,8 @@ def test_merge_inside_repeat(self) -> None: prog = Main( q := QReg("q", 1), Repeat(cond=3).block( - qb.RX[0.1](q[0]), - qb.RX[0.2](q[0]), + qb.RX(rad(0.1), q[0]), + qb.RX(rad(0.2), q[0]), ), ) @@ -177,9 +177,9 @@ def test_three_rotations_merge_to_one(self) -> None: """Three consecutive rotations merge to one (requires multiple passes).""" prog = Main( q := QReg("q", 1), - qb.RZ[0.1](q[0]), - qb.RZ[0.2](q[0]), - qb.RZ[0.3](q[0]), + qb.RZ(rad(0.1), q[0]), + qb.RZ(rad(0.2), q[0]), + qb.RZ(rad(0.3), q[0]), ) ast = slr_to_ast(prog) @@ -198,4 +198,4 @@ def test_three_rotations_merge_to_one(self) -> None: gate = result2.program.body[0] assert isinstance(gate, GateOp) assert isinstance(gate.params[0], LiteralExpr) - assert abs(gate.params[0].value - 0.6) < 1e-10 + assert abs(gate.params[0].value.value.to_radians_signed() - 0.6) < 1e-10 diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_guppy.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_guppy.py index 4c9df090d..9cf53c5af 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_guppy.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_guppy.py @@ -12,9 +12,10 @@ """Tests for AST to Guppy code generator.""" import pytest -from pecos.slr import CReg, If, Main, QReg, Repeat +from pecos.slr import CReg, If, Main, QReg, Repeat, Return from pecos.slr.ast import AstToGuppy, ast_to_guppy, slr_to_ast from pecos.slr.qeclib import qubit as qb +from pecos.slr.qeclib.steane.preps.encoding_circ import EncodingCircuit class TestAstToGuppyBasic: @@ -28,7 +29,8 @@ def test_empty_program(self) -> None: code = ast_to_guppy(ast) assert "from guppylang import guppy" in code - assert "from guppylang.std import quantum" in code + assert "from guppylang.std.builtins import array, owned" in code + assert "from guppylang.std.quantum import discard, measure, qubit" in code assert "@guppy" in code assert "def main" in code @@ -71,9 +73,9 @@ def test_single_qubit_gate(self) -> None: code = ast_to_guppy(ast) - # Should generate gate with reassignment for linearity - assert "quantum.h" in code - assert "q[0] = quantum.h(q[0])" in code + assert "q_0, = q" in code + assert "q_0 = h(q_0)" in code + assert "discard(q_0)" in code def test_two_qubit_gate(self) -> None: """Two-qubit gate generates tuple assignment.""" @@ -85,9 +87,8 @@ def test_two_qubit_gate(self) -> None: code = ast_to_guppy(ast) - # Two-qubit gates return tuple - assert "quantum.cx" in code - assert "q[0], q[1] = quantum.cx" in code + assert "q_0, q_1 = q" in code + assert "q_0, q_1 = cx(q_0, q_1)" in code def test_multiple_gates(self) -> None: """Multiple gates generate correct sequence.""" @@ -101,31 +102,30 @@ def test_multiple_gates(self) -> None: code = ast_to_guppy(ast) - assert "quantum.h" in code - assert "quantum.x" in code - assert "quantum.cz" in code + assert "q_0 = h(q_0)" in code + assert "q_1 = x(q_1)" in code + assert "q_0, q_1 = cz(q_0, q_1)" in code class TestAstToGuppyPrepMeasure: - """Prep and measure code generation tests.""" + """PZ and measure code generation tests.""" def test_measure_with_result(self) -> None: - """Measure with result generates variable and return.""" + """Measure with explicit Return generates variable and return.""" prog = Main( q := QReg("q", 1), c := CReg("c", 1), qb.Measure(q[0]) > c[0], + Return(c), ) ast = slr_to_ast(prog) code = ast_to_guppy(ast) - assert "quantum.measure" in code - # Measurement results use local variable names (c_0 instead of c[0]) - assert "c_0 = quantum.measure(q[0])" in code - # Return type should be bool since all qubits are measured - assert "-> bool:" in code - assert "return c_0" in code + assert "c = array(False)" in code + assert "c[0] = measure(q_0)" in code + assert "-> array[bool, 1]:" in code + assert "return c" in code class TestAstToGuppyControlFlow: @@ -145,7 +145,7 @@ def test_if_statement(self) -> None: code = ast_to_guppy(ast) assert "if" in code - assert "quantum.h" in code + assert "q_0 = h(q_0)" in code def test_if_else_statement(self) -> None: """If-else statement generates both branches.""" @@ -166,8 +166,8 @@ def test_if_else_statement(self) -> None: assert "if" in code assert "else:" in code - assert "quantum.h" in code - assert "quantum.x" in code + assert "q_0 = h(q_0)" in code + assert "q_0 = x(q_0)" in code def test_repeat_statement(self) -> None: """Repeat statement generates for-range loop.""" @@ -183,7 +183,7 @@ def test_repeat_statement(self) -> None: # Repeat becomes for _ in range(n) assert "for _ in range(3):" in code - assert "quantum.h" in code + assert "q_0 = h(q_0)" in code class TestAstToGuppyQEC: @@ -208,8 +208,27 @@ def test_syndrome_extraction(self) -> None: assert "ancilla: array[qubit, 1]" in code # Check operations - assert "quantum.cx" in code - assert "quantum.measure" in code + assert "data_0, ancilla_0 = cx(data_0, ancilla_0)" in code + assert "data_1, ancilla_0 = cx(data_1, ancilla_0)" in code + assert "c[0] = measure(ancilla_0)" in code + + def test_qeclib_block_internal_return_does_not_leak_as_main_return(self) -> None: + """S5/M2 provenance guard. + + A qeclib composite block's internal `Return` is a flattened + block-boundary handoff, NOT the Main return -- it is elided at + convert time. `EncodingCircuit` emits a single final root + `ReturnOp(values=('q',))`; a position/count detector would wrongly + make it `return q`. Post-S5 it must be `main(...) -> None` with no + return line. + """ + prog = Main( + q := QReg("q", 7), + EncodingCircuit(q), + ) + code = ast_to_guppy(slr_to_ast(prog)) + assert "-> None:" in code + assert "\n return " not in code class TestAstToGuppyGenerator: @@ -235,8 +254,8 @@ def test_generator_reusable(self) -> None: code1 = "\n".join(generator.generate(ast1)) code2 = "\n".join(generator.generate(ast2)) - assert "q[0]" in code1 - assert "r[0]" in code2 + assert "q_0 = h(q_0)" in code1 + assert "r_0 = x(r_0)" in code2 def test_indentation(self) -> None: """Generated code has proper indentation for nested blocks.""" @@ -291,10 +310,10 @@ def test_full_pipeline(self) -> None: assert "from guppylang import guppy" in code assert "@guppy" in code assert "def main" in code - assert "quantum.h" in code - assert "quantum.cx" in code + assert "q_0 = h(q_0)" in code + assert "q_0, q_1 = cx(q_0, q_1)" in code assert "if" in code - assert "quantum.x" in code + assert "q_2 = x(q_2)" in code def test_bell_state_circuit(self) -> None: """Test a simple Bell state circuit.""" @@ -318,5 +337,5 @@ def test_bell_state_circuit(self) -> None: assert any("def main" in line for line in lines) # Check gates are in function body (indented) - gate_lines = [line for line in lines if "quantum." in line] + gate_lines = [line for line in lines if " = h(" in line or " = cx(" in line] assert all(line.startswith(" ") for line in gate_lines) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qasm.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qasm.py index c6499f7f6..e4694f05b 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qasm.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qasm.py @@ -141,7 +141,7 @@ def test_phase_gates(self) -> None: class TestAstToQasmPrepMeasure: - """Prep and measure code generation tests.""" + """PZ and measure code generation tests.""" def test_measure_with_result(self) -> None: """Measure with result generates arrow syntax.""" @@ -157,10 +157,10 @@ def test_measure_with_result(self) -> None: assert "measure q[0] -> c[0];" in code def test_prep_reset(self) -> None: - """Prep generates reset operation.""" + """PZ generates reset operation.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), ) ast = slr_to_ast(prog) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qir.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qir.py index fae4f5305..6cf6fd09d 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qir.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_qir.py @@ -146,7 +146,7 @@ def test_two_qubit_cz_gate(self) -> None: class TestAstToQirPrepMeasure: - """Prep and measure code generation tests.""" + """PZ and measure code generation tests.""" def test_measurement(self) -> None: """Measurement generates mz_to_creg_bit call.""" @@ -163,10 +163,10 @@ def test_measurement(self) -> None: assert "mz_to_creg_bit" in llvm_ir def test_prep_reset(self) -> None: - """Prep generates reset_body call.""" + """PZ generates reset_body call.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), ) ast = slr_to_ast(prog) @@ -232,7 +232,7 @@ def test_results_output(self) -> None: """Result CReg generates int_record_output call.""" prog = Main( q := QReg("q", 1), - c := CReg("c", 1, result=True), + c := CReg("c", 1), qb.Measure(q[0]) > c[0], ) ast = slr_to_ast(prog) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_quantum_circuit.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_quantum_circuit.py index 402dda193..f1f1b1576 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_quantum_circuit.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_quantum_circuit.py @@ -148,7 +148,7 @@ def test_multiple_gates_different_qubits(self) -> None: class TestAstToQuantumCircuitPrepMeasure: - """Prep and measure code generation tests.""" + """PZ and measure code generation tests.""" def test_measurement(self) -> None: """Measurement creates tick with Measure operation.""" @@ -167,10 +167,10 @@ def test_measurement(self) -> None: assert 0 in tick["Measure"] def test_prep_reset(self) -> None: - """Prep creates tick with RESET operation.""" + """PZ creates tick with RESET operation.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), ) ast = slr_to_ast(prog) @@ -390,7 +390,7 @@ def test_circuit_with_repeated_syndrome(self) -> None: qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), qb.Measure(ancilla[0]) > c[0], - qb.Prep(ancilla[0]), + qb.PZ(ancilla[0]), ), ) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_stim.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_stim.py index 24970736c..14c42dd8c 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_stim.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_codegen_stim.py @@ -168,7 +168,7 @@ def test_multiple_gates(self) -> None: class TestAstToStimPrepMeasure: - """Prep and measure code generation tests.""" + """PZ and measure code generation tests.""" def test_measurement(self) -> None: """Measurement generates M instruction.""" @@ -202,10 +202,10 @@ def test_multiple_measurements(self) -> None: assert "1" in code def test_prep_reset(self) -> None: - """Prep generates R (reset) instruction.""" + """PZ generates R (reset) instruction.""" prog = Main( q := QReg("q", 1), - qb.Prep(q[0]), + qb.PZ(q[0]), ) ast = slr_to_ast(prog) @@ -299,7 +299,7 @@ def test_repeated_syndrome_extraction(self) -> None: qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), qb.Measure(ancilla[0]) > c[0], - qb.Prep(ancilla[0]), + qb.PZ(ancilla[0]), ), ) ast = slr_to_ast(prog) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_converter.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_converter.py index 948164a1d..62e0ab209 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_converter.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_converter.py @@ -71,7 +71,6 @@ def test_program_with_creg(self) -> None: assert isinstance(decl, RegisterDecl) assert decl.name == "c" assert decl.size == 3 - assert decl.is_result is True def test_program_with_both_regs(self) -> None: """Program with QReg and CReg has both declarations.""" @@ -143,13 +142,13 @@ def test_multiple_gates(self) -> None: class TestSlrToAstPrepMeasure: - """Prep and Measure conversion tests.""" + """PZ and Measure conversion tests.""" def test_prep_operation(self) -> None: - """Prep converts to PrepareOp with correct allocator and slots.""" + """PZ converts to PrepareOp with correct allocator and slots.""" prog = Main( q := QReg("q", 2), - qb.Prep(q[0]), + qb.PZ(q[0]), ) ast = slr_to_ast(prog) @@ -299,9 +298,9 @@ def test_syndrome_extraction_pattern(self) -> None: ancilla := QReg("ancilla", 1), c := CReg("c", 1), # Initialize - qb.Prep(data[0]), - qb.Prep(data[1]), - qb.Prep(ancilla[0]), + qb.PZ(data[0]), + qb.PZ(data[1]), + qb.PZ(ancilla[0]), # Syndrome extraction qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), @@ -331,9 +330,9 @@ def test_round_trip_preserves_structure(self) -> None: prog = Main( q := QReg("q", 3), c := CReg("c", 1), - qb.Prep(q[0]), - qb.Prep(q[1]), - qb.Prep(q[2]), + qb.PZ(q[0]), + qb.PZ(q[1]), + qb.PZ(q[2]), qb.H(q[0]), qb.CX(q[0], q[1]), qb.CX(q[1], q[2]), diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_nodes.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_nodes.py index 2a9644c78..3a71e0c81 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_nodes.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_nodes.py @@ -350,16 +350,12 @@ def test_allocator_decl_with_parent(self) -> None: assert decl.parent == "base" def test_register_decl(self) -> None: - """RegisterDecl stores name and size.""" + """RegisterDecl stores name and size (no is_result field post-3b).""" decl = RegisterDecl(name="c", size=5) assert decl.name == "c" assert decl.size == 5 - assert decl.is_result is True - - def test_register_decl_not_result(self) -> None: - """RegisterDecl can mark register as not a result.""" - decl = RegisterDecl(name="scratch", size=3, is_result=False) - assert decl.is_result is False + with pytest.raises(AttributeError): + _ = decl.is_result class TestProgram: diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_permute.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_permute.py index e48e86d4b..ecf21f25f 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_permute.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_permute.py @@ -105,8 +105,26 @@ def test_permute_guppy_codegen(self) -> None: ast = slr_to_ast(prog) guppy = generate(ast, "guppy") - # Should contain swap code - assert "Swap" in guppy or "_temp_" in guppy or "a, b = b, a" in guppy + # Qubits are remapped logically in the Guppy slot state. + assert "# Permute: a -> b, b -> a" in guppy + assert "b_0 = x(b_0)" in guppy + + def test_creg_permute_guppy_uses_mem_swap(self) -> None: + """Test CReg Permute uses Guppy's in-place swap helper.""" + prog = Main( + c := CReg("c", 2), + d := CReg("d", 2), + c[0].set(1), + d[1].set(1), + Permute(c, d), + ) + ast = slr_to_ast(prog) + + guppy = generate(ast, "guppy") + + assert "from guppylang.std.mem import mem_swap" in guppy + assert "mem_swap(c[0], d[0])" in guppy + assert "mem_swap(c[1], d[1])" in guppy def test_permute_qasm_codegen(self) -> None: """Test Permute generates QASM comment.""" diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_roundtrip.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_roundtrip.py index d4faf033a..b43a533e5 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_roundtrip.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_roundtrip.py @@ -137,7 +137,7 @@ def test_qec_syndrome_qasm_structure(self) -> None: data := QReg("data", 2), ancilla := QReg("ancilla", 1), c := CReg("c", 1), - qb.Prep(ancilla[0]), + qb.PZ(ancilla[0]), qb.CX(data[0], ancilla[0]), qb.CX(data[1], ancilla[0]), qb.Measure(ancilla[0]) > c[0], @@ -202,11 +202,11 @@ def test_bell_state_guppy_structure(self) -> None: # Verify Guppy imports and structure assert "from guppylang import guppy" in guppy - assert "from guppylang.std import quantum" in guppy + assert "from guppylang.std.builtins import array, owned" in guppy assert "@guppy" in guppy assert "def main" in guppy.lower() - assert "quantum.h" in guppy - assert "quantum.cx" in guppy + assert "q_0 = h(q_0)" in guppy + assert "q_0, q_1 = cx(q_0, q_1)" in guppy def test_measurement_guppy_structure(self) -> None: """Test measurement generates correct Guppy structure.""" @@ -219,7 +219,7 @@ def test_measurement_guppy_structure(self) -> None: ast = slr_to_ast(prog) guppy = ast_to_guppy(ast) - assert "quantum.measure" in guppy + assert "c[0] = measure(q_0)" in guppy class TestRoundTripStim: @@ -439,8 +439,9 @@ def test_same_qubit_order_all_generators(self) -> None: # Guppy guppy = ast_to_guppy(ast) - assert "a[0]" in guppy - assert "b[0]" in guppy + assert "a_0 = h(a_0)" in guppy + assert "b_0 = h(b_0)" in guppy + assert "a_0, b_0 = cx(a_0, b_0)" in guppy def test_gate_sequence_preserved_all_generators(self) -> None: """Test that gate sequence is preserved in all generators.""" @@ -462,9 +463,9 @@ def test_gate_sequence_preserved_all_generators(self) -> None: # Guppy - check order guppy = ast_to_guppy(ast) - h_pos = guppy.find("quantum.h") - x_pos = guppy.find("quantum.x") - z_pos = guppy.find("quantum.z") + h_pos = guppy.find("q_0 = h(q_0)") + x_pos = guppy.find("q_0 = x(q_0)") + z_pos = guppy.find("q_0 = z(q_0)") assert h_pos < x_pos < z_pos, "Gate order not preserved in Guppy" diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_serialize.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_serialize.py index 1cf2a427b..71e5111c0 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_serialize.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_serialize.py @@ -448,15 +448,17 @@ def test_empty_program(self) -> None: assert len(restored.declarations) == 0 assert len(restored.body) == 0 - def test_register_with_is_result_false(self) -> None: - """Test RegisterDecl with is_result=False.""" - decl = RegisterDecl(name="scratch", size=4, is_result=False) + def test_register_decl_roundtrip(self) -> None: + """RegisterDecl round-trips through dict (no is_result field post-3b).""" + decl = RegisterDecl(name="scratch", size=4) data = ast_to_dict(decl) restored = dict_to_ast(data) assert isinstance(restored, RegisterDecl) - assert restored.is_result is False + assert restored.name == "scratch" + assert restored.size == 4 + assert "is_result" not in data def test_allocator_without_parent(self) -> None: """Test AllocatorDecl without parent.""" diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_visitor.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_visitor.py index 0ac53494f..0c1191163 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_visitor.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_ast_visitor.py @@ -191,8 +191,8 @@ def visit_gate(self, node: GateOp) -> str: def visit_prepare(self, node: PrepareOp) -> str: if node.slots: slots = ", ".join(str(s) for s in node.slots) - return f"Prep({node.allocator}[{slots}])" - return f"Prep({node.allocator}[*])" + return f"PZ({node.allocator}[{slots}])" + return f"PZ({node.allocator}[*])" def visit_measure(self, node: MeasureOp) -> str: targets = ", ".join(str(t) for t in node.targets) @@ -209,7 +209,7 @@ def visit_measure(self, node: MeasureOp) -> str: prog = Program(name="test", body=(prep, gate1, gate2, measure)) code = visitor.visit(prog) - assert "Prep(q[0, 1])" in code + assert "PZ(q[0, 1])" in code assert "H(q[0])" in code assert "CX(q[0], q[1])" in code assert "Measure(q[0])" in code @@ -273,3 +273,71 @@ def visit_gate(self, node: GateOp) -> None: assert "H" in visitor.visited assert "X" in visitor.visited + + +class TestVisitorDispatchCompleteness: + """Safety net for the centralized `_DISPATCH` (replaced per-node + `accept()`): a new concrete AST node without a dispatch entry must + fail loudly here -- this is what catching a missing `accept()` + used to do implicitly. + """ + + @staticmethod + def _concrete_node_names() -> set[str]: + import pecos.slr.ast.nodes as nodes_mod + + def all_subclasses(cls: type) -> set[type]: + out: set[type] = set() + for sub in cls.__subclasses__(): + out.add(sub) + out |= all_subclasses(sub) + return out + + # Only nodes shipped in `pecos.slr.ast.nodes` -- user/test + # subclasses (e.g. `MyGate(GateOp)`) are intentionally resolved + # by MRO in BaseVisitor.visit and must NOT be required in + # _DISPATCH, so scope the enumeration to the nodes module. + nodes = {c for c in all_subclasses(nodes_mod.AstNode) if c.__module__ == nodes_mod.__name__} + # Intermediate/abstract bases (AstNode, Expression, Statement, + # TypeExpr, Declaration, BlockArg) are never instantiated directly + # and are correctly absent from _DISPATCH. + bases = {base for cls in nodes for base in cls.__bases__ if base in nodes or base is nodes_mod.AstNode} + return {cls.__name__ for cls in nodes if cls not in bases} + + def test_every_concrete_node_has_a_dispatch_entry(self) -> None: + from pecos.slr.ast.visitor import _DISPATCH + + missing = sorted(self._concrete_node_names() - set(_DISPATCH)) + assert ( + not missing + ), f"concrete AST nodes with no _DISPATCH entry: {missing} (add them to pecos.slr.ast.visitor._DISPATCH)" + + def test_no_stale_or_invalid_dispatch_entries(self) -> None: + from pecos.slr.ast.visitor import _DISPATCH, BaseVisitor + + concrete = self._concrete_node_names() + stale = sorted(set(_DISPATCH) - concrete) + assert not stale, f"_DISPATCH keys that are not concrete nodes: {stale}" + bad = sorted(v for v in _DISPATCH.values() if not callable(getattr(BaseVisitor, v, None))) + assert not bad, f"_DISPATCH values not methods on BaseVisitor: {bad}" + + def test_subclass_of_concrete_node_dispatches_via_mro(self) -> None: + """Visitor-refactor rationale: the old `node.accept(self)` + double-dispatch was inherited, so a user subclass of a concrete + node dispatched to the base node's `visit_*`. MRO lookup in + `BaseVisitor.visit` must preserve that (a bare class-name match + would wrongly raise). + """ + + class MyGate(GateOp): + pass + + class Recorder(BaseVisitor[str]): + def visit_gate(self, node: GateOp) -> str: + return f"gate:{node.gate.name}" + + def default_result(self) -> str: + return "" + + node = MyGate(gate=GateKind.H, targets=(SlotRef(allocator="q", index=0),)) + assert Recorder().visit(node) == "gate:H" diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_codegen_equivalence.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_codegen_equivalence.py index 03e32b2be..2af00ddb2 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_codegen_equivalence.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_codegen_equivalence.py @@ -15,11 +15,10 @@ """ import pytest -from pecos.slr import CReg, If, Main, Permute, QReg, Repeat +from pecos.slr import CReg, If, Main, Permute, QReg, Repeat, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.codegen import generate from pecos.slr.gen_codes import ( - GuppyGenerator, QASMGenerator, QIRGenerator, QuantumCircuitGenerator, @@ -174,38 +173,6 @@ def test_clifford_gates_stim(self) -> None: assert direct_lines == ast_lines -class TestGuppyEquivalence: - """Compare Guppy output from direct SLR vs AST.""" - - def test_bell_state_guppy_structure(self) -> None: - """Test Bell state produces structurally similar Guppy.""" - prog = Main( - q := QReg("q", 2), - qb.H(q[0]), - qb.CX(q[0], q[1]), - ) - - # AST path first (before direct generator mutates prog) - ast = slr_to_ast(prog) - ast_guppy = generate(ast, "guppy") - - # Direct SLR path - gen = GuppyGenerator(_internal=True) - gen.generate_block(prog) - direct_guppy = gen.get_output() - - # Both should have key elements - assert "@guppy" in direct_guppy or "guppy" in direct_guppy.lower() - assert "@guppy" in ast_guppy or "guppy" in ast_guppy.lower() - - # Both should have H and CX gates - assert "quantum.h" in direct_guppy.lower() or ".h(" in direct_guppy.lower() - assert "quantum.h" in ast_guppy.lower() or ".h(" in ast_guppy.lower() - - assert "quantum.cx" in direct_guppy.lower() or ".cx(" in direct_guppy.lower() - assert "quantum.cx" in ast_guppy.lower() or ".cx(" in ast_guppy.lower() - - class TestQIREquivalence: """Compare QIR output from direct SLR vs AST.""" @@ -388,7 +355,7 @@ def test_rx_gate_qasm(self) -> None: prog = Main( q := QReg("q", 1), - qb.RX[math.pi / 4](q[0]), + qb.RX(rad(math.pi / 4), q[0]), ) # AST path first @@ -409,7 +376,7 @@ def test_ry_gate_qasm(self) -> None: prog = Main( q := QReg("q", 1), - qb.RY[math.pi / 2](q[0]), + qb.RY(rad(math.pi / 2), q[0]), ) ast = slr_to_ast(prog) @@ -428,7 +395,7 @@ def test_rz_gate_qasm(self) -> None: prog = Main( q := QReg("q", 1), - qb.RZ[math.pi](q[0]), + qb.RZ(rad(math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -447,9 +414,9 @@ def test_multiple_rotations_qasm(self) -> None: prog = Main( q := QReg("q", 2), - qb.RX[math.pi / 4](q[0]), - qb.RY[math.pi / 2](q[1]), - qb.RZ[math.pi](q[0]), + qb.RX(rad(math.pi / 4), q[0]), + qb.RY(rad(math.pi / 2), q[1]), + qb.RZ(rad(math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -699,9 +666,9 @@ def test_qft_like_circuit_qasm(self) -> None: prog = Main( q := QReg("q", 3), qb.H(q[0]), - qb.RZ[math.pi / 2](q[0]), + qb.RZ(rad(math.pi / 2), q[0]), qb.H(q[1]), - qb.RZ[math.pi / 4](q[1]), + qb.RZ(rad(math.pi / 4), q[1]), qb.H(q[2]), ) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_integration.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_integration.py index 5f0d743f4..2b2829418 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_integration.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_integration.py @@ -23,7 +23,7 @@ import math import pytest -from pecos.slr import CReg, If, Main, QReg, Repeat +from pecos.slr import CReg, If, Main, QReg, Repeat, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.codegen import ( CodegenOptions, @@ -160,9 +160,9 @@ def test_rotation_gates_roundtrip(self) -> None: """Rotation gates with float params round-trip.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RX[math.pi / 4](q[0]), - qb.RY[1.234567890123](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RX(rad(math.pi / 4), q[0]), + qb.RY(rad(1.234567890123), q[0]), ) ast = slr_to_ast(prog) diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_pretty_print.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_pretty_print.py index 577cc90e5..7194a7c65 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_pretty_print.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_pretty_print.py @@ -13,7 +13,7 @@ import math -from pecos.slr import CReg, If, Main, QReg, Repeat +from pecos.slr import CReg, If, Main, QReg, Repeat, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.nodes import ( AllocatorDecl, @@ -121,16 +121,16 @@ def test_rotation_gates(self) -> None: """Rotation gates with parameters print correctly.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RX[math.pi](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RX(rad(math.pi), q[0]), ) ast = slr_to_ast(prog) result = pretty_print(ast) - assert "qb.RZ[0.5](q[0])" in result - # Pi value should be formatted - assert "qb.RX[" in result + assert "qb.RZ(rad(0.5), q[0])" in result + # Angle-first form `qb.RX(theta, q)`, angle rendered first. + assert "qb.RX(rad(3.141592653589793), q[0])" in result class TestPrettyPrintControlFlow: @@ -296,12 +296,12 @@ def test_rotation_gate_statement(self) -> None: stmt = GateOp( gate=GateKind.RZ, targets=(SlotRef(allocator="q", index=0),), - params=(LiteralExpr(value=0.25),), + params=(LiteralExpr(value=rad(0.25)),), ) result = format_statement(stmt) - assert result == "qb.RZ[0.25](q[0])" + assert result == "qb.RZ(rad(0.25), q[0])" class TestPrettyPrintIndentation: diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_serialize.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_serialize.py index 0f01a809b..0530b94cd 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/test_serialize.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/test_serialize.py @@ -15,7 +15,7 @@ import math import pytest -from pecos.slr import CReg, If, Main, QReg, Repeat +from pecos.slr import CReg, If, Main, QReg, Repeat, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.nodes import ( AllocatorDecl, @@ -236,8 +236,8 @@ def test_rotation_gates(self) -> None: """Rotation gates with float params round-trip.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RX[math.pi](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RX(rad(math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -247,7 +247,8 @@ def test_rotation_gates(self) -> None: # Find RZ gate rz_gates = [s for s in restored.body if isinstance(s, GateOp) and s.gate == GateKind.RZ] assert len(rz_gates) == 1 - assert rz_gates[0].params[0].value == 0.5 + # The typed angle round-trips by exact fixed-point fraction + unit. + assert rz_gates[0].params[0].value == rad(0.5) def test_measurement(self) -> None: """Measurement with classical register round-trips.""" diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_pipeline.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_pipeline.py index cd8b884e7..4bd20b9a3 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_pipeline.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_pipeline.py @@ -13,7 +13,7 @@ import math -from pecos.slr import CReg, If, Main, QReg +from pecos.slr import CReg, If, Main, QReg, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.nodes import ( AllocatorDecl, @@ -54,8 +54,8 @@ def test_validate_with_rotation(self) -> None: """Program with rotation gates passes.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RX[math.pi](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RX(rad(math.pi), q[0]), ) ast = slr_to_ast(prog) @@ -71,7 +71,7 @@ def test_validate_complex_circuit(self) -> None: qb.H(q[0]), qb.CX(q[0], q[1]), qb.CX(q[1], q[2]), - qb.RZ[0.5](q[0]), + qb.RZ(rad(0.5), q[0]), If(c[0] == 1).Then( qb.X(q[0]), ), diff --git a/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_type_checker.py b/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_type_checker.py index d5fb45963..90a750b38 100644 --- a/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_type_checker.py +++ b/python/quantum-pecos/tests/pecos/slr/ast_tests/validation/test_type_checker.py @@ -13,7 +13,7 @@ import math -from pecos.slr import CReg, If, Main, QReg, Repeat +from pecos.slr import CReg, If, Main, QReg, Repeat, rad from pecos.slr.ast import slr_to_ast from pecos.slr.ast.nodes import ( AllocatorDecl, @@ -53,8 +53,8 @@ def test_valid_rotation_gates(self) -> None: """Valid rotation gate with angle parameter.""" prog = Main( q := QReg("q", 1), - qb.RZ[0.5](q[0]), - qb.RX[math.pi](q[0]), + qb.RZ(rad(0.5), q[0]), + qb.RX(rad(math.pi), q[0]), ) ast = slr_to_ast(prog) diff --git a/python/quantum-pecos/tests/pecos/slr/test_qubit_state_validator.py b/python/quantum-pecos/tests/pecos/slr/test_qubit_state_validator.py deleted file mode 100644 index a606506d7..000000000 --- a/python/quantum-pecos/tests/pecos/slr/test_qubit_state_validator.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2026 The PECOS Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. - -"""Tests for QubitStateValidator - compile-time detection of unprepared qubit usage.""" - -import pytest -from pecos.slr import CReg, If, Main, QReg -from pecos.slr.gen_codes.guppy.qubit_state_validator import ( - QubitStateValidator, - StateViolation, - validate_qubit_states, -) -from pecos.slr.qeclib import qubit as qb - - -class TestQubitStateValidatorBasic: - """Basic validation tests.""" - - def test_gate_without_prep_strict_mode(self) -> None: - """In strict mode, gate without prep is a violation.""" - prog = Main( - q := QReg("q", 2), - qb.H(q[0]), # No prep before H - should be violation - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 1 - assert violations[0].array_name == "q" - assert violations[0].index == 0 - assert "H" in violations[0].gate_name - assert "unprepared" in violations[0].message.lower() - - def test_gate_without_prep_non_strict_mode(self) -> None: - """In non-strict mode (legacy), qubits start prepared.""" - prog = Main( - q := QReg("q", 2), - qb.H(q[0]), # No prep but non-strict mode - OK - ) - - variable_context = {"q": prog.vars.get("q")} - violations = validate_qubit_states(prog, variable_context, strict=False) - - assert len(violations) == 0 - - def test_prep_then_gate_is_valid(self) -> None: - """Prep followed by gate is valid.""" - prog = Main( - q := QReg("q", 2), - qb.Prep(q[0]), - qb.H(q[0]), # Prep before H - valid - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 0 - - def test_measure_then_gate_is_violation(self) -> None: - """Gate after measurement without re-prep is violation.""" - prog = Main( - q := QReg("q", 2), - c := CReg("c", 2), - qb.Prep(q[0]), - qb.Measure(q[0]) > c[0], - qb.H(q[0]), # After measure, no re-prep - violation - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 1 - assert violations[0].array_name == "q" - assert violations[0].index == 0 - - def test_measure_reprep_gate_is_valid(self) -> None: - """Measure, re-prep, then gate is valid.""" - prog = Main( - q := QReg("q", 2), - c := CReg("c", 2), - qb.Prep(q[0]), - qb.Measure(q[0]) > c[0], - qb.Prep(q[0]), # Re-prep after measure - qb.H(q[0]), # Now valid - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 0 - - -class TestQubitStateValidatorMultiQubit: - """Tests with multiple qubits.""" - - def test_independent_qubit_tracking(self) -> None: - """Each qubit's state is tracked independently.""" - prog = Main( - q := QReg("q", 3), - qb.Prep(q[0]), - qb.Prep(q[1]), - # q[2] not prepared - qb.H(q[0]), # OK - qb.H(q[1]), # OK - qb.H(q[2]), # Violation - not prepared - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 1 - assert violations[0].index == 2 - - def test_two_qubit_gate_both_prepared(self) -> None: - """Two-qubit gate requires both qubits prepared.""" - prog = Main( - q := QReg("q", 2), - qb.Prep(q[0]), - qb.Prep(q[1]), - qb.CX(q[0], q[1]), # Both prepared - valid - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 0 - - def test_two_qubit_gate_one_unprepared(self) -> None: - """Two-qubit gate with one unprepared qubit is violation.""" - prog = Main( - q := QReg("q", 2), - qb.Prep(q[0]), - # q[1] not prepared - qb.CX(q[0], q[1]), # q[1] unprepared - violation - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 1 - assert violations[0].index == 1 - - def test_two_qubit_gate_both_unprepared(self) -> None: - """Two-qubit gate with both unprepared is two violations.""" - prog = Main( - q := QReg("q", 2), - qb.CX(q[0], q[1]), # Both unprepared - two violations - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 2 - - -class TestQubitStateValidatorConditionals: - """Tests with conditional blocks.""" - - def test_if_block_both_branches_prepare(self) -> None: - """If both branches prepare, qubit is prepared after.""" - prog = Main( - q := QReg("q", 1), - c := CReg("c", 1), - If(c[0] == 1) - .Then( - qb.Prep(q[0]), - ) - .Else( - qb.Prep(q[0]), - ), - qb.H(q[0]), # Prepared in both branches - valid - ) - - variable_context = {"q": prog.vars.get("q"), "c": prog.vars.get("c")} - violations = validate_qubit_states(prog, variable_context, strict=True) - - assert len(violations) == 0 - - def test_if_block_only_then_prepares(self) -> None: - """If only then branch prepares, qubit may be unprepared after.""" - prog = Main( - q := QReg("q", 1), - c := CReg("c", 1), - If(c[0] == 1).Then( - qb.Prep(q[0]), - ), - # No else - q[0] may not be prepared - qb.H(q[0]), # May be unprepared - violation - ) - - variable_context = {"q": prog.vars.get("q"), "c": prog.vars.get("c")} - violations = validate_qubit_states(prog, variable_context, strict=True) - - # Should detect violation - qubit not prepared in else branch - assert len(violations) >= 1 - - -class TestQubitStateValidatorQECPattern: - """Tests for typical QEC patterns.""" - - def test_syndrome_extraction_pattern(self) -> None: - """Typical syndrome extraction: prep, use, measure, re-prep cycle.""" - prog = Main( - data := QReg("data", 2), - ancilla := QReg("ancilla", 1), - c := CReg("c", 1), - # Initialize data qubits - qb.Prep(data[0]), - qb.Prep(data[1]), - # Syndrome extraction round 1 - qb.Prep(ancilla[0]), - qb.CX(data[0], ancilla[0]), - qb.CX(data[1], ancilla[0]), - qb.Measure(ancilla[0]) > c[0], - # Syndrome extraction round 2 - qb.Prep(ancilla[0]), # Re-prep after measure - qb.CX(data[0], ancilla[0]), - qb.CX(data[1], ancilla[0]), - qb.Measure(ancilla[0]) > c[0], - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 0 - - def test_missing_reprep_in_qec_cycle(self) -> None: - """Detect missing re-prep in QEC cycle.""" - prog = Main( - data := QReg("data", 1), - ancilla := QReg("ancilla", 1), - c := CReg("c", 1), - # Initialize - qb.Prep(data[0]), - qb.Prep(ancilla[0]), - # Round 1 - qb.CX(data[0], ancilla[0]), - qb.Measure(ancilla[0]) > c[0], - # Round 2 - MISSING re-prep of ancilla - qb.CX(data[0], ancilla[0]), # ancilla[0] is unprepared - violation - ) - - violations = validate_qubit_states(prog, strict=True) - - assert len(violations) == 1 - assert violations[0].array_name == "ancilla" - assert violations[0].index == 0 - - -class TestStateViolation: - """Tests for StateViolation dataclass.""" - - def test_string_representation(self) -> None: - """StateViolation has readable string representation.""" - violation = StateViolation( - array_name="q", - index=2, - position=5, - gate_name="H", - message="Test message", - ) - - s = str(violation) - assert "q[2]" in s - assert "position 5" in s - assert "Test message" in s - - -class TestValidatorClass: - """Tests for QubitStateValidator class directly.""" - - def test_validator_reusable(self) -> None: - """Validator can be reused for multiple blocks.""" - validator = QubitStateValidator(strict=True) - - prog1 = Main( - q := QReg("q", 1), - qb.H(q[0]), - ) - - prog2 = Main( - q := QReg("q", 1), - qb.Prep(q[0]), - qb.H(q[0]), - ) - - violations1 = validator.validate(prog1) - violations2 = validator.validate(prog2) - - assert len(violations1) == 1 - assert len(violations2) == 0 - - def test_strict_flag(self) -> None: - """Strict flag controls initial state assumption.""" - prog = Main( - q := QReg("q", 1), - qb.H(q[0]), - ) - - strict_validator = QubitStateValidator(strict=True) - non_strict_validator = QubitStateValidator(strict=False) - - variable_context = {"q": prog.vars.get("q")} - - strict_violations = strict_validator.validate(prog) - non_strict_violations = non_strict_validator.validate(prog, variable_context) - - assert len(strict_violations) == 1 - assert len(non_strict_violations) == 0 diff --git a/python/quantum-pecos/tests/pecos/unit/test_slr_converter_guppy.py b/python/quantum-pecos/tests/pecos/unit/test_slr_converter_guppy.py index 3da0e0ed3..8a4aa247c 100644 --- a/python/quantum-pecos/tests/pecos/unit/test_slr_converter_guppy.py +++ b/python/quantum-pecos/tests/pecos/unit/test_slr_converter_guppy.py @@ -9,33 +9,16 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -"""Tests for SlrConverter Guppy functionality.""" - -from pecos.slr import CReg, Main, Parallel, QReg, SlrConverter -from pecos.slr.qeclib import qubit as qb -from pecos.slr.qeclib.steane.steane_class import Steane +"""Tests for SlrConverter Guppy functionality. +The v1 AST -> Guppy emitter is exercised via compile-and-run acceptance tests +under ``tests/slr_tests/ast_guppy/``. Tests here cover basic structural sanity +of `SlrConverter.hugr()` (AST-routed post-cutover, wraps `main` in a no-arg +`entry()` and compiles that). +""" -def test_slr_converter_guppy_simple() -> None: - """Test SlrConverter.guppy() with a simple program.""" - prog = Main( - q := QReg("q", 2), - c := CReg("c", 2), - qb.H(q[0]), - qb.CX(q[0], q[1]), - qb.Measure(q) > c, - ) - - guppy_code = SlrConverter(prog).guppy() - - # Check that the generated code is valid Python - # AST codegen uses simplified imports - assert "from guppylang import guppy" in guppy_code - assert "@guppy" in guppy_code - # AST codegen uses array parameters - assert "def main(q:" in guppy_code - assert "quantum.h(" in guppy_code - assert "quantum.cx(" in guppy_code +from pecos.slr import CReg, Main, QReg, SlrConverter +from pecos.slr.qeclib import qubit as qb def test_slr_converter_guppy_does_not_have_undefined_variables() -> None: @@ -99,88 +82,6 @@ def test_slr_converter_hugr_simple() -> None: assert hasattr(hugr, "modules") -def test_slr_converter_steane_guppy_generation() -> None: - """Test that Steane code can generate Guppy code without undefined variables.""" - prog = Main( - c := Steane("c"), - c.px(), # Simple Pauli-X operation - ) - - # This should generate valid Guppy code without undefined variables - guppy_code = SlrConverter(prog).guppy() - - # AST codegen uses array parameters instead of local declarations - # Check that c_a is declared as parameter or that c_a[i] is properly accessed - # Either c_a: array[...] @owned in params, or c_a = array(...) in body - c_a_in_params = "c_a: array[qubit" in guppy_code - c_a_declared = "c_a =" in guppy_code or "c_a=" in guppy_code - - # If c_a appears in the code, it should be in params or declared - assert ("c_a" not in guppy_code) or c_a_in_params or c_a_declared - - # Code should have quantum operations - assert "quantum." in guppy_code - - -def test_slr_converter_steane_hugr_compilation() -> None: - """Test that Steane code can compile to HUGR. - - This test verifies that the Steane code implementation can be successfully - compiled to HUGR format through guppylang. The test ensures that: - - 1. Ancilla arrays (like c_a) are properly detected and excluded from structs - 2. These arrays are passed to functions with @owned annotation - 3. Arrays are unpacked to individual variables to avoid MoveOutOfSubscriptError - 4. The unpacked variables are used instead of array indexing in function bodies - - The solution works by: - - Detecting ancilla qubits based on usage patterns (frequent measurement/reset) - - Excluding them from struct packing to keep them as separate arrays - - Unpacking @owned ancilla arrays at the start of functions - - Using the unpacked variables (e.g., c_a_0) instead of array access (c_a[0]) - - Note: The guppy code generation itself works correctly, but the final - compilation to HUGR fails due to API mismatch between guppylang-internals - (expecting hugr.build module) and hugr 0.13.0 (which doesn't have it). - """ - prog = Main( - c := Steane("c"), - c.px(), - ) - - # This should work once guppylang supports the required patterns - hugr = SlrConverter(prog).hugr() - assert hugr is not None - assert hasattr(hugr, "modules") - - -def test_slr_converter_parallel_blocks_guppy() -> None: - """Test Guppy generation with parallel blocks.""" - prog = Main( - q := QReg("q", 4), - c := CReg("c", 4), - Parallel( - qb.H(q[0]), - qb.X(q[1]), - qb.H(q[2]), - qb.X(q[3]), - ), - qb.Measure(q) > c, - ) - - guppy_code = SlrConverter(prog).guppy() - - # Should contain the gates - assert "quantum.h(" in guppy_code - assert "quantum.x(" in guppy_code - assert "quantum.measure" in guppy_code - - # Should not have undefined variables - undefined_vars = ["c_a", "c_a_0"] - for var in undefined_vars: - assert var not in guppy_code, f"Generated code contains undefined variable: {var}" - - def test_slr_converter_guppy_has_main_function() -> None: """Test that generated Guppy code has a proper main function.""" prog = Main( @@ -195,25 +96,3 @@ def test_slr_converter_guppy_has_main_function() -> None: # Should have main function with array parameters assert "def main(" in guppy_code assert "@guppy" in guppy_code - - -def test_slr_converter_guppy_imports() -> None: - """Test that generated Guppy code has correct imports.""" - prog = Main( - q := QReg("q", 1), - c := CReg("c", 1), - qb.H(q[0]), - qb.Measure(q) > c, - ) - - guppy_code = SlrConverter(prog).guppy() - - # AST codegen uses simplified imports - required_imports = [ - "from guppylang import guppy", - "from guppylang.std import quantum", - "from guppylang.std.quantum import qubit", - ] - - for imp in required_imports: - assert imp in guppy_code, f"Missing import: {imp}" diff --git a/python/quantum-pecos/tests/qec/surface/test_circuit_fuzz.py b/python/quantum-pecos/tests/qec/surface/test_circuit_fuzz.py index d368f82d9..976829c83 100644 --- a/python/quantum-pecos/tests/qec/surface/test_circuit_fuzz.py +++ b/python/quantum-pecos/tests/qec/surface/test_circuit_fuzz.py @@ -492,9 +492,9 @@ def test_sz_phase_single_qubit(self): sim = SparseStab(3) # Qubit 0: data (|+>), Qubit 1: ancilla 1 (|+Y>), Qubit 2: ancilla 2 (|+Y>) - # Prep |+> + # PZ |+> sim.run_gate("H", {0}) - # Prep |+Y> = S|+> + # PZ |+Y> = S|+> sim.run_gate("H", {1}) sim.run_gate("SZ", {1}) sim.run_gate("H", {2}) diff --git a/python/quantum-pecos/tests/qec/test_dem_sampler_vs_stim.py b/python/quantum-pecos/tests/qec/test_dem_sampler_vs_stim.py index e21072a7f..dd4f7d943 100644 --- a/python/quantum-pecos/tests/qec/test_dem_sampler_vs_stim.py +++ b/python/quantum-pecos/tests/qec/test_dem_sampler_vs_stim.py @@ -793,7 +793,7 @@ def test_random_circuit_fault_locations_match(self, seed: int) -> None: name = instruction.name targets = instruction.targets_copy() if name == "R": - stim_op_count += len(targets) # Prep error after + stim_op_count += len(targets) # PZ error after elif name in ("H", "S", "S_DAG"): stim_op_count += len(targets) # Gate error after elif name == "CX": diff --git a/python/quantum-pecos/tests/qec/test_fault_catalog.py b/python/quantum-pecos/tests/qec/test_fault_catalog.py index bd2950276..311faef88 100644 --- a/python/quantum-pecos/tests/qec/test_fault_catalog.py +++ b/python/quantum-pecos/tests/qec/test_fault_catalog.py @@ -221,7 +221,7 @@ def test_prep_fault_with_no_effect_included(self): fault = prep_locs[0].faults[0] assert fault.kind == "prep_flip" assert fault.pauli is None - # Prep X through H becomes Z which doesn't flip MZ → empty + # PZ X through H becomes Z which doesn't flip MZ → empty assert fault.measurements == [] diff --git a/python/quantum-pecos/tests/qec/test_meas_sampling_generality.py b/python/quantum-pecos/tests/qec/test_meas_sampling_generality.py index f8315c5f5..1532453b9 100644 --- a/python/quantum-pecos/tests/qec/test_meas_sampling_generality.py +++ b/python/quantum-pecos/tests/qec/test_meas_sampling_generality.py @@ -92,12 +92,12 @@ def det_rate(results): class TestPrepFaultAbsorption: - """Prep faults propagate forward but get absorbed at PZ/MZ.""" + """PZ faults propagate forward but get absorbed at PZ/MZ.""" def test_prep_fault_reaches_next_measurement(self): """A prep fault on PZ(ancilla) should flip the next ancilla MZ.""" tc = build_two_round_x_check() - # Prep-only noise + # PZ-only noise depol = depolarizing().p1(0).p2(0).p_meas(0).p_prep(0.01) shots = 50000 @@ -110,11 +110,11 @@ def det_rate(results): meas_rate = det_rate(meas_r) stab_rate = det_rate(stab_r) - # Prep faults fire the detector (X error → detected at MZ) + # PZ faults fire the detector (X error → detected at MZ) assert stab_rate > 0.005, f"Stabilizer should see prep faults: {stab_rate}" assert ( abs(meas_rate - stab_rate) / stab_rate < 0.15 - ), f"Prep fault rate mismatch: dem={meas_rate:.4f} stab={stab_rate:.4f}" + ), f"PZ fault rate mismatch: dem={meas_rate:.4f} stab={stab_rate:.4f}" def test_prep_fault_does_not_cross_reset(self): """A prep fault should NOT propagate past a subsequent PZ on the same qubit.""" diff --git a/python/quantum-pecos/tests/selene/test_hugr_to_ast.py b/python/quantum-pecos/tests/selene/test_hugr_to_ast.py index cbc5d051c..a654c60f7 100644 --- a/python/quantum-pecos/tests/selene/test_hugr_to_ast.py +++ b/python/quantum-pecos/tests/selene/test_hugr_to_ast.py @@ -74,7 +74,7 @@ def single_h() -> bool: assert alloc_decls[0].name == "q" assert alloc_decls[0].capacity == 1 - # Check body has Prep, H, Measure + # Check body has PZ, H, Measure assert len(ast.body) == 3 # First should be PrepareOp @@ -261,7 +261,7 @@ def minimal() -> bool: ast = guppy_to_ast(minimal) - # Should have Prep and Measure only + # Should have PZ and Measure only assert len(ast.body) == 2 assert isinstance(ast.body[0], PrepareOp) assert isinstance(ast.body[1], MeasureOp) diff --git a/python/quantum-pecos/tests/slr/pecos/unit/slr/test_quantum_permutation.py b/python/quantum-pecos/tests/slr/pecos/unit/slr/test_quantum_permutation.py index ed541f2e9..755e607a5 100644 --- a/python/quantum-pecos/tests/slr/pecos/unit/slr/test_quantum_permutation.py +++ b/python/quantum-pecos/tests/slr/pecos/unit/slr/test_quantum_permutation.py @@ -3,7 +3,7 @@ import re import pytest -from pecos.slr import CReg, Main, Permute, QReg, SlrConverter +from pecos.slr import CReg, Main, Permute, QReg, SlrConverter, rad from pecos.slr.qeclib import qubit # QASM Tests @@ -361,9 +361,9 @@ def test_rotation_gates_with_permutation() -> None: a, b, # Apply initial gates to track qubit allocation - qubit.RX[0.1](a[0]), # Track as "original a[0]" - qubit.RY[0.2](a[1]), # Track as "original a[1]" - qubit.RZ[0.3](b[0]), # Track as "original b[0]" + qubit.RX(rad(0.1), a[0]), # Track as "original a[0]" + qubit.RY(rad(0.2), a[1]), # Track as "original a[1]" + qubit.RZ(rad(0.3), b[0]), # Track as "original b[0]" qubit.SZ(b[1]), # Track as "original b[1]" # Apply permutation Permute( @@ -371,8 +371,8 @@ def test_rotation_gates_with_permutation() -> None: [b[0], a[0]], ), # Apply gates after permutation - qubit.RX[0.4](a[0]), # Should apply to "original b[0]" - qubit.RY[0.5](b[0]), # Should apply to "original a[0]" + qubit.RX(rad(0.4), a[0]), # Should apply to "original b[0]" + qubit.RY(rad(0.5), b[0]), # Should apply to "original a[0]" qubit.T(a[1]), # Should apply to "original a[1]" qubit.Tdg(b[1]), # Should apply to "original b[1]" ) diff --git a/python/quantum-pecos/tests/slr_tests/ast_guppy/__init__.py b/python/quantum-pecos/tests/slr_tests/ast_guppy/__init__.py new file mode 100644 index 000000000..e5eab195b --- /dev/null +++ b/python/quantum-pecos/tests/slr_tests/ast_guppy/__init__.py @@ -0,0 +1,16 @@ +"""AST -> Guppy v1 acceptance tests. + +These tests exercise the SLR -> AST -> Guppy lowering path +(`SlrConverter.guppy()` and downstream codegens at +`pecos/slr/ast/codegen/guppy.py`). They are the v1 acceptance +contract: each test is the spec for one feature in the v1 supported +set. + +Tests start as xfail while the AST Guppy emitter is being rewritten. +As features land, the xfail mark comes off the corresponding test. + +Post-cutover, `SlrConverter.hugr()` is also AST-routed (wraps `main` +in a no-arg `entry()` and compiles that). Acceptance tests prefer +`SlrConverter.guppy()` / `_harness.assert_ast_guppy_compiles` so +failures point at the parameterized function, not the entry wrapper. +""" diff --git a/python/quantum-pecos/tests/slr_tests/ast_guppy/_harness.py b/python/quantum-pecos/tests/slr_tests/ast_guppy/_harness.py new file mode 100644 index 000000000..83b2002c4 --- /dev/null +++ b/python/quantum-pecos/tests/slr_tests/ast_guppy/_harness.py @@ -0,0 +1,96 @@ +"""Compile harness for AST -> Guppy v1 acceptance tests. + +Provides a single primitive: `assert_ast_guppy_compiles(prog)`. Takes an +SLR `Main`/`Block`, runs it through `SlrConverter.guppy()` (which is the +AST path: `slr_to_ast` -> `AstToGuppy`), writes the source to a temp +file, imports it as a fresh module, and calls `main.compile_function()` +on the resulting Guppy function. + +Post-cutover, `SlrConverter.hugr()` also routes through the AST path +(wrapping `main(...)` in a no-arg `entry()` and calling +`entry.compile()`). This harness compiles `main.compile_function()` +directly so failures point at the parameterized function, not at the +entry wrapper. +""" + +from __future__ import annotations + +import importlib.util +import sys +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from pecos.slr import SlrConverter + +if TYPE_CHECKING: + from pecos.slr import Block + + +@dataclass(frozen=True) +class CompileFailureError(AssertionError): + """Raised when generated Guppy source fails to compile. + + Carries the generated source for diagnostics. The exception type + inherits from AssertionError so pytest renders it as a normal + test failure instead of an internal error. + """ + + source: str + cause: BaseException + + def __str__(self) -> str: + cause_msg = f"{type(self.cause).__name__}: {self.cause}" + # Truncate the source in repr; full source is on .source for inspection. + max_lines = 80 + lines = self.source.splitlines() + shown = "\n".join(lines[:max_lines]) + suffix = f"\n... ({len(lines) - max_lines} more lines truncated)" if len(lines) > max_lines else "" + return f"{cause_msg}\n--- generated Guppy source ---\n{shown}{suffix}" + + +def ast_guppy_source(slr_program: Block) -> str: + """Return the Guppy source the AST path would emit, without compiling.""" + return SlrConverter(slr_program).guppy() + + +def assert_ast_guppy_compiles(slr_program: Block) -> None: + """Run SLR -> AST -> Guppy source -> compile_function. Raise on failure. + + The "main" function in the generated source is compiled via Guppy's + `compile_function()` (works for parameterized functions, unlike + `compile()` which expects a no-arg entrypoint). + """ + source = ast_guppy_source(slr_program) + + # Import as a fresh module from a temp file so Guppy can attribute + # source spans correctly in any error messages. + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + path = Path(f.name) + f.write(source) + + spec = importlib.util.spec_from_file_location(f"_ast_guppy_test_{path.stem}", path) + if spec is None or spec.loader is None: + msg = f"Failed to create import spec for generated source at {path}" + raise RuntimeError(msg) + + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + try: + spec.loader.exec_module(module) + except BaseException as exc: + raise CompileFailureError(source=source, cause=exc) from exc + + main = getattr(module, "main", None) + if main is None: + msg = "Generated Guppy source has no `main` function" + raise CompileFailureError( + source=source, + cause=AttributeError(msg), + ) + + try: + main.compile_function() + except BaseException as exc: + raise CompileFailureError(source=source, cause=exc) from exc diff --git a/python/quantum-pecos/tests/slr_tests/ast_guppy/_selene_harness.py b/python/quantum-pecos/tests/slr_tests/ast_guppy/_selene_harness.py new file mode 100644 index 000000000..e6a8868a1 --- /dev/null +++ b/python/quantum-pecos/tests/slr_tests/ast_guppy/_selene_harness.py @@ -0,0 +1,203 @@ +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Selene behavioral test harness for the AST -> Guppy v1 emitter. + +Compile-only tests via `_harness.assert_ast_guppy_compiles` prove +linearity + HUGR construction. They do not prove that observable +outcomes match SLR intent (wrong CReg ordering, wrong permutation +mapping, swapped reset/discard semantics all type-check). + +This harness runs an SLR program through the AST path and executes +the result via Selene +(`pecos.sim(pecos.Guppy(entry)).classical(pecos.selene_engine())`), +returning per-shot measurement bits as a list of dicts. + +Behavioral assertions on the result table are the v1 oracle. +""" + +from __future__ import annotations + +import importlib.util +import sys +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING + +from pecos import Guppy, selene_engine, sim +from pecos.slr import SlrConverter +from pecos.slr.ast import RegisterDecl, slr_to_ast +from pecos.slr.ast.codegen.entry_wrapper import RETURN_TAG_NAMESPACE, build_no_arg_entry_wrapper + +if TYPE_CHECKING: + import pecos_rslib + from pecos.slr import Block + + +_DEFAULT_SHOTS = 100 +_DEFAULT_SEED = 42 + + +def run_ast_guppy_via_selene( + slr_program: Block, + *, + shots: int = _DEFAULT_SHOTS, + seed: int = _DEFAULT_SEED, +) -> list[dict[str, int]]: + """Run an SLR program through the AST -> Guppy -> Selene path. + + Returns a list of per-shot measurement records. Each record is a + `dict[str, int]` keyed by Guppy result names ("measurement_0", + "measurement_1", ...) with bit values 0 or 1. + + The AST-emitted `main(q: array[qubit, N] @ owned) -> ...` is + wrapped in a no-arg `entry()` that allocates the qubits, calls + main, and returns the result CRegs unpacked as a flat tuple of + bools. Selene's Guppy adapter requires a no-arg entrypoint. + """ + ast_source = SlrConverter(slr_program).guppy() + program = slr_to_ast(slr_program) + + # Opt-in named return tags. The wrapper emits + # `result("__pecos_return.", )` per returned CReg, so Selene + # keys outputs by CReg NAME -- immune to internal (non-returned) + # measurements like the Steane RUS verify. `_shot_records` reads those + # tags and re-exports the existing public `measurement_N` shape, so all + # `run_ast_guppy_via_selene` consumers stay unchanged. The production + # wrapper (default `emit_return_result_tags=False`) is untouched. + wrapper, info = build_no_arg_entry_wrapper(program, emit_return_result_tags=True) + # The returned CRegs (explicit `Return(...)`, in listed order) are the + # source of truth for the public `measurement_N` order. The implicit + # result-CReg path was removed, so a program with no `Return` has no + # measurement record at all. + if info.explicit_return is None: + msg = ( + "Selene behavioral test requires an explicit `Return(...)`. " + "The implicit result-CReg return was removed; a program with " + "no `Return` compiles to `entry() -> None` and has no measurement " + "record." + ) + raise ValueError(msg) + record_cregs: list[RegisterDecl] = [] + for value in info.explicit_return.values: + name = value if isinstance(value, str) else getattr(value, "name", None) + if name in info.all_creg_sizes: + record_cregs.append( + RegisterDecl(name=name, size=info.all_creg_sizes[name]), + ) + if not record_cregs: + # Strict: an explicit `Return(...)` + # with no CRegs (e.g. `Return(q)`) has no measurement record -- + # fail loudly rather than silently mis-shape the result table. + msg = ( + "Selene behavioral test requires at least one returned CReg " + "(explicit `Return(...)`). Returning only QRegs/values " + "yields no measurement record. Declare CRegs and write " + "measurement bits into them." + ) + raise ValueError(msg) + + full_source = ast_source + wrapper + + entry_func = _import_entry_function(full_source) + total_qubits = sum(info.allocator_sizes.values()) + + result = sim(Guppy(entry_func)).classical(selene_engine()).qubits(max(total_qubits, 1)).seed(seed).run(shots) + + return _shot_records(result, record_cregs) + + +def _import_entry_function(source: str) -> object: + """Write source to a temp file, import, and return the `entry` callable.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + path = Path(f.name) + f.write(source) + + spec = importlib.util.spec_from_file_location(f"_selene_test_{path.stem}", path) + if spec is None or spec.loader is None: + msg = f"Failed to create import spec for generated source at {path}" + raise RuntimeError(msg) + + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + entry = getattr(module, "entry", None) + if entry is None: + msg = "Wrapped Guppy source has no `entry` function" + raise RuntimeError(msg) + return entry + + +def _shot_records(result: pecos_rslib.ShotVec, record_cregs: list[RegisterDecl]) -> list[dict[str, int]]: + """Re-export the named `__pecos_return.` tags as the public shape. + + The wrapper emits `result("__pecos_return.", )` per returned + CReg, so Selene's `to_dict()` keys outputs by CReg name and each + shot value is a list of that CReg's bits in declaration order. We flatten + the returned CRegs (in `Return(...)` order) into the historical public + shape `{"measurement_0": .., "measurement_1": .., ...}` so existing + `run_ast_guppy_via_selene` consumers are unchanged -- but now reading the + correct bits (immune to internal, non-returned measurements). + """ + raw = result.to_dict() if hasattr(result, "to_dict") else result + if not isinstance(raw, dict): + msg = f"Unexpected Selene result shape: {type(raw).__name__}" + raise TypeError(msg) + + tag_keys = [f"{RETURN_TAG_NAMESPACE}.{decl.name}" for decl in record_cregs] + missing = [k for k in tag_keys if k not in raw] + if missing: + msg = ( + f"Selene result is missing return tags {missing}; got keys " + f"{sorted(raw)}. Expected the opt-in wrapper to emit " + f'`result("{RETURN_TAG_NAMESPACE}.", )` per ' + "returned CReg." + ) + raise KeyError(msg) + + shot_count = len(raw[tag_keys[0]]) if tag_keys else 0 + records: list[dict[str, int]] = [] + for shot_idx in range(shot_count): + record: dict[str, int] = {} + counter = 0 + for decl, key in zip(record_cregs, tag_keys, strict=True): + shot_val = raw[key][shot_idx] + # Selene shapes a size-1 CReg result tag as a scalar int per + # shot, and a size>1 CReg as a list of `size` ints per shot. + # Be explicit and fail LOUD on any other shape -- a silent + # mis-count is the exact bug class this guard prevents (do not let an + # unexpected Selene type, e.g. a numpy array/generator, be + # silently wrapped as one bit). + if isinstance(shot_val, (list, tuple)): + bits = list(shot_val) + elif isinstance(shot_val, int): # bool is an int subclass + bits = [shot_val] + else: + msg = ( + f"Return tag {key!r} shot {shot_idx}: unexpected Selene " + f"value shape {type(shot_val).__name__} ({shot_val!r}); " + "expected a scalar int (size-1 CReg) or a list of ints " + "(size>1 CReg). Selene's result-tag output shape may " + "have changed -- update _shot_records deliberately." + ) + raise TypeError(msg) + if len(bits) != decl.size: + msg = ( + f"Return tag {key!r} shot {shot_idx} has {len(bits)} bits, " + f"expected {decl.size} (CReg {decl.name!r})." + ) + raise ValueError(msg) + for bit in bits: + record[f"measurement_{counter}"] = int(bit) + counter += 1 + records.append(record) + return records diff --git a/python/quantum-pecos/tests/slr_tests/ast_guppy/audit_runner.py b/python/quantum-pecos/tests/slr_tests/ast_guppy/audit_runner.py new file mode 100644 index 000000000..755fda854 --- /dev/null +++ b/python/quantum-pecos/tests/slr_tests/ast_guppy/audit_runner.py @@ -0,0 +1,872 @@ +# Copyright 2026 The PECOS Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Audit runner for cutover gap discovery. + +Iterates a curated list of `(source_label, slr_program_factory)` +pairs from PECOS examples, qeclib, and existing test fixtures. +Runs each through `SlrConverter.hugr()` (now AST-routed by default +post-cutover) and captures any failures. + +This is NOT a pytest test file. It's an audit tool run +during the cutover. Output is the seed for new +rows in the audit manifest. + +Invocation: + cd /home/ciaranra/Repos/PECOS + uv run python python/quantum-pecos/tests/slr_tests/ast_guppy/audit_runner.py + +For each program: emits one of +- OK