diff --git a/lib/Optimizer/Transforms/ExpandMeasurements.cpp b/lib/Optimizer/Transforms/ExpandMeasurements.cpp index 45117d5ee47..5cc9fd48ab3 100644 --- a/lib/Optimizer/Transforms/ExpandMeasurements.cpp +++ b/lib/Optimizer/Transforms/ExpandMeasurements.cpp @@ -32,6 +32,17 @@ bool usesIndividualQubit(A x) { // Generalized pattern for expanding a multiple qubit measurement (whether it is // mx, my, or mz) to a series of individual measurements. +// +// Handles both result-type families that the vector form of `quake.mz`/`mx`/ +// `my` can carry: +// - `!cc.stdvec` -- the legacy form. The only legitimate +// consumer is `quake.discriminate`, so the rewrite folds the per-element +// measurements straight into a `cc.stdvec_init -> !cc.stdvec`. +// - `!cc.stdvec` -- the handle-vector value can have +// non-discriminate consumers. Those consumers expect a value of the +// original handle-stdvec type, so the rewrite additionally builds a +// per-element handle buffer and folds it into a `cc.stdvec_init -> +// !cc.stdvec` that replaces all remaining uses. template class ExpandRewritePattern : public OpRewritePattern { public: @@ -40,6 +51,42 @@ class ExpandRewritePattern : public OpRewritePattern { LogicalResult matchAndRewrite(A measureOp, PatternRewriter &rewriter) const override { auto loc = measureOp.getLoc(); + auto *ctx = rewriter.getContext(); + + // The dynamic-legality predicate filters out the scalar forms, so by + // construction the result type here is `!cc.stdvec` for some X. + auto stdvecResTy = + dyn_cast(measureOp.getMeasOut().getType()); + auto handleTy = cudaq::cc::MeasureHandleType::get(ctx); + bool isHandleResult = + isa(stdvecResTy.getElementType()); + + // Per-element scalar result type tracks the original stdvec element + // type. For handle inputs we measure into `!cc.measure_handle` per + // qubit. + Type perElemTy = isHandleResult + ? static_cast(handleTy) + : static_cast(quake::MeasureType::get(ctx)); + + // Classify users so we only allocate the buffers we actually need, and + // collect the discriminate users at the same time. The legacy + // `!quake.measure` path has only `quake.discriminate` consumers by + // construction; the handle path may have either, both, or none. + SmallVector discUsers; + bool hasNonDiscUser = false; + for (auto *u : measureOp.getMeasOut().getUsers()) { + if (auto d = dyn_cast(u)) + discUsers.push_back(d); + else + hasNonDiscUser = true; + } + // Allocation policy: + // - Legacy `!cc.stdvec` always allocates the i1 buffer. + // - `!cc.stdvec` allocates each buffer only when a + // consumer in that element-type class is present. + bool needI1Buf = !isHandleResult || !discUsers.empty(); + bool needHandleBuf = isHandleResult && hasNonDiscUser; + // 1. Determine the total number of qubits we need to measure. This // determines the size of the buffer of bools to create to store the results // in. @@ -56,25 +103,47 @@ class ExpandRewritePattern : public OpRewritePattern { totalToRead = arith::AddIOp::create(rewriter, loc, totalToRead, vecSz); } - // 2. Create the buffer. + // 2. Create the buffers (one per output kind we actually need). auto i1Ty = rewriter.getI1Type(); auto i8Ty = rewriter.getI8Type(); - Value buff = cudaq::cc::AllocaOp::create(rewriter, loc, i8Ty, totalToRead); + Value i1Buff; + if (needI1Buf) + i1Buff = cudaq::cc::AllocaOp::create(rewriter, loc, i8Ty, totalToRead); + Value handleBuff; + if (needHandleBuf) + handleBuff = + cudaq::cc::AllocaOp::create(rewriter, loc, handleTy, totalToRead); + + // Per-element store helper. Each qubit is measured exactly once with + // `perElemTy`; the resulting value is fanned out to whichever buffers we + // allocated (i1 for discriminate consumers, handle for non-discriminate + // consumers). + auto storePerElement = [&](OpBuilder &builder, Location loc, Value meas, + Value offset) { + if (needI1Buf) { + auto bit = quake::DiscriminateOp::create(builder, loc, i1Ty, meas); + auto addr = cudaq::cc::ComputePtrOp::create( + builder, loc, cudaq::cc::PointerType::get(i8Ty), i1Buff, offset); + auto bitByte = cudaq::cc::CastOp::create( + builder, loc, i8Ty, bit, cudaq::cc::CastOpMode::Unsigned); + cudaq::cc::StoreOp::create(builder, loc, bitByte, addr); + } + if (needHandleBuf) { + auto addr = cudaq::cc::ComputePtrOp::create( + builder, loc, cudaq::cc::PointerType::get(handleTy), handleBuff, + offset); + cudaq::cc::StoreOp::create(builder, loc, meas, addr); + } + }; // 3. Measure each individual qubit and insert the result, in order, into - // the buffer. For registers/vectors, loop over the entire set of qubits. + // the buffer. For registers, loop over the entire set of qubits. Value buffOff = arith::ConstantIntOp::create(rewriter, loc, 0, 64); Value one = arith::ConstantIntOp::create(rewriter, loc, 1, 64); - auto measTy = quake::MeasureType::get(rewriter.getContext()); for (auto v : measureOp.getTargets()) { if (isa(v.getType())) { - auto meas = A::create(rewriter, loc, measTy, v).getMeasOut(); - auto bit = quake::DiscriminateOp::create(rewriter, loc, i1Ty, meas); - Value addr = cudaq::cc::ComputePtrOp::create( - rewriter, loc, cudaq::cc::PointerType::get(i8Ty), buff, buffOff); - auto bitByte = cudaq::cc::CastOp::create( - rewriter, loc, i8Ty, bit, cudaq::cc::CastOpMode::Unsigned); - cudaq::cc::StoreOp::create(rewriter, loc, bitByte, addr); + auto meas = A::create(rewriter, loc, perElemTy, v).getMeasOut(); + storePerElement(rewriter, loc, meas, buffOff); buffOff = arith::AddIOp::create(rewriter, loc, buffOff, one); } else { assert(isa(v.getType())); @@ -84,37 +153,56 @@ class ExpandRewritePattern : public OpRewritePattern { [&](OpBuilder &builder, Location loc, Region &, Block &block) { Value iv = block.getArgument(0); Value qv = quake::ExtractRefOp::create(builder, loc, v, iv); - auto meas = A::create(builder, loc, measTy, qv); - auto bit = quake::DiscriminateOp::create(builder, loc, i1Ty, - meas.getMeasOut()); + auto meas = A::create(builder, loc, perElemTy, qv); if (auto registerName = measureOp.getRegisterNameAttr()) meas.setRegisterName(registerName); Value offset = arith::AddIOp::create(builder, loc, iv, buffOff); - auto addr = cudaq::cc::ComputePtrOp::create( - builder, loc, cudaq::cc::PointerType::get(i8Ty), buff, - offset); - auto bitByte = cudaq::cc::CastOp::create( - builder, loc, i8Ty, bit, cudaq::cc::CastOpMode::Unsigned); - cudaq::cc::StoreOp::create(builder, loc, bitByte, addr); + storePerElement(builder, loc, meas.getMeasOut(), offset); }); buffOff = arith::AddIOp::create(rewriter, loc, buffOff, vecSz); } } - // 4. Use the buffer as an initialization expression and create the - // std::vec value. - auto stdvecTy = cudaq::cc::StdvecType::get(rewriter.getContext(), i1Ty); - for (auto *out : measureOp.getMeasOut().getUsers()) - if (auto disc = dyn_cast_if_present(out)) { - auto ptrArrI1Ty = - cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(i1Ty)); + // 4. Replace each `quake.discriminate` consumer with a + // `cc.stdvec_init -> !cc.stdvec` over the i1 buffer. + if (needI1Buf) { + auto stdvecI1Ty = cudaq::cc::StdvecType::get(ctx, i1Ty); + auto ptrArrI1Ty = + cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(i1Ty)); + for (auto disc : discUsers) { auto buffCast = - cudaq::cc::CastOp::create(rewriter, loc, ptrArrI1Ty, buff); + cudaq::cc::CastOp::create(rewriter, loc, ptrArrI1Ty, i1Buff); rewriter.template replaceOpWithNewOp( - disc, stdvecTy, buffCast, totalToRead); + disc, stdvecI1Ty, buffCast, totalToRead); } + } + + // 5. For the handle path with non-discriminate consumers, build a + // `cc.stdvec_init -> !cc.stdvec` over the handle + // buffer and route the original result's remaining users to it via + // `replaceOp` (one atomic substitution) + Value replacementVal; + if (needHandleBuf) { + auto stdvecHandleTy = cudaq::cc::StdvecType::get(ctx, handleTy); + auto handleStdvec = cudaq::cc::StdvecInitOp::create( + rewriter, loc, stdvecHandleTy, handleBuff, totalToRead); + replacementVal = handleStdvec.getResult(); + } + + // The pass is scheduled before wire lowering, so the variadic `$wires` + // result group is structurally empty here. + assert(measureOp.getWires().empty() && + "`expand-measurements` runs before wire lowering"); - rewriter.eraseOp(measureOp); + // Step 5 builds a handle-vector replacement exactly when the + // user-classification scan found a non-discriminate consumer. Without + // this, `replaceOp` below would feed a null value through to a live + // user. + assert((replacementVal != nullptr) == hasNonDiscUser && + "handle-vector replacement must exist iff a non-discriminate " + "consumer was present"); + + rewriter.replaceOp(measureOp, replacementVal); return success(); } }; @@ -124,6 +212,78 @@ using MxRewrite = ExpandRewritePattern; using MyRewrite = ExpandRewritePattern; using MzRewrite = ExpandRewritePattern; +// Expand `quake.discriminate : !cc.stdvec -> +// !cc.stdvec` when the input handle vector is *not* the direct result +// of a measurement op. The bridge emits this shape for `cudaq::to_bools` +// applied to a handle vector that has crossed an SSA boundary +// (e.g. function argument, kernel return), where the measurement-op +// pattern above cannot reach the underlying `quake.mz/mx/my`. It loops +// over the handle vector, discriminates each element, and rewraps the +// resulting bytes as a `!cc.stdvec`. The direct-from-measurement +// case stays handled by `ExpandRewritePattern` to avoid an extra +// per-element load. +class ExpandStdvecHandleDiscriminate + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(quake::DiscriminateOp disc, + PatternRewriter &rewriter) const override { + Value handleVec = disc.getMeasurement(); + auto stdvecTy = dyn_cast(handleVec.getType()); + if (!stdvecTy || + !isa(stdvecTy.getElementType())) + return failure(); + if (handleVec.getDefiningOp()) + return failure(); + + auto loc = disc.getLoc(); + auto *ctx = rewriter.getContext(); + auto i1Ty = rewriter.getI1Type(); + auto i8Ty = rewriter.getI8Type(); + auto i64Ty = rewriter.getI64Type(); + auto handleTy = cudaq::cc::MeasureHandleType::get(ctx); + + Value vecSize = + cudaq::cc::StdvecSizeOp::create(rewriter, loc, i64Ty, handleVec); + auto handleArrPtrTy = + cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(handleTy)); + Value handleData = cudaq::cc::StdvecDataOp::create( + rewriter, loc, handleArrPtrTy, handleVec); + // Output is held in an i8 buffer, then bitcast to `!cc.ptr>` for the wrap. This matches the convention used by the + // measurement-op pattern above (steps 2 + 4) so downstream passes see + // the same shape regardless of which path produced the i1 vector. + Value i1Buff = cudaq::cc::AllocaOp::create(rewriter, loc, i8Ty, vecSize); + + cudaq::opt::factory::createInvariantLoop( + rewriter, loc, vecSize, + [&](OpBuilder &builder, Location loc, Region &, Block &block) { + Value iv = block.getArgument(0); + Value handleAddr = cudaq::cc::ComputePtrOp::create( + builder, loc, cudaq::cc::PointerType::get(handleTy), handleData, + iv); + Value handleVal = cudaq::cc::LoadOp::create(builder, loc, handleAddr); + Value bit = + quake::DiscriminateOp::create(builder, loc, i1Ty, handleVal); + Value byteAddr = cudaq::cc::ComputePtrOp::create( + builder, loc, cudaq::cc::PointerType::get(i8Ty), i1Buff, iv); + Value bitByte = cudaq::cc::CastOp::create( + builder, loc, i8Ty, bit, cudaq::cc::CastOpMode::Unsigned); + cudaq::cc::StoreOp::create(builder, loc, bitByte, byteAddr); + }); + + auto stdvecI1Ty = cudaq::cc::StdvecType::get(ctx, i1Ty); + auto ptrArrI1Ty = + cudaq::cc::PointerType::get(cudaq::cc::ArrayType::get(i1Ty)); + Value buffCast = + cudaq::cc::CastOp::create(rewriter, loc, ptrArrI1Ty, i1Buff); + rewriter.replaceOpWithNewOp(disc, stdvecI1Ty, + buffCast, vecSize); + return success(); + } +}; + /// Convert a `quake.reset` with a `veq` argument into a loop over the elements /// of the `veq` and `quake.reset` on each of them. class ResetRewrite : public OpRewritePattern { @@ -156,7 +316,8 @@ class ExpandMeasurementsPass auto *op = getOperation(); auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.insert(ctx); + patterns.insert(ctx); ConversionTarget target(*ctx); target.addLegalDialect(); @@ -169,6 +330,25 @@ class ExpandMeasurementsPass target.addDynamicallyLegalOp([](quake::ResetOp r) { return !isa(r.getTargets().getType()); }); + target.addDynamicallyLegalOp( + [](quake::DiscriminateOp d) { + // Scalar discriminate is always legal. + auto stdvecTy = + dyn_cast(d.getMeasurement().getType()); + if (!stdvecTy) + return true; + // Vector discriminate of legacy `!quake.measure` is folded as + // a side-effect of the measurement-op rewrite (step 4); leave + // it legal here so the driver does not look for a standalone + // pattern. + if (!isa(stdvecTy.getElementType())) + return true; + // Vector discriminate of `!cc.measure_handle` whose source is + // a measurement op is similarly folded (step 4 again). Only + // the indirect case needs `ExpandStdvecHandleDiscriminate`. + return d.getMeasurement() + .getDefiningOp() != nullptr; + }); if (failed(applyPartialConversion(op, target, std::move(patterns)))) { op->emitOpError("could not expand measurements"); signalPassFailure(); diff --git a/test/Transforms/expand_measurements_handle.qke b/test/Transforms/expand_measurements_handle.qke new file mode 100644 index 00000000000..dcb81e96134 --- /dev/null +++ b/test/Transforms/expand_measurements_handle.qke @@ -0,0 +1,227 @@ +// ========================================================================== // +// Copyright (c) 2026 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// RUN: cudaq-opt --expand-measurements %s | FileCheck %s + +func.func @handle_return(%v: !quake.veq) -> !cc.stdvec { + %mz = quake.mz %v : (!quake.veq) -> !cc.stdvec + return %mz : !cc.stdvec +} + +// CHECK-LABEL: func.func @handle_return( +// CHECK-SAME: %[[V:.*]]: !quake.veq) -> !cc.stdvec +// CHECK: %[[N:.*]] = quake.veq_size %[[V]] : (!quake.veq) -> i64 +// CHECK: %[[BUF:.*]] = cc.alloca !cc.measure_handle{{\[}}%{{.*}} : i64] +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : i64 +// CHECK: cc.condition %{{.*}}(%{{.*}} : i64) +// CHECK: } do { +// CHECK: ^bb0(%[[IV:.*]]: i64): +// CHECK: %[[Q:.*]] = quake.extract_ref %[[V]]{{\[}}%[[IV]]] : (!quake.veq, i64) -> !quake.ref +// CHECK: %[[H:.*]] = quake.mz %[[Q]] : (!quake.ref) -> !cc.measure_handle +// CHECK: %[[ADDR:.*]] = cc.compute_ptr %[[BUF]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.ptr +// CHECK: cc.store %[[H]], %[[ADDR]] : !cc.ptr +// CHECK: cc.continue %[[IV]] : i64 +// CHECK: } step { +// CHECK: ^bb0(%[[SV:.*]]: i64): +// CHECK: %{{.*}} = arith.addi %[[SV]], %{{.*}} : i64 +// CHECK: cc.continue %{{.*}} : i64 +// CHECK: } {invariant} +// CHECK: %[[VEC:.*]] = cc.stdvec_init %[[BUF]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.stdvec +// CHECK: return %[[VEC]] : !cc.stdvec + + +func.func @handle_disc(%v: !quake.veq) -> !cc.stdvec { + %mz = quake.mz %v : (!quake.veq) -> !cc.stdvec + %d = quake.discriminate %mz : (!cc.stdvec) -> !cc.stdvec + return %d : !cc.stdvec +} + +// CHECK-LABEL: func.func @handle_disc( +// CHECK: %[[I8BUF:.*]] = cc.alloca i8{{\[}}%{{.*}} : i64] +// CHECK-NOT: cc.alloca !cc.measure_handle +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : i64 +// CHECK: cc.condition %{{.*}}(%{{.*}} : i64) +// CHECK: } do { +// CHECK: ^bb0(%[[IV:.*]]: i64): +// CHECK: %[[H:.*]] = quake.mz %{{.*}} : (!quake.ref) -> !cc.measure_handle +// CHECK: %[[BIT:.*]] = quake.discriminate %[[H]] : (!cc.measure_handle) -> i1 +// CHECK: %[[ADDR:.*]] = cc.compute_ptr %[[I8BUF]]{{\[}}%{{.*}}] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.ptr +// CHECK: %[[BYTE:.*]] = cc.cast unsigned %[[BIT]] : (i1) -> i8 +// CHECK: cc.store %[[BYTE]], %[[ADDR]] : !cc.ptr +// CHECK: cc.continue %[[IV]] : i64 +// CHECK: } step { +// CHECK: ^bb0(%{{.*}}: i64): +// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 +// CHECK: cc.continue %{{.*}} : i64 +// CHECK: } {invariant} +// CHECK: %[[CAST:.*]] = cc.cast %[[I8BUF]] : (!cc.ptr>) -> !cc.ptr> +// CHECK: %[[VEC:.*]] = cc.stdvec_init %[[CAST]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.stdvec +// CHECK: return %[[VEC]] : !cc.stdvec + + +func.func @handle_mixed(%v: !quake.veq) -> !cc.stdvec { + %mz = quake.mz %v : (!quake.veq) -> !cc.stdvec + %d = quake.discriminate %mz : (!cc.stdvec) -> !cc.stdvec + %slot = cc.alloca !cc.stdvec + cc.store %d, %slot : !cc.ptr> + return %mz : !cc.stdvec +} + +// CHECK-LABEL: func.func @handle_mixed( +// CHECK-DAG: %[[I8BUF:.*]] = cc.alloca i8{{\[}}%{{.*}} : i64] +// CHECK-DAG: %[[HBUF:.*]] = cc.alloca !cc.measure_handle{{\[}}%{{.*}} : i64] +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : i64 +// CHECK: cc.condition %{{.*}}(%{{.*}} : i64) +// CHECK: } do { +// CHECK: ^bb0(%[[IV:.*]]: i64): +// CHECK: %[[H:.*]] = quake.mz %{{.*}} : (!quake.ref) -> !cc.measure_handle +// CHECK: %[[BIT:.*]] = quake.discriminate %[[H]] : (!cc.measure_handle) -> i1 +// CHECK: %[[I8ADDR:.*]] = cc.compute_ptr %[[I8BUF]] +// CHECK: cc.store %{{.*}}, %[[I8ADDR]] : !cc.ptr +// CHECK: %[[HADDR:.*]] = cc.compute_ptr %[[HBUF]] +// CHECK: cc.store %[[H]], %[[HADDR]] : !cc.ptr +// CHECK: cc.continue %[[IV]] : i64 +// CHECK: } step { +// CHECK: ^bb0(%{{.*}}: i64): +// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 +// CHECK: cc.continue %{{.*}} : i64 +// CHECK: } {invariant} +// CHECK: cc.stdvec_init %{{.*}} : (!cc.ptr>, i64) -> !cc.stdvec +// CHECK: %[[HVEC:.*]] = cc.stdvec_init %[[HBUF]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.stdvec +// CHECK: return %[[HVEC]] : !cc.stdvec + + +func.func @handle_store(%v: !quake.veq) { + %mz = quake.mz %v : (!quake.veq) -> !cc.stdvec + %slot = cc.alloca !cc.stdvec + cc.store %mz, %slot : !cc.ptr> + return +} + +// CHECK-LABEL: func.func @handle_store( +// CHECK: %[[HBUF:.*]] = cc.alloca !cc.measure_handle{{\[}}%{{.*}} : i64] +// CHECK-NOT: cc.alloca i8 +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : i64 +// CHECK: cc.condition %{{.*}}(%{{.*}} : i64) +// CHECK: } do { +// CHECK: ^bb0(%[[IV:.*]]: i64): +// CHECK: %[[H:.*]] = quake.mz %{{.*}} : (!quake.ref) -> !cc.measure_handle +// CHECK: %[[ADDR:.*]] = cc.compute_ptr %[[HBUF]] +// CHECK: cc.store %[[H]], %[[ADDR]] : !cc.ptr +// CHECK: cc.continue %[[IV]] : i64 +// CHECK: } step { +// CHECK: ^bb0(%{{.*}}: i64): +// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 +// CHECK: cc.continue %{{.*}} : i64 +// CHECK: } {invariant} +// CHECK: %[[HVEC:.*]] = cc.stdvec_init %[[HBUF]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.stdvec +// CHECK: cc.store %[[HVEC]], %{{.*}} : !cc.ptr> + + +func.func @handle_ref_and_veq(%q: !quake.ref, %v: !quake.veq) -> !cc.stdvec { + %mz = quake.mz %q, %v : (!quake.ref, !quake.veq) -> !cc.stdvec + return %mz : !cc.stdvec +} + +// CHECK-LABEL: func.func @handle_ref_and_veq( +// CHECK-SAME: %[[Q:.*]]: !quake.ref, %[[V:.*]]: !quake.veq +// CHECK: %[[HBUF:.*]] = cc.alloca !cc.measure_handle{{\[}}%{{.*}} : i64] +// CHECK: %[[H_REF:.*]] = quake.mz %[[Q]] : (!quake.ref) -> !cc.measure_handle +// CHECK: %[[REF_ADDR:.*]] = cc.compute_ptr %[[HBUF]] +// CHECK: cc.store %[[H_REF]], %[[REF_ADDR]] : !cc.ptr +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : i64 +// CHECK: cc.condition %{{.*}}(%{{.*}} : i64) +// CHECK: } do { +// CHECK: ^bb0(%[[IV:.*]]: i64): +// CHECK: %[[QV:.*]] = quake.extract_ref %[[V]]{{\[}}%[[IV]]] : (!quake.veq, i64) -> !quake.ref +// CHECK: %[[H_VEQ:.*]] = quake.mz %[[QV]] : (!quake.ref) -> !cc.measure_handle +// CHECK: %[[VEQ_ADDR:.*]] = cc.compute_ptr %[[HBUF]] +// CHECK: cc.store %[[H_VEQ]], %[[VEQ_ADDR]] : !cc.ptr +// CHECK: cc.continue %[[IV]] : i64 +// CHECK: } step { +// CHECK: ^bb0(%{{.*}}: i64): +// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 +// CHECK: cc.continue %{{.*}} : i64 +// CHECK: } {invariant} +// CHECK: cc.stdvec_init %[[HBUF]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.stdvec + + +func.func @handle_return_mx(%v: !quake.veq) -> !cc.stdvec { + %mx = quake.mx %v : (!quake.veq) -> !cc.stdvec + return %mx : !cc.stdvec +} + +// CHECK-LABEL: func.func @handle_return_mx( +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: } do { +// CHECK: quake.mx %{{.*}} : (!quake.ref) -> !cc.measure_handle +// CHECK: } step { +// CHECK: } {invariant} +// CHECK: cc.stdvec_init %{{.*}} : (!cc.ptr>, i64) -> !cc.stdvec + + +func.func @handle_return_my(%v: !quake.veq) -> !cc.stdvec { + %my = quake.my %v : (!quake.veq) -> !cc.stdvec + return %my : !cc.stdvec +} + +// CHECK-LABEL: func.func @handle_return_my( +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: } do { +// CHECK: quake.my %{{.*}} : (!quake.ref) -> !cc.measure_handle +// CHECK: } step { +// CHECK: } {invariant} +// CHECK: cc.stdvec_init %{{.*}} : (!cc.ptr>, i64) -> !cc.stdvec + + +func.func @to_bools_indirect(%hv: !cc.stdvec) -> !cc.stdvec { + %d = quake.discriminate %hv : (!cc.stdvec) -> !cc.stdvec + return %d : !cc.stdvec +} + +// CHECK-LABEL: func.func @to_bools_indirect( +// CHECK-SAME: %[[HV:.*]]: !cc.stdvec) -> !cc.stdvec +// CHECK: %[[N:.*]] = cc.stdvec_size %[[HV]] : (!cc.stdvec) -> i64 +// CHECK: %[[HDATA:.*]] = cc.stdvec_data %[[HV]] +// CHECK-SAME: : (!cc.stdvec) -> !cc.ptr> +// CHECK: %[[I8BUF:.*]] = cc.alloca i8{{\[}}%[[N]] : i64] +// CHECK: cc.loop while (({{.*}}) -> (i64)) { +// CHECK: %{{.*}} = arith.cmpi slt, %{{.*}}, %[[N]] : i64 +// CHECK: cc.condition %{{.*}}(%{{.*}} : i64) +// CHECK: } do { +// CHECK: ^bb0(%[[IV:.*]]: i64): +// CHECK: %[[HADDR:.*]] = cc.compute_ptr %[[HDATA]]{{\[}}%[[IV]]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.ptr +// CHECK: %[[H:.*]] = cc.load %[[HADDR]] : !cc.ptr +// CHECK: %[[BIT:.*]] = quake.discriminate %[[H]] : (!cc.measure_handle) -> i1 +// CHECK: %[[I8ADDR:.*]] = cc.compute_ptr %[[I8BUF]]{{\[}}%[[IV]]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.ptr +// CHECK: %[[BYTE:.*]] = cc.cast unsigned %[[BIT]] : (i1) -> i8 +// CHECK: cc.store %[[BYTE]], %[[I8ADDR]] : !cc.ptr +// CHECK: cc.continue %[[IV]] : i64 +// CHECK: } step { +// CHECK: ^bb0(%{{.*}}: i64): +// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 +// CHECK: cc.continue %{{.*}} : i64 +// CHECK: } {invariant} +// CHECK: %[[CAST:.*]] = cc.cast %[[I8BUF]] : (!cc.ptr>) -> !cc.ptr> +// CHECK: %[[VEC:.*]] = cc.stdvec_init %[[CAST]], %[[N]] +// CHECK-SAME: : (!cc.ptr>, i64) -> !cc.stdvec +// CHECK: return %[[VEC]] : !cc.stdvec