Skip to content

Commit 4278d9b

Browse files
authored
[mlir][spirv] Lower arith overflow flags to corresponding SPIR-V op decorations (#77714)
1 parent dd5ce45 commit 4278d9b

File tree

2 files changed

+96
-3
lines changed

2 files changed

+96
-3
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,61 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
158158
return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
159159
}
160160

161+
// TODO: Move to some common place?
162+
static std::string getDecorationString(spirv::Decoration decor) {
163+
return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
164+
}
165+
161166
namespace {
162167

168+
/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
169+
/// operations. Op can potentially support overflow flags.
170+
template <typename Op, typename SPIRVOp>
171+
struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
172+
using OpConversionPattern<Op>::OpConversionPattern;
173+
174+
LogicalResult
175+
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
176+
ConversionPatternRewriter &rewriter) const override {
177+
assert(adaptor.getOperands().size() <= 3);
178+
auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
179+
Type dstType = converter->convertType(op.getType());
180+
if (!dstType) {
181+
return rewriter.notifyMatchFailure(
182+
op->getLoc(),
183+
llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
184+
}
185+
186+
if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
187+
!getElementTypeOrSelf(op.getType()).isIndex() &&
188+
dstType != op.getType()) {
189+
return op.emitError("bitwidth emulation is not implemented yet on "
190+
"unsigned op pattern version");
191+
}
192+
193+
auto overflowFlags = arith::IntegerOverflowFlags::none;
194+
if (auto overflowIface =
195+
dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
196+
if (converter->getTargetEnv().allows(
197+
spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
198+
overflowFlags = overflowIface.getOverflowAttr().getValue();
199+
}
200+
201+
auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
202+
op, dstType, adaptor.getOperands());
203+
204+
if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
205+
newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
206+
rewriter.getUnitAttr());
207+
208+
if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
209+
newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
210+
rewriter.getUnitAttr());
211+
212+
return success();
213+
}
214+
};
215+
163216
//===----------------------------------------------------------------------===//
164217
// ConstantOp
165218
//===----------------------------------------------------------------------===//
@@ -1154,9 +1207,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
11541207
patterns.add<
11551208
ConstantCompositeOpPattern,
11561209
ConstantScalarOpPattern,
1157-
spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
1158-
spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
1159-
spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
1210+
ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1211+
ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1212+
ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
11601213
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
11611214
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
11621215
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
14071407
}
14081408

14091409
} // end module
1410+
1411+
// -----
1412+
1413+
module attributes {
1414+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel], [SPV_KHR_no_integer_wrap_decoration]>, #spirv.resource_limits<>>
1415+
} {
1416+
1417+
// CHECK-LABEL: @ops_flags
1418+
func.func @ops_flags(%arg0: i64, %arg1: i64) {
1419+
// CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap} : i64
1420+
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
1421+
// CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} {no_unsigned_wrap} : i64
1422+
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
1423+
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
1424+
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
1425+
return
1426+
}
1427+
1428+
} // end module
1429+
1430+
1431+
// -----
1432+
1433+
module attributes {
1434+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
1435+
} {
1436+
1437+
// No decorations should be generated is corresponding Extensions/Capabilities are missing
1438+
// CHECK-LABEL: @ops_flags
1439+
func.func @ops_flags(%arg0: i64, %arg1: i64) {
1440+
// CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} : i64
1441+
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
1442+
// CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} : i64
1443+
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
1444+
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
1445+
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
1446+
return
1447+
}
1448+
1449+
} // end module

0 commit comments

Comments
 (0)