[mlir][linalg] Reimplement SimplifyPackToExpandShape and SimplifyUnPackToCollapseShape for more cases.#204971
Conversation
|
Hello @JerryShih 👋 Thank you for submitting a Pull Request (PR) to the LLVM Project. Since this is your first PR, here are a few useful links covering our main contribution policies and review practices.
Please reply to this message to confirm that you have read these policies, especially the LLVM AI Tool Use Policy, and that any AI tool usage has been noted in the PR description. Frequently asked questionsHow do I add reviewers? This PR will be automatically labeled, and the relevant teams will be notified. For some parts of the project, reviewers may also be added automatically. You can also add reviewers manually using the Reviewers section on this page. If you cannot use that section, it is probably because you do not have write permissions for the repository. In that case, you can request a review by tagging reviewers in a comment using What if there are no comments? If you have not received any comments on your PR after a week, you can request a review by pinging the PR with a comment such as “Ping”. The common courtesy ping rate is once a week. Please remember that you are asking for volunteer time from other developers. Are any special GitHub settings required to contribute to LLVM? We only require contributors to have a public email address associated with their GitHub commits, see this section of LLVM Developer Policy for details. If you have questions, feel free to leave a comment on this PR, or ask on LLVM Discord or LLVM Discourse. Thank you, |
|
@llvm/pr-subscribers-mlir-linalg Author: Jerry Shih (JerryShih) ChangesIf there is no transposition/padding semantic for pack/unpack, Patch is 23.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/204971.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 95383e6262f71..e6954814848b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -195,10 +195,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
```
}];
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
- TensorOrMemRef<[AnyType]>:$dest,
+ TensorOrMemRef<[AnyType]>:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
- DenseI64ArrayAttr:$inner_dims_pos,
+ DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);
@@ -235,7 +235,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> outerDimsPerm = {});
// Returns the shape of the packed type. It is a shared helper that helps
- // type inference methods in a way that ensures that they agree on which
+ // type inference methods in a way that ensures that they agree on which
// dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
@@ -252,7 +252,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<OpFoldResult> innerTiles);
// Same as above function but here dynamic dimensions are assumed
- // to require padding.
+ // to require padding except the unit-tile size dims.
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1a56c5a483e73..ce02f8d0bc174 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5621,13 +5621,16 @@ bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
- if (ShapedType::isDynamic(inputShape[pos]) ||
- ShapedType::isDynamic(outputTileSizes[pos]))
- return true;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile)
return true;
assert(*constantTile != 0 && "static tile size can't be zero");
+ // No padding is needed for unit tile size.
+ if(*constantTile == 1)
+ continue;
+ if (ShapedType::isDynamic(inputShape[pos]) ||
+ ShapedType::isDynamic(outputTileSizes[pos]))
+ return true;
if (inputShape[pos] % (*constantTile) != 0)
return true;
}
@@ -5900,22 +5903,6 @@ static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
return true;
}
-/// Returns true if the pack op does not need a padding value.
-static bool paddingIsNotNeeded(PackOp op) {
- auto srcType = op.getSourceType();
- auto innerDimsPos = op.getInnerDimsPos();
- auto innerTiles = op.getStaticInnerTiles();
- if (ShapedType::isDynamicShape(innerTiles))
- return false;
- for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
- if (srcType.isDynamicDim(pos) && tileSize != 1)
- return false;
- }
- return !PackOp::requirePaddingValue(
- srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
- op.getOuterDimsPerm(), op.getMixedTiles());
-}
-
/// Returns true if the `srcShape` or `destShape` is different from the one in
/// `packOp` and populates each with the inferred static shape.
static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
@@ -5969,7 +5956,13 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
// Fold optional PaddingValue operand away if padding is not needed.
- if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
+ // Reject the dynamic tile size here.
+ if (packOp.getPaddingValue() &&
+ !ShapedType::isDynamicShape(packOp.getStaticInnerTiles()) &&
+ !requirePaddingValueStrict(
+ packOp.getSourceType().getShape(), packOp.getInnerDimsPos(),
+ packOp.getDestType().getShape(), packOp.getOuterDimsPerm(),
+ packOp.getMixedTiles())) {
rewriter.startOpModification(packOp);
packOp.getPaddingValueMutable().clear();
rewriter.finalizeOpModification(packOp);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 993eae62535c3..76024ba1fda77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -18,31 +18,27 @@ namespace mlir {
namespace linalg {
namespace {
-/// Returns the number of shape sizes that is either dynamic or greater than 1.
-static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
- return llvm::count_if(
- shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
-}
-
-/// Returns success() if there is only 1 dimension size in non-packed domain
-/// being greater than 1 and packing only happens on the dimension.
-/// Note: this method should only be used by pack/unpack to reshape conversion.
-/// It assumes that non-unit inner tile size must be used by the non-unit
-/// dimension.
-static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> srcShape,
- ArrayRef<int64_t> innerPackTileSize) {
- if (getNumGtOneDims(srcShape) > 1) {
- return rewriter.notifyMatchFailure(
- op, "expects non-packed domain to have at most one non-unit dims");
+/// Returns `true` if there is no need of transposition for the packed layout
+/// except the unit tile size.
+static bool isPackWithoutTranspose(ArrayRef<int64_t> dimsPos,
+ ArrayRef<int64_t> tileSize) {
+ SmallVector<int64_t> seqPos;
+ if (dimsPos.empty()) {
+ seqPos = llvm::to_vector<4>(llvm::seq<int64_t>(0, tileSize.size()));
+ dimsPos = seqPos;
}
- // Non-unit inner tile size must be used by the non-unit dimension. If not, it
- // will faill on getting reassociation maps.
- if (getNumGtOneDims(innerPackTileSize) > 1) {
- return rewriter.notifyMatchFailure(
- op, "expects at most one non-unit inner tiles");
+
+ int64_t lastNonUnitPos = 0;
+ for (auto [pos, tile] : llvm::zip_equal(dimsPos, tileSize)) {
+ if ((ShapedType::isDynamic(tile) || tile > 1)) {
+ if (pos < lastNonUnitPos) {
+ return false;
+ }
+ lastNonUnitPos = pos;
+ }
}
- return success();
+
+ return true;
}
// If the `linalgOp` represents a transpose, return the permutation vector for
@@ -88,25 +84,6 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
.getResult();
}
- /// Returns success() if it is only packing on the innermost dimension.
- LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
- PackOp packOp) const {
- auto outerDimsPerm = packOp.getOuterDimsPerm();
- if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
- return rewriter.notifyMatchFailure(
- packOp,
- "expects outer_dims_perm is empty or an identity permutation");
- }
-
- int64_t srcRank = packOp.getSourceRank();
- ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
- if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
- return rewriter.notifyMatchFailure(
- packOp, "expects packing at the innermost dimension");
- }
- return success();
- }
-
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
if (packOp.getPaddingValue())
@@ -115,19 +92,24 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (!packOp.hasPureTensorSemantics())
return failure();
+ PackingMetadata packingMetadata;
ShapedType sourceType = packOp.getSourceType();
- if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
- failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
- packOp.getStaticTiles())) &&
- !packOp.isLikePad()) {
- return failure();
- }
-
ShapedType destType = packOp.getDestType();
+ ArrayRef<int64_t> outputShape = destType.getShape();
+ SmallVector<int64_t> packInverseDestPerm =
+ getPackInverseDestPerm(packOp, packingMetadata);
+ SmallVector<int64_t> transpPerm =
+ invertPermutationVector(packInverseDestPerm);
+
+ if (!isPackWithoutTranspose(transpPerm, outputShape))
+ return rewriter.notifyMatchFailure(packOp,
+ "expects no transpose behavior");
+
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
- return failure();
+ return rewriter.notifyMatchFailure(
+ packOp, "unable to get reshape reassociation indices");
FailureOr<Value> expanded =
insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
*reassociation);
@@ -151,49 +133,38 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
operand, reassociation);
}
- /// Returns success() if it is unpacking on the innermost dimension.
- LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
- UnPackOp unpackOp) const {
- auto outerDimsPerm = unpackOp.getOuterDimsPerm();
- if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
- return rewriter.notifyMatchFailure(
- unpackOp,
- "expects outer_dims_perm is empty or an identity permutation");
- }
-
- ShapedType sourceType = unpackOp.getSourceType();
- ShapedType destType = unpackOp.getDestType();
- if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
- return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
-
- ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
- if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
- return rewriter.notifyMatchFailure(
- unpackOp, "expects unpacking on the innermost dimension");
- }
-
- return success();
- }
-
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
// TODO: Support Memref UnPackOp. Temporarily return failure.
if (!unpackOp.hasPureTensorSemantics())
return failure();
+ ShapedType sourceType = unpackOp.getSourceType();
ShapedType destType = unpackOp.getDestType();
- if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
- failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
- unpackOp.getStaticTiles())) &&
- !unpackOp.isLikeUnPad()) {
- return failure();
- }
- ShapedType sourceType = unpackOp.getSourceType();
+ if (PackOp::requirePaddingValueStrict(
+ destType.getShape(), unpackOp.getInnerDimsPos(),
+ sourceType.getShape(), unpackOp.getOuterDimsPerm(),
+ unpackOp.getMixedTiles()))
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no padding behavior");
+
+ PackingMetadata metadata;
+ ArrayRef<int64_t> inputShape = sourceType.getShape();
+ SmallVector<int64_t> unpackInverseSrcPerm =
+ getUnPackInverseSrcPerm(unpackOp, metadata);
+ SmallVector<int64_t> transpPerm =
+ invertPermutationVector(unpackInverseSrcPerm);
+
+ if (!isPackWithoutTranspose(transpPerm, inputShape))
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no transpose behavior");
+
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
- return failure();
+ return rewriter.notifyMatchFailure(
+ unpackOp, "unable to get reshape reassociation indices");
Value collapsed = insertCollapse(
rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
getReassociationIndicesAttribute(rewriter, *reassociation));
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 6979770154bab..e1b4e8a047ff4 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -157,9 +157,12 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
+// There is no enough info to check whether there is no padding from the
+// dynamic input/output shapes.
+//
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK: tensor.collapse
-// CHECK-NOT: linalg.unpack
+// CHECK-NOT: tensor.collapse
+// CHECK: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
@@ -394,3 +397,196 @@ func.func @unpad_like_unpack_with_transpose(%arg0: tensor<32x1x16x64xf32>) -> te
%0 = linalg.unpack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x1x16x64xf32> -> tensor<32x64x16xf32>
return %0 : tensor<32x64x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_3d_to_5d(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] output_shape [3, 1, 1, 32, 64] : tensor<3x32x64xf32> into tensor<3x1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<3x1x1x32x64xf32>
+func.func @pack_3d_to_5d(%arg0: tensor<3x32x64xf32>) -> tensor<3x1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x1x1x32x64xf32>
+ %0 = linalg.pack %arg0 inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x32x64xf32> -> tensor<3x1x1x32x64xf32>
+ return %0 : tensor<3x1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_3d_to_5d_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] output_shape [3, 1, 1, 32, 64] : tensor<3x32x64xf32> into tensor<3x1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<3x1x1x32x64xf32>
+func.func @pack_3d_to_5d_with_outer_dims_perm(%arg0: tensor<3x32x64xf32>) -> tensor<3x1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x1x1x32x64xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x32x64xf32> -> tensor<3x1x1x32x64xf32>
+ return %0 : tensor<3x1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_3d_to_5d_dynamic_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x?x64xf32>)
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4]] output_shape [32, 1, %[[DIM1]], 1, 64] : tensor<32x?x64xf32> into tensor<32x1x?x1x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<32x1x?x1x64xf32>
+func.func @pack_3d_to_5d_dynamic_shape(%arg0: tensor<32x?x64xf32>) -> tensor<32x1x?x1x64xf32> {
+ %c1 = arith.constant 1 : index
+ %dim1 = tensor.dim %arg0, %c1 : tensor<32x?x64xf32>
+ %empty = tensor.empty(%dim1) : tensor<32x1x?x1x64xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [1, 64] into %empty : tensor<32x?x64xf32> -> tensor<32x1x?x1x64xf32>
+ return %0 : tensor<32x1x?x1x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_nd_with_non_unit_outer_tile_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x3x32x64xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: linalg.pack
+func.func @pack_nd_with_non_unit_outer_tile_dims_perm(%arg0: tensor<3x3x32x64xf32>) -> tensor<3x3x1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x3x1x1x32x64xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [2, 3] inner_tiles = [32, 64] into %empty : tensor<3x3x32x64xf32> -> tensor<3x3x1x1x32x64xf32>
+ return %0 : tensor<3x3x1x1x32x64xf32>
+
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_non_unit_packed_dims(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: linalg.pack
+func.func @pack_with_non_unit_packed_dims(%arg0: tensor<4x4xf32>) -> tensor<2x2x2x2xf32> {
+ %empty = tensor.empty() : tensor<2x2x2x2xf32>
+ %0 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %empty : tensor<4x4xf32> -> tensor<2x2x2x2xf32>
+ return %0 : tensor<2x2x2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_non_unit_inner_tile_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: linalg.pack
+func.func @pack_with_non_unit_inner_tile_dims_perm(%arg0: tensor<32x32xf32>) -> tensor<1x1x32x32xf32> {
+ %empty = tensor.empty() : tensor<1x1x32x32xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %empty : tensor<32x32xf32> -> tensor<1x1x32x32xf32>
+ return %0 : tensor<1x1x32x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_5d_to_3d(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] : tensor<3x1x1x32x64xf32> into tensor<3x32x64xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<3x32x64xf32>
+func.func @unpack_5d_to_3d(%arg0: tensor<3x1x1x32x64xf32>) -> tensor<3x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x32x64xf32>
+ %0 = linalg.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x1x1x32x64xf32> -> tensor<3x32x64xf32>
+ return %0 : tensor<3x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_5d_to_3d_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] : tensor<3x1x1x32x64xf32> into tensor<3x32x64xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<3x32x64xf32>
+func.func @unpack_5d_to_3d_with_outer_dims_perm(%arg0: tensor<3x1x1x32x64xf32>) -> tensor<3x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x32x64xf32>
+ %0 = linalg.unpack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x1x1x32x64xf32> -> tensor<3x32x64xf32>
+ return %0 : tensor<3x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_5d_to_3d_dynamic_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x1x?x1x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4]] : tensor<32x1x?x1x64xf32> into tensor<32x?x64xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<32x?x64xf32>
+func.func @unpack_5d_to_3d_dynamic_shape(%arg0: tensor<32x1x?x1x64xf32>) -> tensor<32x?x64xf32> {
+ %c2 = arith.constant 2 : index
+ %dim2 = tensor.dim %arg0, %c2 : tensor<32x1x?x1x64xf32>
+ %empty = tensor.empty(%dim2) : tensor<32x?x64xf32>
+ %0 = linalg.unpack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [1, 64] into %empty : tensor<32x1x?x1x64xf32> -> tensor<32x?x64xf32>
+ return %0 : tensor<32x?x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_nd_with_non_unit_outer_tile_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x3x1x1x32x64xf32>)
+// CHECK-NOT: tensor.collapse_shape
+// CHECK: linalg.unpack
+func.func @unpack_nd_with_non_unit_outer_tile_dims_perm(%arg0: tensor<3x3x1x1x32x64xf32>) -> tensor<3x3x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x3x32x64xf32>
+ %0 = linalg.unpack %arg0 outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [2, 3] inner_tiles = [32, 64] into %empty : tensor<3x3x1x1x32x64xf32> -> tensor<3x3x32x64xf32...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Jerry Shih (JerryShih) ChangesIf there is no transposition/padding semantic for pack/unpack, Patch is 23.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/204971.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 95383e6262f71..e6954814848b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -195,10 +195,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
```
}];
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
- TensorOrMemRef<[AnyType]>:$dest,
+ TensorOrMemRef<[AnyType]>:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
- DenseI64ArrayAttr:$inner_dims_pos,
+ DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);
@@ -235,7 +235,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> outerDimsPerm = {});
// Returns the shape of the packed type. It is a shared helper that helps
- // type inference methods in a way that ensures that they agree on which
+ // type inference methods in a way that ensures that they agree on which
// dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
@@ -252,7 +252,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<OpFoldResult> innerTiles);
// Same as above function but here dynamic dimensions are assumed
- // to require padding.
+ // to require padding except the unit-tile size dims.
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1a56c5a483e73..ce02f8d0bc174 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5621,13 +5621,16 @@ bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
- if (ShapedType::isDynamic(inputShape[pos]) ||
- ShapedType::isDynamic(outputTileSizes[pos]))
- return true;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile)
return true;
assert(*constantTile != 0 && "static tile size can't be zero");
+ // No padding is needed for unit tile size.
+ if(*constantTile == 1)
+ continue;
+ if (ShapedType::isDynamic(inputShape[pos]) ||
+ ShapedType::isDynamic(outputTileSizes[pos]))
+ return true;
if (inputShape[pos] % (*constantTile) != 0)
return true;
}
@@ -5900,22 +5903,6 @@ static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
return true;
}
-/// Returns true if the pack op does not need a padding value.
-static bool paddingIsNotNeeded(PackOp op) {
- auto srcType = op.getSourceType();
- auto innerDimsPos = op.getInnerDimsPos();
- auto innerTiles = op.getStaticInnerTiles();
- if (ShapedType::isDynamicShape(innerTiles))
- return false;
- for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
- if (srcType.isDynamicDim(pos) && tileSize != 1)
- return false;
- }
- return !PackOp::requirePaddingValue(
- srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
- op.getOuterDimsPerm(), op.getMixedTiles());
-}
-
/// Returns true if the `srcShape` or `destShape` is different from the one in
/// `packOp` and populates each with the inferred static shape.
static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
@@ -5969,7 +5956,13 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
// Fold optional PaddingValue operand away if padding is not needed.
- if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
+ // Reject the dynamic tile size here.
+ if (packOp.getPaddingValue() &&
+ !ShapedType::isDynamicShape(packOp.getStaticInnerTiles()) &&
+ !requirePaddingValueStrict(
+ packOp.getSourceType().getShape(), packOp.getInnerDimsPos(),
+ packOp.getDestType().getShape(), packOp.getOuterDimsPerm(),
+ packOp.getMixedTiles())) {
rewriter.startOpModification(packOp);
packOp.getPaddingValueMutable().clear();
rewriter.finalizeOpModification(packOp);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 993eae62535c3..76024ba1fda77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -18,31 +18,27 @@ namespace mlir {
namespace linalg {
namespace {
-/// Returns the number of shape sizes that is either dynamic or greater than 1.
-static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
- return llvm::count_if(
- shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
-}
-
-/// Returns success() if there is only 1 dimension size in non-packed domain
-/// being greater than 1 and packing only happens on the dimension.
-/// Note: this method should only be used by pack/unpack to reshape conversion.
-/// It assumes that non-unit inner tile size must be used by the non-unit
-/// dimension.
-static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> srcShape,
- ArrayRef<int64_t> innerPackTileSize) {
- if (getNumGtOneDims(srcShape) > 1) {
- return rewriter.notifyMatchFailure(
- op, "expects non-packed domain to have at most one non-unit dims");
+/// Returns `true` if there is no need of transposition for the packed layout
+/// except the unit tile size.
+static bool isPackWithoutTranspose(ArrayRef<int64_t> dimsPos,
+ ArrayRef<int64_t> tileSize) {
+ SmallVector<int64_t> seqPos;
+ if (dimsPos.empty()) {
+ seqPos = llvm::to_vector<4>(llvm::seq<int64_t>(0, tileSize.size()));
+ dimsPos = seqPos;
}
- // Non-unit inner tile size must be used by the non-unit dimension. If not, it
- // will faill on getting reassociation maps.
- if (getNumGtOneDims(innerPackTileSize) > 1) {
- return rewriter.notifyMatchFailure(
- op, "expects at most one non-unit inner tiles");
+
+ int64_t lastNonUnitPos = 0;
+ for (auto [pos, tile] : llvm::zip_equal(dimsPos, tileSize)) {
+ if ((ShapedType::isDynamic(tile) || tile > 1)) {
+ if (pos < lastNonUnitPos) {
+ return false;
+ }
+ lastNonUnitPos = pos;
+ }
}
- return success();
+
+ return true;
}
// If the `linalgOp` represents a transpose, return the permutation vector for
@@ -88,25 +84,6 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
.getResult();
}
- /// Returns success() if it is only packing on the innermost dimension.
- LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
- PackOp packOp) const {
- auto outerDimsPerm = packOp.getOuterDimsPerm();
- if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
- return rewriter.notifyMatchFailure(
- packOp,
- "expects outer_dims_perm is empty or an identity permutation");
- }
-
- int64_t srcRank = packOp.getSourceRank();
- ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
- if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
- return rewriter.notifyMatchFailure(
- packOp, "expects packing at the innermost dimension");
- }
- return success();
- }
-
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
if (packOp.getPaddingValue())
@@ -115,19 +92,24 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (!packOp.hasPureTensorSemantics())
return failure();
+ PackingMetadata packingMetadata;
ShapedType sourceType = packOp.getSourceType();
- if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
- failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
- packOp.getStaticTiles())) &&
- !packOp.isLikePad()) {
- return failure();
- }
-
ShapedType destType = packOp.getDestType();
+ ArrayRef<int64_t> outputShape = destType.getShape();
+ SmallVector<int64_t> packInverseDestPerm =
+ getPackInverseDestPerm(packOp, packingMetadata);
+ SmallVector<int64_t> transpPerm =
+ invertPermutationVector(packInverseDestPerm);
+
+ if (!isPackWithoutTranspose(transpPerm, outputShape))
+ return rewriter.notifyMatchFailure(packOp,
+ "expects no transpose behavior");
+
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
- return failure();
+ return rewriter.notifyMatchFailure(
+ packOp, "unable to get reshape reassociation indices");
FailureOr<Value> expanded =
insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
*reassociation);
@@ -151,49 +133,38 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
operand, reassociation);
}
- /// Returns success() if it is unpacking on the innermost dimension.
- LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
- UnPackOp unpackOp) const {
- auto outerDimsPerm = unpackOp.getOuterDimsPerm();
- if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
- return rewriter.notifyMatchFailure(
- unpackOp,
- "expects outer_dims_perm is empty or an identity permutation");
- }
-
- ShapedType sourceType = unpackOp.getSourceType();
- ShapedType destType = unpackOp.getDestType();
- if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
- return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
-
- ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
- if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
- return rewriter.notifyMatchFailure(
- unpackOp, "expects unpacking on the innermost dimension");
- }
-
- return success();
- }
-
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
// TODO: Support Memref UnPackOp. Temporarily return failure.
if (!unpackOp.hasPureTensorSemantics())
return failure();
+ ShapedType sourceType = unpackOp.getSourceType();
ShapedType destType = unpackOp.getDestType();
- if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
- failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
- unpackOp.getStaticTiles())) &&
- !unpackOp.isLikeUnPad()) {
- return failure();
- }
- ShapedType sourceType = unpackOp.getSourceType();
+ if (PackOp::requirePaddingValueStrict(
+ destType.getShape(), unpackOp.getInnerDimsPos(),
+ sourceType.getShape(), unpackOp.getOuterDimsPerm(),
+ unpackOp.getMixedTiles()))
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no padding behavior");
+
+ PackingMetadata metadata;
+ ArrayRef<int64_t> inputShape = sourceType.getShape();
+ SmallVector<int64_t> unpackInverseSrcPerm =
+ getUnPackInverseSrcPerm(unpackOp, metadata);
+ SmallVector<int64_t> transpPerm =
+ invertPermutationVector(unpackInverseSrcPerm);
+
+ if (!isPackWithoutTranspose(transpPerm, inputShape))
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no transpose behavior");
+
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
- return failure();
+ return rewriter.notifyMatchFailure(
+ unpackOp, "unable to get reshape reassociation indices");
Value collapsed = insertCollapse(
rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
getReassociationIndicesAttribute(rewriter, *reassociation));
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 6979770154bab..e1b4e8a047ff4 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -157,9 +157,12 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
+// There is no enough info to check whether there is no padding from the
+// dynamic input/output shapes.
+//
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK: tensor.collapse
-// CHECK-NOT: linalg.unpack
+// CHECK-NOT: tensor.collapse
+// CHECK: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
@@ -394,3 +397,196 @@ func.func @unpad_like_unpack_with_transpose(%arg0: tensor<32x1x16x64xf32>) -> te
%0 = linalg.unpack %arg0 inner_dims_pos = [1] inner_tiles = [64] into %empty : tensor<32x1x16x64xf32> -> tensor<32x64x16xf32>
return %0 : tensor<32x64x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_3d_to_5d(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] output_shape [3, 1, 1, 32, 64] : tensor<3x32x64xf32> into tensor<3x1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<3x1x1x32x64xf32>
+func.func @pack_3d_to_5d(%arg0: tensor<3x32x64xf32>) -> tensor<3x1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x1x1x32x64xf32>
+ %0 = linalg.pack %arg0 inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x32x64xf32> -> tensor<3x1x1x32x64xf32>
+ return %0 : tensor<3x1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_3d_to_5d_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x32x64xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] output_shape [3, 1, 1, 32, 64] : tensor<3x32x64xf32> into tensor<3x1x1x32x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<3x1x1x32x64xf32>
+func.func @pack_3d_to_5d_with_outer_dims_perm(%arg0: tensor<3x32x64xf32>) -> tensor<3x1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x1x1x32x64xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x32x64xf32> -> tensor<3x1x1x32x64xf32>
+ return %0 : tensor<3x1x1x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_3d_to_5d_dynamic_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x?x64xf32>)
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4]] output_shape [32, 1, %[[DIM1]], 1, 64] : tensor<32x?x64xf32> into tensor<32x1x?x1x64xf32>
+// CHECK: return %[[EXPANDED]] : tensor<32x1x?x1x64xf32>
+func.func @pack_3d_to_5d_dynamic_shape(%arg0: tensor<32x?x64xf32>) -> tensor<32x1x?x1x64xf32> {
+ %c1 = arith.constant 1 : index
+ %dim1 = tensor.dim %arg0, %c1 : tensor<32x?x64xf32>
+ %empty = tensor.empty(%dim1) : tensor<32x1x?x1x64xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [1, 64] into %empty : tensor<32x?x64xf32> -> tensor<32x1x?x1x64xf32>
+ return %0 : tensor<32x1x?x1x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_nd_with_non_unit_outer_tile_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x3x32x64xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: linalg.pack
+func.func @pack_nd_with_non_unit_outer_tile_dims_perm(%arg0: tensor<3x3x32x64xf32>) -> tensor<3x3x1x1x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x3x1x1x32x64xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [2, 3] inner_tiles = [32, 64] into %empty : tensor<3x3x32x64xf32> -> tensor<3x3x1x1x32x64xf32>
+ return %0 : tensor<3x3x1x1x32x64xf32>
+
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_non_unit_packed_dims(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: linalg.pack
+func.func @pack_with_non_unit_packed_dims(%arg0: tensor<4x4xf32>) -> tensor<2x2x2x2xf32> {
+ %empty = tensor.empty() : tensor<2x2x2x2xf32>
+ %0 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %empty : tensor<4x4xf32> -> tensor<2x2x2x2xf32>
+ return %0 : tensor<2x2x2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_with_non_unit_inner_tile_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: linalg.pack
+func.func @pack_with_non_unit_inner_tile_dims_perm(%arg0: tensor<32x32xf32>) -> tensor<1x1x32x32xf32> {
+ %empty = tensor.empty() : tensor<1x1x32x32xf32>
+ %0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %empty : tensor<32x32xf32> -> tensor<1x1x32x32xf32>
+ return %0 : tensor<1x1x32x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_5d_to_3d(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] : tensor<3x1x1x32x64xf32> into tensor<3x32x64xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<3x32x64xf32>
+func.func @unpack_5d_to_3d(%arg0: tensor<3x1x1x32x64xf32>) -> tensor<3x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x32x64xf32>
+ %0 = linalg.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x1x1x32x64xf32> -> tensor<3x32x64xf32>
+ return %0 : tensor<3x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_5d_to_3d_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x1x1x32x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2, 3], [4]] : tensor<3x1x1x32x64xf32> into tensor<3x32x64xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<3x32x64xf32>
+func.func @unpack_5d_to_3d_with_outer_dims_perm(%arg0: tensor<3x1x1x32x64xf32>) -> tensor<3x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x32x64xf32>
+ %0 = linalg.unpack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [32, 64] into %empty : tensor<3x1x1x32x64xf32> -> tensor<3x32x64xf32>
+ return %0 : tensor<3x32x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_5d_to_3d_dynamic_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x1x?x1x64xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4]] : tensor<32x1x?x1x64xf32> into tensor<32x?x64xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<32x?x64xf32>
+func.func @unpack_5d_to_3d_dynamic_shape(%arg0: tensor<32x1x?x1x64xf32>) -> tensor<32x?x64xf32> {
+ %c2 = arith.constant 2 : index
+ %dim2 = tensor.dim %arg0, %c2 : tensor<32x1x?x1x64xf32>
+ %empty = tensor.empty(%dim2) : tensor<32x?x64xf32>
+ %0 = linalg.unpack %arg0 outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [1, 64] into %empty : tensor<32x1x?x1x64xf32> -> tensor<32x?x64xf32>
+ return %0 : tensor<32x?x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_nd_with_non_unit_outer_tile_dims_perm(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x3x1x1x32x64xf32>)
+// CHECK-NOT: tensor.collapse_shape
+// CHECK: linalg.unpack
+func.func @unpack_nd_with_non_unit_outer_tile_dims_perm(%arg0: tensor<3x3x1x1x32x64xf32>) -> tensor<3x3x32x64xf32> {
+ %empty = tensor.empty() : tensor<3x3x32x64xf32>
+ %0 = linalg.unpack %arg0 outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [2, 3] inner_tiles = [32, 64] into %empty : tensor<3x3x1x1x32x64xf32> -> tensor<3x3x32x64xf32...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…() for padding_value canonicalization. Move paddingIsNotNeeded() logic into requirePaddingValueStrict().
…pe for more cases. If there is no transposition/padding semantic for pack/unpack, we could try to use `mlir::getReassociationIndicesForReshape()` to get the CollapseShapeOp/ExpandShapeOp form.
0158fa6 to
cb6b549
Compare
checked. |
|
@hanhanW |
| if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) { | ||
| // Reject the dynamic tile size here. | ||
| if (packOp.getPaddingValue() && | ||
| !ShapedType::isDynamicShape(packOp.getStaticInnerTiles()) && |
There was a problem hiding this comment.
@banach-space @joker-eph
could you help to check this pr?
The requirePaddingValueStrict() will check the folding result of tile size. But it looks like the canonicalize pass only checking for static size.
llvm-project/mlir/test/Dialect/Linalg/canonicalize.mlir
Lines 1550 to 1565 in e0cc08d
If there is no transposition/padding semantic for pack/unpack,
we could try to use
mlir::getReassociationIndicesForReshape()to get the CollapseShapeOp/ExpandShapeOp form.