Skip to content

Commit 2861856

Browse files
authored
[mlir][Vector] Add utility for computing scalable value bounds (#83876)
This adds a new API built with the `ValueBoundsConstraintSet` to compute the bounds of possibly scalable quantities. It uses knowledge of the range of vscale (which is defined by the target architecture), to solve for the bound as either a constant or an expression in terms of vscale. The result is an `AffineMap` that will always take at most one parameter, vscale, and returns a single result, which is the bound of `value`. The API is defined as follows: ```c++ FailureOr<ConstantOrScalableBound> vector::ScalableValueBoundsConstraintSet::computeScalableBound( Value value, std::optional<int64_t> dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB = true, StopConditionFn stopCondition = nullptr); ``` Note: `ConstantOrScalableBound` is a thin wrapper over the `AffineMap` with a utility for converting the bound to a single quantity (i.e. a size and scalable flag). We believe this API could prove useful downstream in IREE (which uses a similar analysis to hoist allocas, which currently fails for scalable vectors).
1 parent 2152094 commit 2861856

10 files changed

+555
-28
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
10+
#define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
11+
12+
#include "mlir/Analysis/Presburger/IntegerRelation.h"
13+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
14+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
15+
16+
namespace mlir::vector {
17+
18+
namespace detail {
19+
20+
/// Parent class for the value bounds RTTIExtends. Uses protected inheritance to
21+
/// hide all ValueBoundsConstraintSet methods by default (as some do not use the
22+
/// ScalableValueBoundsConstraintSet, so may produce unexpected results).
23+
struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
24+
using ::mlir::ValueBoundsConstraintSet::ValueBoundsConstraintSet;
25+
};
26+
} // namespace detail
27+
28+
/// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds.
29+
struct ScalableValueBoundsConstraintSet
30+
: public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
31+
detail::ValueBoundsConstraintSet> {
32+
ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
33+
unsigned vscaleMax)
34+
: RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};
35+
36+
using RTTIExtends::bound;
37+
using RTTIExtends::StopConditionFn;
38+
39+
/// A thin wrapper over an `AffineMap` which can represent a constant bound,
40+
/// or a scalable bound (in terms of vscale). The `AffineMap` will always
41+
/// take at most one parameter, vscale, and returns a single result, which is
42+
/// the bound of value.
43+
struct ConstantOrScalableBound {
44+
AffineMap map;
45+
46+
struct BoundSize {
47+
int64_t baseSize{0};
48+
bool scalable{false};
49+
};
50+
51+
/// Get the (possibly) scalable size of the bound, returns failure if
52+
/// the bound cannot be represented as a single quantity.
53+
FailureOr<BoundSize> getSize() const;
54+
};
55+
56+
/// Computes a (possibly) scalable bound for a given value. This is
57+
/// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but
58+
/// uses knowledge of the range of vscale to compute either a constant
59+
/// bound, an expression in terms of vscale, or failure if no bound can
60+
/// be computed.
61+
///
62+
/// The resulting `AffineMap` will always take at most one parameter,
63+
/// vscale, and return a single result, which is the bound of `value`.
64+
///
65+
/// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
66+
/// `vscaleMax`, the resulting bound (if found), will be constant.
67+
static FailureOr<ConstantOrScalableBound>
68+
computeScalableBound(Value value, std::optional<int64_t> dim,
69+
unsigned vscaleMin, unsigned vscaleMax,
70+
presburger::BoundType boundType, bool closedUB = true,
71+
StopConditionFn stopCondition = nullptr);
72+
73+
/// Get the value of vscale. Returns `nullptr` vscale as not been encountered.
74+
Value getVscaleValue() const { return vscale; }
75+
76+
/// Sets the value of vscale. Asserts if vscale has already been set.
77+
void setVscale(vector::VectorScaleOp vscaleOp) {
78+
assert(!vscale && "expected vscale to be unset");
79+
vscale = vscaleOp.getResult();
80+
}
81+
82+
/// The minimum possible value of vscale.
83+
unsigned getVscaleMin() const { return vscaleMin; }
84+
85+
/// The maximum possible value of vscale.
86+
unsigned getVscaleMax() const { return vscaleMax; }
87+
88+
static char ID;
89+
90+
private:
91+
const unsigned vscaleMin;
92+
const unsigned vscaleMax;
93+
94+
// This will be set when the first `vector.vscale` operation is found within
95+
// the `ValueBoundsOpInterface` implementation then reused from there on.
96+
Value vscale = nullptr;
97+
};
98+
99+
using ConstantOrScalableBound =
100+
ScalableValueBoundsConstraintSet::ConstantOrScalableBound;
101+
102+
} // namespace mlir::vector
103+
104+
#endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace vector {
16+
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace vector
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
8383
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
8484
#include "mlir/Dialect/UB/IR/UBOps.h"
85+
#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
8586
#include "mlir/Dialect/Vector/IR/VectorOps.h"
8687
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
8788
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
@@ -174,6 +175,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
174175
tosa::registerShardingInterfaceExternalModels(registry);
175176
vector::registerBufferizableOpInterfaceExternalModels(registry);
176177
vector::registerSubsetOpInterfaceExternalModels(registry);
178+
vector::registerValueBoundsOpInterfaceExternalModels(registry);
177179
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
178180
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
179181
spirv::registerSPIRVTargetInterfaceExternalModels(registry);

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/Value.h"
1616
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1717
#include "llvm/ADT/SetVector.h"
18+
#include "llvm/Support/ExtensibleRTTI.h"
1819

1920
#include <queue>
2021

@@ -63,7 +64,8 @@ using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
6364
///
6465
/// Note: Any modification of existing IR invalides the data stored in this
6566
/// class. Adding new operations is allowed.
66-
class ValueBoundsConstraintSet {
67+
class ValueBoundsConstraintSet
68+
: public llvm::RTTIExtends<ValueBoundsConstraintSet, llvm::RTTIRoot> {
6769
protected:
6870
/// Helper class that builds a bound for a shaped value dimension or
6971
/// index-typed value.
@@ -107,6 +109,8 @@ class ValueBoundsConstraintSet {
107109
};
108110

109111
public:
112+
static char ID;
113+
110114
/// The stop condition when traversing the backward slice of a shaped value/
111115
/// index-type value. The traversal continues until the stop condition
112116
/// evaluates to "true" for a value.
@@ -265,6 +269,16 @@ class ValueBoundsConstraintSet {
265269

266270
ValueBoundsConstraintSet(MLIRContext *ctx);
267271

272+
/// Populates the constraint set for a value/map without actually computing
273+
/// the bound. Returns the position for the value/map (via the return value
274+
/// and `posOut` output parameter).
275+
int64_t populateConstraintsSet(Value value,
276+
std::optional<int64_t> dim = std::nullopt,
277+
StopConditionFn stopCondition = nullptr);
278+
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
279+
StopConditionFn stopCondition = nullptr,
280+
int64_t *posOut = nullptr);
281+
268282
/// Iteratively process all elements on the worklist until an index-typed
269283
/// value or shaped value meets `stopCondition`. Such values are not processed
270284
/// any further.

mlir/lib/Dialect/Vector/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
add_mlir_dialect_library(MLIRVectorDialect
22
VectorOps.cpp
3+
ValueBoundsOpInterfaceImpl.cpp
4+
ScalableValueBoundsConstraintSet.cpp
35

46
ADDITIONAL_HEADER_DIRS
57
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
10+
11+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12+
13+
namespace mlir::vector {
14+
15+
FailureOr<ConstantOrScalableBound::BoundSize>
16+
ConstantOrScalableBound::getSize() const {
17+
if (map.isSingleConstant())
18+
return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
19+
if (map.getNumResults() != 1 || map.getNumInputs() != 1)
20+
return failure();
21+
auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
22+
if (!binop || binop.getKind() != AffineExprKind::Mul)
23+
return failure();
24+
auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
25+
if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
26+
constant = cst.getValue();
27+
return true;
28+
}
29+
return false;
30+
};
31+
// Match `s0 * cst` or `cst * s0`:
32+
int64_t cst = 0;
33+
auto lhs = binop.getLHS();
34+
auto rhs = binop.getRHS();
35+
if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
36+
(matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
37+
return BoundSize{cst, /*scalable=*/true};
38+
}
39+
return failure();
40+
}
41+
42+
char ScalableValueBoundsConstraintSet::ID = 0;
43+
44+
FailureOr<ConstantOrScalableBound>
45+
ScalableValueBoundsConstraintSet::computeScalableBound(
46+
Value value, std::optional<int64_t> dim, unsigned vscaleMin,
47+
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
48+
StopConditionFn stopCondition) {
49+
using namespace presburger;
50+
51+
assert(vscaleMin <= vscaleMax);
52+
ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
53+
vscaleMax);
54+
55+
int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
56+
57+
// Project out all variables apart from vscale.
58+
// This should result in constraints in terms of vscale only.
59+
scalableCstr.projectOut(
60+
[&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
61+
62+
assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
63+
scalableCstr.positionToValueDim.size() &&
64+
"inconsistent mapping state");
65+
66+
// Check that the only symbols left are vscale.
67+
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
68+
if (i == pos)
69+
continue;
70+
if (scalableCstr.positionToValueDim[i] !=
71+
ValueDim(scalableCstr.getVscaleValue(),
72+
ValueBoundsConstraintSet::kIndexValue)) {
73+
return failure();
74+
}
75+
}
76+
77+
SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
78+
scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
79+
&upperBound, closedUB);
80+
81+
auto invalidBound = [](auto &bound) {
82+
return !bound[0] || bound[0].getNumResults() != 1;
83+
};
84+
85+
AffineMap bound = [&] {
86+
if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
87+
lowerBound[0] == lowerBound[0]) {
88+
return lowerBound[0];
89+
} else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
90+
return lowerBound[0];
91+
} else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
92+
return upperBound[0];
93+
}
94+
return AffineMap{};
95+
}();
96+
97+
if (!bound)
98+
return failure();
99+
100+
return ConstantOrScalableBound{bound};
101+
}
102+
103+
} // namespace mlir::vector
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
12+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
13+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
14+
15+
using namespace mlir;
16+
17+
namespace mlir::vector {
18+
namespace {
19+
20+
struct VectorScaleOpInterface
21+
: public ValueBoundsOpInterface::ExternalModel<VectorScaleOpInterface,
22+
VectorScaleOp> {
23+
void populateBoundsForIndexValue(Operation *op, Value value,
24+
ValueBoundsConstraintSet &cstr) const {
25+
auto *scalableCstr = dyn_cast<ScalableValueBoundsConstraintSet>(&cstr);
26+
if (!scalableCstr)
27+
return;
28+
auto vscaleOp = cast<VectorScaleOp>(op);
29+
assert(value == vscaleOp.getResult() && "invalid value");
30+
if (auto vscale = scalableCstr->getVscaleValue()) {
31+
// All copies of vscale are equivalent.
32+
scalableCstr->bound(value) == cstr.getExpr(vscale);
33+
} else {
34+
// We know vscale is confined to [vscaleMin, vscaleMax].
35+
scalableCstr->bound(value) >= scalableCstr->getVscaleMin();
36+
scalableCstr->bound(value) <= scalableCstr->getVscaleMax();
37+
scalableCstr->setVscale(vscaleOp);
38+
}
39+
}
40+
};
41+
42+
} // namespace
43+
} // namespace mlir::vector
44+
45+
void mlir::vector::registerValueBoundsOpInterfaceExternalModels(
46+
DialectRegistry &registry) {
47+
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
48+
vector::VectorScaleOp::attachInterface<vector::VectorScaleOpInterface>(
49+
*ctx);
50+
});
51+
}

0 commit comments

Comments
 (0)