Skip to content

Commit ff614a5

Browse files
[mlir][Interfaces] LISH: Add helpers for hyperrectangular subsets (#70628)
The majority of subset ops operate on hyperrectangular subsets. This commit adds a new optional interface method (`getAccessedHyperrectangularSlice`) that can be implemented by such subset ops. If implemented, the other `operatesOn...` interface methods of the `SubsetOpInterface` do not have to be implemented anymore. The comparison logic for hyperrectangular subsets (is disjoint/equivalent) is implemented with `ValueBoundsOpInterface`. This makes the subset hoisting more powerful: simple cases where two different SSA values always have the same runtime value can now be supported.
1 parent 5b6ceaf commit ff614a5

File tree

11 files changed

+285
-97
lines changed

11 files changed

+285
-97
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
220220
AllElementTypesMatch<["source", "dest"]>,
221221
BufferizableOpInterface, DestinationStyleOpInterface,
222222
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
223-
DeclareOpInterfaceMethods<SubsetOpInterface>,
223+
DeclareOpInterfaceMethods<SubsetOpInterface,
224+
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
224225
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
225226
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
226227
"buildSubsetExtraction", "isEquivalentSubset"]>,

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
268268

269269
public:
270270
void dump() const { llvm::errs() << *this << "\n"; }
271+
272+
MLIRContext *getContext() const {
273+
return is<Attribute>() ? get<Attribute>().getContext()
274+
: get<Value>().getContext();
275+
}
271276
};
272277

273278
// Temporarily exit the MLIR namespace to add casting support as later code in

mlir/include/mlir/Interfaces/SubsetOpInterface.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
1111

1212
#include "mlir/IR/OpDefinition.h"
13+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1314

1415
namespace mlir {
1516
class SubsetOpInterface;
@@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op);
2728
/// `DestinationStyleOpInterface`.
2829
OpResult defaultGetUpdatedDestination(Operation *op);
2930

30-
/// Default implementation of `isEquivalentSubset`.
31+
/// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`.
3132
bool defaultIsEquivalentSubset(Operation *op, Value candidate,
3233
function_ref<bool(Value, Value)> equivalenceFn);
3334

35+
/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`.
36+
bool defaultOperatesOnEquivalentSubset(
37+
Operation *op, SubsetOpInterface candidate,
38+
function_ref<bool(Value, Value)> equivalenceFn);
39+
40+
/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`.
41+
bool defaultOperatesOnDisjointSubset(
42+
Operation *op, SubsetOpInterface candidate,
43+
function_ref<bool(Value, Value)> equivalenceFn);
44+
45+
/// Return the container that the given subset op is operating on.
46+
Value getTensorContainer(Operation *op);
47+
3448
/// Verify `SubsetOpInterface`.
3549
LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
3650

mlir/include/mlir/Interfaces/SubsetOpInterface.td

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
3232
hyperrectangular slice.
3333
- `tensor.gather/scatter` describe the subset as list of indices. (Not
3434
implemented yet.)
35-
36-
Note: This interface does not expose any interface methods to get a
37-
description of the accessed subset. That is because there is currently no
38-
efficient way to describe arbitrary subsets. This interface merely provides
39-
interface methods to check if two subsets are equivalent or disjoint.
4035
}];
4136

4237
let cppNamespace = "::mlir";
@@ -46,31 +41,75 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
4641
Return "true" if this op and the given candidate subset op operate on
4742
equivalent subsets. Return "false" if the two subsets are disjoint
4843
or cannot be proven to be equivalent.
44+
45+
This interface method does not have to be implemented if
46+
`getAccessedHyperrectangularSlice` is implemented.
4947
}],
5048
/*retType=*/"bool",
5149
/*methodName=*/"operatesOnEquivalentSubset",
5250
/*args=*/(ins
5351
"::mlir::SubsetOpInterface":$candidate,
54-
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
52+
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
53+
/*methodBody=*/"",
54+
/*defaultImplementation=*/[{
55+
return ::mlir::detail::defaultOperatesOnEquivalentSubset(
56+
$_op, candidate, equivalenceFn);
57+
}]
5558
>,
5659
InterfaceMethod<
5760
/*desc=*/[{
5861
Return "true" if this op and the given candidate subset op operate on
5962
disjoint subsets. Return "false" if the two subsets are equivalent,
6063
overlapping or cannot be proven to be disjoint.
64+
65+
This interface method does not have to be implemented if
66+
`getAccessedHyperrectangularSlice` is implemented.
6167
}],
6268
/*retType=*/"bool",
6369
/*methodName=*/"operatesOnDisjointSubset",
6470
/*args=*/(ins
6571
"::mlir::SubsetOpInterface":$candidate,
66-
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
72+
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
73+
/*methodBody=*/"",
74+
/*defaultImplementation=*/[{
75+
return ::mlir::detail::defaultOperatesOnDisjointSubset(
76+
$_op, candidate, equivalenceFn);
77+
}]
78+
>,
79+
InterfaceMethod<
80+
/*desc=*/[{
81+
If this op operates on a hyperrectangular subset, return a
82+
description of the subset in terms of offsets, sizes and strides.
83+
Otherwise, return "failure".
84+
85+
This interface method is a convenience method for the most common case
86+
of hyperrectangular subset ops. It is optional. If it is implemented,
87+
`operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not
88+
have to be implemented.
89+
}],
90+
/*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>",
91+
/*methodName=*/"getAccessedHyperrectangularSlice",
92+
/*args=*/(ins),
93+
/*methodBody=*/"",
94+
/*defaultImplementation=*/[{
95+
return ::mlir::failure();
96+
}]
6797
>,
6898
];
6999

70100
let verify = [{
71101
return ::mlir::detail::verifySubsetOpInterface(
72102
::mlir::cast<::mlir::SubsetOpInterface>($_op));
73103
}];
104+
105+
let extraClassDeclaration = [{
106+
/// Return the container that this operation is operating on. In case of an
107+
/// extraction op, the container is the source tensor. In case of an
108+
/// insertion op, the container is the destination tensor.
109+
Value getTensorContainer() {
110+
return ::mlir::detail::getTensorContainer(getOperation());
111+
}
112+
}];
74113
}
75114

76115
def SubsetExtractionOpInterface

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,31 @@
2121
namespace mlir {
2222
class OffsetSizeAndStrideOpInterface;
2323

24+
/// A hyperrectangular slice, represented as a list of offsets, sizes and
25+
/// strides.
26+
class HyperrectangularSlice {
27+
public:
28+
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
29+
ArrayRef<OpFoldResult> sizes,
30+
ArrayRef<OpFoldResult> strides);
31+
32+
/// Create a hyperrectangular slice with unit strides.
33+
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
34+
ArrayRef<OpFoldResult> sizes);
35+
36+
/// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
37+
HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
38+
39+
ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
40+
ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
41+
ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
42+
43+
private:
44+
SmallVector<OpFoldResult> mixedOffsets;
45+
SmallVector<OpFoldResult> mixedSizes;
46+
SmallVector<OpFoldResult> mixedStrides;
47+
};
48+
2449
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
2550

2651
/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
@@ -182,12 +207,34 @@ class ValueBoundsConstraintSet {
182207
std::optional<int64_t> dim1 = std::nullopt,
183208
std::optional<int64_t> dim2 = std::nullopt);
184209

210+
/// Compute whether the given values/attributes are equal. Return "failure" if
211+
/// equality could not be determined.
212+
///
213+
/// `ofr1`/`ofr2` must be of index type.
214+
static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
215+
185216
/// Return "true" if the given slices are guaranteed to be overlapping.
186217
/// Return "false" if the given slices are guaranteed to be non-overlapping.
187218
/// Return "failure" if unknown.
188-
static FailureOr<bool>
189-
areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
190-
OffsetSizeAndStrideOpInterface slice2);
219+
///
220+
/// Slices are overlapping if for all dimensions:
221+
/// * offset1 + size1 * stride1 <= offset2
222+
/// * and offset2 + size2 * stride2 <= offset1
223+
///
224+
/// Slice are non-overlapping if the above constraint is not satisfied for
225+
/// at least one dimension.
226+
static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
227+
HyperrectangularSlice slice1,
228+
HyperrectangularSlice slice2);
229+
230+
/// Return "true" if the given slices are guaranteed to be equivalent.
231+
/// Return "false" if the given slices are guaranteed to be non-equivalent.
232+
/// Return "failure" if unknown.
233+
///
234+
/// Slices are equivalent if their offsets, sizes and strices are equal.
235+
static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
236+
HyperrectangularSlice slice1,
237+
HyperrectangularSlice slice2);
191238

192239
/// Add a bound for the given index-typed value or shaped value. This function
193240
/// returns a builder that adds the bound.

mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp

Lines changed: 6 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -17,73 +17,12 @@ using namespace mlir::tensor;
1717

1818
namespace {
1919

20-
/// Return the tensor that the given subset op operates on.
21-
Value getContainerOperand(SubsetOpInterface op) {
22-
if (auto extractionOp =
23-
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
24-
return extractionOp.getSourceOperand().get();
25-
if (auto insertionOp =
26-
dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
27-
return insertionOp.getDestinationOperand().get();
28-
llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
29-
}
30-
31-
/// Return "true" if the two ops operate on an equivalent subset.
32-
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
33-
/// if the two ops operate non-equivalent subsets, if equivalence cannot be
34-
/// determined or if `op1` is not a subset op.
35-
template <typename OpTy>
36-
bool operateOnEquivalentSubsets(
37-
OpTy op1, SubsetOpInterface op2,
38-
function_ref<bool(Value, Value)> equivalenceFn) {
39-
auto offsetsSizesAndStrides2 =
40-
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
41-
if (!offsetsSizesAndStrides2)
42-
return false;
43-
if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
44-
isEqualConstantIntOrValue))
45-
return false;
46-
return equivalenceFn(
47-
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
48-
getContainerOperand(op2));
49-
}
50-
51-
/// Return "true" if the two ops operate on a disjoint subsets.
52-
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
53-
/// if the two ops operate non-disjoint subsets, if disjointness cannot be
54-
/// determined or if `op1` is not a subset op.
55-
template <typename OpTy>
56-
bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
57-
function_ref<bool(Value, Value)> equivalenceFn) {
58-
auto offsetsSizesAndStrides2 =
59-
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
60-
if (!offsetsSizesAndStrides2)
61-
return false;
62-
FailureOr<bool> overlappingSlices =
63-
ValueBoundsConstraintSet::areOverlappingSlices(op1,
64-
offsetsSizesAndStrides2);
65-
if (failed(overlappingSlices) || *overlappingSlices)
66-
return false;
67-
return equivalenceFn(
68-
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
69-
getContainerOperand(op2));
70-
}
71-
7220
struct ExtractSliceOpSubsetOpInterface
7321
: public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
7422
tensor::ExtractSliceOp> {
75-
bool operatesOnEquivalentSubset(
76-
Operation *op, SubsetOpInterface candidate,
77-
function_ref<bool(Value, Value)> equivalenceFn) const {
78-
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
79-
return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
80-
}
81-
82-
bool operatesOnDisjointSubset(
83-
Operation *op, SubsetOpInterface candidate,
84-
function_ref<bool(Value, Value)> equivalenceFn) const {
85-
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
86-
return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
23+
FailureOr<HyperrectangularSlice>
24+
getAccessedHyperrectangularSlice(Operation *op) const {
25+
return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
8726
}
8827
};
8928

@@ -99,18 +38,9 @@ template <typename OpTy>
9938
struct InsertSliceLikeOpSubsetOpInterface
10039
: public SubsetOpInterface::ExternalModel<
10140
InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
102-
bool operatesOnEquivalentSubset(
103-
Operation *op, SubsetOpInterface candidate,
104-
function_ref<bool(Value, Value)> equivalenceFn) const {
105-
auto insertSliceOp = cast<OpTy>(op);
106-
return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
107-
}
108-
109-
bool operatesOnDisjointSubset(
110-
Operation *op, SubsetOpInterface candidate,
111-
function_ref<bool(Value, Value)> equivalenceFn) const {
112-
auto insertSliceOp = cast<OpTy>(op);
113-
return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
41+
FailureOr<HyperrectangularSlice>
42+
getAccessedHyperrectangularSlice(Operation *op) const {
43+
return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
11444
}
11545
};
11646

mlir/lib/Interfaces/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface
9393
DEPENDS
9494
MLIRDestinationStyleOpInterface
9595
MLIRSubsetOpInterfaceIncGen
96+
MLIRValueBoundsOpInterface
9697

9798
LINK_LIBS PUBLIC
9899
MLIRDestinationStyleOpInterface
99100
MLIRIR
101+
MLIRValueBoundsOpInterface
100102
)
101103

102104
add_mlir_interface_library(TilingInterface)

mlir/lib/Interfaces/SubsetOpInterface.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Interfaces/SubsetOpInterface.h"
1010
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
11+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1112

1213
#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
1314

@@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset(
4041
candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
4142
}
4243

44+
bool detail::defaultOperatesOnEquivalentSubset(
45+
Operation *op, SubsetOpInterface candidate,
46+
function_ref<bool(Value, Value)> equivalenceFn) {
47+
auto subsetOp = cast<SubsetOpInterface>(op);
48+
FailureOr<HyperrectangularSlice> slice =
49+
subsetOp.getAccessedHyperrectangularSlice();
50+
assert(succeeded(slice) &&
51+
"operatesOnEquivalentSubset must be implemented if "
52+
"getAccessedHyperrectangularSlice is not implemented");
53+
FailureOr<HyperrectangularSlice> otherSlice =
54+
candidate.getAccessedHyperrectangularSlice();
55+
if (failed(otherSlice))
56+
return false;
57+
if (!equivalenceFn(subsetOp.getTensorContainer(),
58+
candidate.getTensorContainer()))
59+
return false;
60+
FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
61+
op->getContext(), *slice, *otherSlice);
62+
return succeeded(equivalent) && *equivalent;
63+
}
64+
65+
bool detail::defaultOperatesOnDisjointSubset(
66+
Operation *op, SubsetOpInterface candidate,
67+
function_ref<bool(Value, Value)> equivalenceFn) {
68+
auto subsetOp = cast<SubsetOpInterface>(op);
69+
FailureOr<HyperrectangularSlice> slice =
70+
subsetOp.getAccessedHyperrectangularSlice();
71+
assert(succeeded(slice) &&
72+
"defaultOperatesOnDisjointSubset must be implemented if "
73+
"getAccessedHyperrectangularSlice is not implemented");
74+
FailureOr<HyperrectangularSlice> otherSlice =
75+
candidate.getAccessedHyperrectangularSlice();
76+
if (failed(otherSlice))
77+
return false;
78+
if (!equivalenceFn(subsetOp.getTensorContainer(),
79+
candidate.getTensorContainer()))
80+
return false;
81+
FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
82+
op->getContext(), *slice, *otherSlice);
83+
return succeeded(overlapping) && !*overlapping;
84+
}
85+
86+
Value detail::getTensorContainer(Operation *op) {
87+
if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
88+
return insertionOp.getDestinationOperand().get();
89+
return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
90+
}
91+
4392
LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
4493
if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
4594
isa<SubsetInsertionOpInterface>(op.getOperation())))

0 commit comments

Comments
 (0)