-
-
Notifications
You must be signed in to change notification settings - Fork 55
✨ Add TensorIterator
#1730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MatthiasReumann
wants to merge
14
commits into
main
Choose a base branch
from
feat/tensor-iterator
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
✨ Add TensorIterator
#1730
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
10a4a2a
Add TensorIterator
MatthiasReumann fb54971
Add qtensor-utils unit test
MatthiasReumann f60a182
🎨 pre-commit fixes
pre-commit-ci[bot] 3df1be0
Update CHANGELOG.md
MatthiasReumann d9153d4
Fix linting
MatthiasReumann 365356a
Add scf.for to unit test
MatthiasReumann 029e32d
🎨 pre-commit fixes
pre-commit-ci[bot] edff31d
Add missing includes
MatthiasReumann 174685d
Merge branch 'main' into feat/tensor-iterator
MatthiasReumann 1ccc7aa
Merge branch 'main' into feat/tensor-iterator
MatthiasReumann 092f9b6
Merge branch 'main' into feat/tensor-iterator
burgholzer 6e2049c
:pencil2: adding this to the main MQT CC entry in the changelog
burgholzer f3358fe
:art: removing redundant namespace qualifiers
burgholzer 4e09cbe
Merge branch 'main' into feat/tensor-iterator
denialhaag File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| /* | ||
| * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| * All rights reserved. | ||
| * | ||
| * SPDX-License-Identifier: MIT | ||
| * | ||
| * Licensed under the MIT License | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <mlir/IR/Builders.h> | ||
| #include <mlir/IR/BuiltinTypes.h> | ||
| #include <mlir/IR/Operation.h> | ||
| #include <mlir/IR/Value.h> | ||
|
|
||
| #include <iterator> | ||
|
|
||
| namespace mlir::qtensor { | ||
|
|
||
| /** | ||
| * @brief A bidirectional_iterator traversing the tensor chain. | ||
| **/ | ||
| class [[nodiscard]] TensorIterator { | ||
| public: | ||
| using iterator_category = std::bidirectional_iterator_tag; | ||
| using difference_type = std::ptrdiff_t; | ||
| using value_type = Operation*; | ||
|
|
||
| TensorIterator() : op_(nullptr), tensor_(nullptr), isSentinel_(false) {} | ||
| explicit TensorIterator(TypedValue<RankedTensorType> tensor) | ||
| : op_(tensor.getDefiningOp()), tensor_(tensor), isSentinel_(false) {} | ||
|
|
||
| /// @returns the operation the iterator points to. | ||
| [[nodiscard]] Operation* operation() const { return op_; } | ||
|
|
||
| /// @returns the operation the iterator points to. | ||
| [[nodiscard]] Operation* operator*() const { return operation(); } | ||
|
|
||
| /// @returns the tensor the iterator points to. | ||
| [[nodiscard]] TypedValue<RankedTensorType> tensor() const; | ||
|
|
||
| TensorIterator& operator++() { | ||
| forward(); | ||
| return *this; | ||
| } | ||
|
|
||
| TensorIterator operator++(int) { | ||
| auto tmp = *this; | ||
| operator++(); | ||
| return tmp; | ||
| } | ||
|
|
||
| TensorIterator& operator--() { | ||
| backward(); | ||
| return *this; | ||
| } | ||
|
|
||
| TensorIterator operator--(int) { | ||
| auto tmp = *this; | ||
| operator--(); | ||
| return tmp; | ||
| } | ||
|
|
||
| bool operator==(const TensorIterator& other) const { | ||
| return other.tensor_ == tensor_ && other.op_ == op_ && | ||
| other.isSentinel_ == isSentinel_; | ||
| } | ||
|
|
||
| bool operator==([[maybe_unused]] std::default_sentinel_t s) const { | ||
| return isSentinel_; | ||
| } | ||
|
|
||
| private: | ||
| /// @brief Move to the next operation on the tensor def-use chain. | ||
| void forward(); | ||
|
|
||
| /// @brief Move to the previous operation on the tensor def-use chain. | ||
| void backward(); | ||
|
|
||
| Operation* op_; | ||
| TypedValue<RankedTensorType> tensor_; | ||
| bool isSentinel_; | ||
| }; | ||
| } // namespace mlir::qtensor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,4 @@ | |
|
|
||
| add_subdirectory(IR) | ||
| add_subdirectory(Transforms) | ||
| add_subdirectory(Utils) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| # Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| # Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: MIT | ||
| # | ||
| # Licensed under the MIT License | ||
|
|
||
| file(GLOB_RECURSE UTILS_CPP "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") | ||
|
|
||
| add_mlir_dialect_library( | ||
| MLIRQTensorUtils | ||
| ${UTILS_CPP} | ||
| ADDITIONAL_HEADER_DIRS | ||
| ${PROJECT_SOURCE_DIR}/mlir/include/mlir/Dialect/QTensor | ||
| DEPENDS | ||
| MLIRQTensorOpsIncGen | ||
| LINK_LIBS | ||
| PUBLIC | ||
| MLIRQTensorDialect) | ||
|
|
||
| mqt_mlir_target_use_project_options(MLIRQTensorUtils) | ||
|
|
||
| # collect header files | ||
| file(GLOB_RECURSE UTILS_HEADERS_SOURCE | ||
| "${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QTensor/Utils/*.h") | ||
| file(GLOB_RECURSE UTILS_HEADERS_BUILD | ||
| "${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QTensor/Utils/*.inc") | ||
|
|
||
| # add public headers using file sets | ||
| target_sources( | ||
| MLIRQTensorUtils | ||
| PUBLIC FILE_SET | ||
| HEADERS | ||
| BASE_DIRS | ||
| ${MQT_MLIR_SOURCE_INCLUDE_DIR} | ||
| FILES | ||
| ${UTILS_HEADERS_SOURCE} | ||
| FILE_SET | ||
| HEADERS | ||
| BASE_DIRS | ||
| ${MQT_MLIR_BUILD_INCLUDE_DIR} | ||
| FILES | ||
| ${UTILS_HEADERS_BUILD}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| /* | ||
| * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| * All rights reserved. | ||
| * | ||
| * SPDX-License-Identifier: MIT | ||
| * | ||
| * Licensed under the MIT License | ||
| */ | ||
|
|
||
| #include "mlir/Dialect/QTensor/Utils/TensorIterator.h" | ||
|
|
||
| #include "mlir/Dialect/QTensor/IR/QTensorOps.h" | ||
|
|
||
| #include <llvm/ADT/TypeSwitch.h> | ||
| #include <llvm/Support/ErrorHandling.h> | ||
| #include <mlir/Dialect/SCF/IR/SCF.h> | ||
| #include <mlir/IR/Builders.h> | ||
| #include <mlir/IR/Value.h> | ||
| #include <mlir/Support/LLVM.h> | ||
|
|
||
| #include <cassert> | ||
| #include <iterator> | ||
|
|
||
| namespace mlir::qtensor { | ||
| TypedValue<RankedTensorType> TensorIterator::tensor() const { | ||
| // A tensor deallocation doesn't have an OpResult. | ||
| if (isa<DeallocOp>(op_)) { | ||
| return nullptr; | ||
| } | ||
| return tensor_; | ||
| } | ||
|
|
||
| void TensorIterator::forward() { | ||
| // If the iterator is a sentinel already, there is nothing to do. | ||
| if (isSentinel_) { | ||
| return; | ||
| } | ||
|
|
||
| // Find the user-operation of the tensor SSA value. | ||
| assert(tensor_.hasOneUse() && "expected linear typing"); | ||
| op_ = *(tensor_.user_begin()); | ||
|
|
||
| // A deallocation defines the end of the tensor's life-chain. | ||
| if (isa<DeallocOp, scf::YieldOp>(op_)) { | ||
| isSentinel_ = true; | ||
| return; | ||
| } | ||
|
|
||
| // Find the output from the input tensor SSA value. | ||
| if (!(isa<AllocOp, FromElementsOp>(op_))) { | ||
| TypeSwitch<Operation*>(op_) | ||
| .Case<ExtractOp>([&](ExtractOp op) { tensor_ = op.getOutTensor(); }) | ||
| .Case<InsertOp>([&](InsertOp op) { tensor_ = op.getResult(); }) | ||
| .Case<scf::ForOp>([&](scf::ForOp op) { | ||
| tensor_ = cast<TypedValue<RankedTensorType>>( | ||
| op.getTiedLoopResult(&*(tensor_.use_begin()))); | ||
| }) | ||
| .Default([&](Operation* op) { | ||
| report_fatal_error("unknown op in def-use chain: " + | ||
| op->getName().getStringRef()); | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| void TensorIterator::backward() { | ||
| // If the iterator is a sentinel, reactivate the iterator. | ||
| if (isSentinel_) { | ||
| isSentinel_ = false; | ||
| return; | ||
| } | ||
|
|
||
| // For deallocations and scf::YieldOps, tensor_ is an OpOperand. | ||
| // Hence, only get the def-op. | ||
| if (isa<DeallocOp, scf::YieldOp>(op_)) { | ||
| op_ = tensor_.getDefiningOp(); | ||
| return; | ||
| } | ||
|
|
||
| // Allocations and FromElements define the start of the tensor's life-chain. | ||
| // Consequently, stop and early exit. | ||
| if (isa<AllocOp, FromElementsOp>(op_)) { | ||
| return; | ||
| } | ||
|
|
||
| // Find the input from the output tensor SSA value. | ||
| TypeSwitch<Operation*>(op_) | ||
| .Case<ExtractOp>([&](ExtractOp op) { tensor_ = op.getTensor(); }) | ||
| .Case<InsertOp>([&](InsertOp op) { tensor_ = op.getDest(); }) | ||
| .Case<scf::ForOp>([&](scf::ForOp op) { | ||
| if (auto res = dyn_cast<OpResult>(tensor_)) { | ||
| OpOperand* operand = op.getTiedLoopInit(res); | ||
| tensor_ = cast<TypedValue<RankedTensorType>>(operand->get()); | ||
| return; | ||
| } | ||
|
|
||
| llvm::reportFatalInternalError( | ||
| "expected scf.for result for tied init lookup"); | ||
| }) | ||
| .Default([&](Operation* op) { | ||
| llvm::reportFatalInternalError("unknown op in def-use chain: " + | ||
| op->getName().getStringRef()); | ||
| }); | ||
|
|
||
| // Get the operation that produces the tensor value. | ||
| // If the current tensor SSA value is a BlockArgument (no defining op), the | ||
| // operation will be a nullptr. | ||
| op_ = tensor_.getDefiningOp(); | ||
| } | ||
|
|
||
| static_assert(std::bidirectional_iterator<TensorIterator>); | ||
| static_assert(std::sentinel_for<std::default_sentinel_t, TensorIterator>, | ||
| "std::default_sentinel_t must be a sentinel for TensorIterator."); | ||
| } // namespace mlir::qtensor | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,4 @@ | |
| # Licensed under the MIT License | ||
|
|
||
| add_subdirectory(IR) | ||
| add_subdirectory(Utils) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| # Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: MIT | ||
| # | ||
| # Licensed under the MIT License | ||
|
|
||
| set(qtensor_utils_target mqt-core-mlir-unittest-qtensor-utils) | ||
| add_executable(${qtensor_utils_target} test_tensoriterator.cpp) | ||
| target_link_libraries(${qtensor_utils_target} PRIVATE GTest::gtest_main MLIRQTensorDialect | ||
| MLIRQTensorUtils MLIRQCOProgramBuilder) | ||
| mqt_mlir_configure_unittest_target(${qtensor_utils_target}) | ||
|
|
||
| gtest_discover_tests(${qtensor_utils_target} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT | ||
| 60) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about a
qco::IfOpthat uses a tensor value? Same below