@@ -158,8 +158,61 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
158
158
return getTypeConversionFailure (rewriter, op, op->getResultTypes ().front ());
159
159
}
160
160
161
+ // TODO: Move to some common place?
162
+ static std::string getDecorationString (spirv::Decoration decor) {
163
+ return llvm::convertToSnakeFromCamelCase (stringifyDecoration (decor));
164
+ }
165
+
161
166
namespace {
162
167
168
+ // / Converts elementwise unary, binary and ternary arith operations to SPIR-V
169
+ // / operations. Op can potentially support overflow flags.
170
+ template <typename Op, typename SPIRVOp>
171
+ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
172
+ using OpConversionPattern<Op>::OpConversionPattern;
173
+
174
+ LogicalResult
175
+ matchAndRewrite (Op op, typename Op::Adaptor adaptor,
176
+ ConversionPatternRewriter &rewriter) const override {
177
+ assert (adaptor.getOperands ().size () <= 3 );
178
+ auto converter = this ->template getTypeConverter <SPIRVTypeConverter>();
179
+ Type dstType = converter->convertType (op.getType ());
180
+ if (!dstType) {
181
+ return rewriter.notifyMatchFailure (
182
+ op->getLoc (),
183
+ llvm::formatv (" failed to convert type {0} for SPIR-V" , op.getType ()));
184
+ }
185
+
186
+ if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
187
+ !getElementTypeOrSelf (op.getType ()).isIndex () &&
188
+ dstType != op.getType ()) {
189
+ return op.emitError (" bitwidth emulation is not implemented yet on "
190
+ " unsigned op pattern version" );
191
+ }
192
+
193
+ auto overflowFlags = arith::IntegerOverflowFlags::none;
194
+ if (auto overflowIface =
195
+ dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
196
+ if (converter->getTargetEnv ().allows (
197
+ spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
198
+ overflowFlags = overflowIface.getOverflowAttr ().getValue ();
199
+ }
200
+
201
+ auto newOp = rewriter.template replaceOpWithNewOp <SPIRVOp>(
202
+ op, dstType, adaptor.getOperands ());
203
+
204
+ if (bitEnumContainsAny (overflowFlags, arith::IntegerOverflowFlags::nsw))
205
+ newOp->setAttr (getDecorationString (spirv::Decoration::NoSignedWrap),
206
+ rewriter.getUnitAttr ());
207
+
208
+ if (bitEnumContainsAny (overflowFlags, arith::IntegerOverflowFlags::nuw))
209
+ newOp->setAttr (getDecorationString (spirv::Decoration::NoUnsignedWrap),
210
+ rewriter.getUnitAttr ());
211
+
212
+ return success ();
213
+ }
214
+ };
215
+
163
216
// ===----------------------------------------------------------------------===//
164
217
// ConstantOp
165
218
// ===----------------------------------------------------------------------===//
@@ -1154,9 +1207,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
1154
1207
patterns.add <
1155
1208
ConstantCompositeOpPattern,
1156
1209
ConstantScalarOpPattern,
1157
- spirv::ElementwiseOpPattern <arith::AddIOp, spirv::IAddOp>,
1158
- spirv::ElementwiseOpPattern <arith::SubIOp, spirv::ISubOp>,
1159
- spirv::ElementwiseOpPattern <arith::MulIOp, spirv::IMulOp>,
1210
+ ElementwiseArithOpPattern <arith::AddIOp, spirv::IAddOp>,
1211
+ ElementwiseArithOpPattern <arith::SubIOp, spirv::ISubOp>,
1212
+ ElementwiseArithOpPattern <arith::MulIOp, spirv::IMulOp>,
1160
1213
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
1161
1214
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
1162
1215
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
0 commit comments