Skip to content

Commit 9baca01

Browse files
authored
[mlir][tensor] Consolidate tensor fold patterns and rename related file (#192820)
This PR moves `MergeConsecutiveExtractSlice` from `MergeConsecutiveInsertExtractSlicePatterns.cpp` to `FoldTensorSubsetOps.cpp`, and removes the duplicate `MergeConsecutiveInsertSlice` pattern in favor of `InsertSliceOfInsertSliceFolder`, which already exists in `FoldTensorSubsetOps.cpp` and provides equivalent functionality with greater stability. Since the merge-related patterns have been fully migrated out, `MergeConsecutiveInsertExtractSlicePatterns.cpp` is renamed to `DropRedundantRankExpansionPatterns.cpp` to better reflect its remaining responsibilities.
1 parent 6a06c8b commit 9baca01

5 files changed

Lines changed: 96 additions & 77 deletions

File tree

mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
add_mlir_dialect_library(MLIRTensorTransforms
22
BufferizableOpInterfaceImpl.cpp
33
ConcatOpPatterns.cpp
4+
DropRedundantRankExpansionPatterns.cpp
45
EmptyOpPatterns.cpp
56
ExtractSliceFromReshapeUtils.cpp
67
FoldTensorSubsetOps.cpp
78
IndependenceTransforms.cpp
8-
MergeConsecutiveInsertExtractSlicePatterns.cpp
99
ReshapePatterns.cpp
1010
RewriteAsConstant.cpp
1111
RuntimeOpVerification.cpp

mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp renamed to mlir/lib/Dialect/Tensor/Transforms/DropRedundantRankExpansionPatterns.cpp

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
1+
//===- DropRedundantRankExpansionPatterns.cpp -----------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -18,66 +18,6 @@ using namespace mlir;
1818
using namespace mlir::tensor;
1919

2020
namespace {
21-
/// Merges consecutive tensor.extract_slice ops into one.
22-
// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
23-
struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
24-
using OpRewritePattern::OpRewritePattern;
25-
26-
LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
27-
PatternRewriter &rewriter) const override {
28-
auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
29-
if (!prevOp)
30-
return failure();
31-
32-
SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
33-
if (failed(affine::mergeOffsetsSizesAndStrides(
34-
rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35-
newOffsets, newSizes, newStrides)))
36-
return failure();
37-
38-
rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
39-
prevOp.getSource(), newOffsets,
40-
newSizes, newStrides);
41-
return success();
42-
}
43-
};
44-
45-
/// Merges consecutive tensor.insert_slice ops into one.
46-
// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
47-
template <typename OpTy>
48-
struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
49-
using OpRewritePattern<OpTy>::OpRewritePattern;
50-
51-
LogicalResult matchAndRewrite(OpTy nextOp,
52-
PatternRewriter &rewriter) const override {
53-
auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
54-
if (!prevOp)
55-
return failure();
56-
57-
if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
58-
return failure();
59-
60-
// The first insert_slice op should be rank reducing to make sure we cover
61-
// the full source tensor to be inserted in the second insert_slice op.
62-
SliceVerificationResult result =
63-
isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
64-
if (result != SliceVerificationResult::Success)
65-
return failure();
66-
67-
// Dynamic dimensions can pass rank reducing check in the above, e.g,
68-
// inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
69-
// the dynamic size covers the full tensor.
70-
if (!prevOp.getSourceType().hasStaticShape() ||
71-
!prevOp.getDestType().hasStaticShape())
72-
return failure();
73-
74-
rewriter.replaceOpWithNewOp<OpTy>(
75-
nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76-
nextOp.getMixedSizes(), nextOp.getMixedStrides());
77-
return success();
78-
}
79-
};
80-
8121
/// Drop redundant rank expansion of insert_slice that are directly followed
8222
/// by extract_slice. E.g.:
8323
/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
@@ -227,14 +167,6 @@ struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
227167
};
228168
} // namespace
229169

230-
void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
231-
RewritePatternSet &patterns) {
232-
patterns.add<MergeConsecutiveExtractSlice,
233-
MergeConsecutiveInsertSlice<InsertSliceOp>,
234-
MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
235-
patterns.getContext());
236-
}
237-
238170
void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
239171
RewritePatternSet &patterns) {
240172
patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,48 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
220220
}
221221
};
222222

223-
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
224-
populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
225-
patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
226-
InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
227-
patterns.getContext());
228-
}
223+
struct MergeConsecutiveExtractSlice
224+
: public OpRewritePattern<tensor::ExtractSliceOp> {
225+
using OpRewritePattern::OpRewritePattern;
226+
227+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp nextOp,
228+
PatternRewriter &rewriter) const override {
229+
auto prevOp = nextOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
230+
if (!prevOp)
231+
return failure();
232+
233+
SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
234+
if (failed(affine::mergeOffsetsSizesAndStrides(
235+
rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
236+
newOffsets, newSizes, newStrides)))
237+
return failure();
238+
239+
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
240+
nextOp, nextOp.getType(), prevOp.getSource(), newOffsets, newSizes,
241+
newStrides);
242+
return success();
243+
}
244+
};
229245

230246
void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
231247
RewritePatternSet &patterns) {
232248
patterns.add<TransferReadOfExtractSliceOpFolder,
233249
InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
234250
}
235251

252+
void tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
253+
RewritePatternSet &patterns) {
254+
patterns.add<MergeConsecutiveExtractSlice,
255+
InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
256+
InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
257+
patterns.getContext());
258+
}
259+
260+
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
261+
populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
262+
populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
263+
}
264+
236265
//===----------------------------------------------------------------------===//
237266
// Pass registration
238267
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func.func @insert_slice_rank_reducing_dynamic_shape(
8080
}
8181

8282
// CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
83-
// CHECK-COUNT-2: tensor.insert_slice
83+
// CHECK-COUNT-1: tensor.insert_slice
8484

8585
// -----
8686

mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,61 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
415415
}
416416
return %0: tensor<12x34xf32>
417417
}
418+
419+
// -----
420+
421+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
422+
// CHECK-LABEL: func.func @extract_slice_same_rank
423+
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
424+
// CHECK: %[[OFFSET:.+]] = affine.apply #[[$map]]()[%[[OFFSET1]], %[[OFFSET0]]]
425+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
426+
// CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32>
427+
func.func @extract_slice_same_rank(
428+
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> {
429+
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
430+
%1 = tensor.extract_slice %0[7, 8, 9, %offset1] [8, 16, 32, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<8x16x32x?xf32>
431+
return %1: tensor<8x16x32x?xf32>
432+
}
433+
434+
// -----
435+
436+
// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
437+
// CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
438+
func.func @extract_slice_rank_reducing_consumer(
439+
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
440+
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
441+
%1 = tensor.extract_slice %0[7, 8, 9, %offset1] [1, 16, 1, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<16x?xf32>
442+
return %1: tensor<16x?xf32>
443+
}
444+
445+
// -----
446+
447+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
448+
// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
449+
// CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
450+
// CHECK: %[[OFFSET:.+]] = affine.apply #[[$map]]()[%[[OFFSET1]], %[[OFFSET0]]]
451+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32>
452+
// CHECK: return %[[EXTRACT]] : tensor<8x?xf32>
453+
func.func @extract_slice_rank_reducing_producer(
454+
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
455+
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
456+
%1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
457+
return %1: tensor<8x?xf32>
458+
}
459+
460+
// -----
461+
462+
// CHECK: #[[$map_0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
463+
// CHECK: #[[$map_1:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
464+
// CHECK-LABEL: func.func @extract_slice_non_one_stride
465+
// CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index)
466+
// CHECK: %[[OFFSET:.+]] = affine.apply #[[$map_0]]()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]]
467+
// CHECK: %[[STRIDE:.+]] = affine.apply #[[$map_1]]()[%[[STRIDE1]], %[[STRIDE0]]]
468+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32>
469+
// CHECK: return %[[EXTRACT]] : tensor<?xf32>
470+
func.func @extract_slice_non_one_stride(
471+
%src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> {
472+
%0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32>
473+
%1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32>
474+
return %1: tensor<?xf32>
475+
}

0 commit comments

Comments
 (0)