Skip to content

Commit d889b08

Browse files
committed
Add shift operations
1 parent f64e96a commit d889b08

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

+88
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/Region.h"
2122
#include "mlir/Support/LogicalResult.h"
2223
#include "mlir/Transforms/DialectConversion.h"
2324

@@ -478,6 +479,90 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
478479
}
479480
};
480481

482+
template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
483+
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
484+
public:
485+
using OpConversionPattern<ArithOp>::OpConversionPattern;
486+
487+
LogicalResult
488+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
489+
ConversionPatternRewriter &rewriter) const override {
490+
491+
Type type = this->getTypeConverter()->convertType(op.getType());
492+
if (type && !(isa_and_nonnull<IntegerType>(type) ||
493+
emitc::isPointerWideType(type))) {
494+
return rewriter.notifyMatchFailure(
495+
op, "expected integer or size_t/ssize_t type");
496+
}
497+
498+
if (type.isInteger(1)) {
499+
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
500+
}
501+
502+
Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
503+
504+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
505+
// Shift amount interpreted as unsigned per Arith dialect spec.
506+
Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
507+
/*needsUnsigned=*/true);
508+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
509+
510+
// Add a runtime check for overflow
511+
Value width;
512+
if (emitc::isPointerWideType(type)) {
513+
Value eight = rewriter.create<emitc::ConstantOp>(
514+
op.getLoc(), rhsType, rewriter.getIndexAttr(8));
515+
emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
516+
op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
517+
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
518+
sizeOfCall.getResult(0));
519+
} else {
520+
width = rewriter.create<emitc::ConstantOp>(
521+
op.getLoc(), rhsType,
522+
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
523+
}
524+
525+
Value excessCheck = rewriter.create<emitc::CmpOp>(
526+
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
527+
528+
// Any concrete value is a valid refinement of poison.
529+
Value poison = rewriter.create<emitc::ConstantOp>(
530+
op.getLoc(), arithmeticType,
531+
(isa<IntegerType>(arithmeticType)
532+
? rewriter.getIntegerAttr(arithmeticType, 0)
533+
: rewriter.getIndexAttr(0)));
534+
535+
emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
536+
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
537+
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
538+
auto currentPoint = rewriter.getInsertionPoint();
539+
rewriter.setInsertionPointToStart(&bodyBlock);
540+
Value arithmeticResult =
541+
rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
542+
Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
543+
op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
544+
rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
545+
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
546+
547+
Value result = adaptValueType(ternary, rewriter, type);
548+
549+
rewriter.replaceOp(op, result);
550+
return success();
551+
}
552+
};
553+
554+
template <typename ArithOp, typename EmitCOp>
555+
class SignedShiftOpConversion final
556+
: public ShiftOpConversion<ArithOp, EmitCOp, false> {
557+
using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
558+
};
559+
560+
template <typename ArithOp, typename EmitCOp>
561+
class UnsignedShiftOpConversion final
562+
: public ShiftOpConversion<ArithOp, EmitCOp, true> {
563+
using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
564+
};
565+
481566
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
482567
public:
483568
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -619,6 +704,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
619704
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
620705
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
621706
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
707+
UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
708+
SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
709+
UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
622710
CmpFOpConversion,
623711
CmpIOpConversion,
624712
SelectOpConversion,

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

+24
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
8686
%idx = arith.extsi %arg0 : i1 to i32
8787
return
8888
}
89+
90+
// -----
91+
92+
func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
93+
// expected-error @+1 {{failed to legalize operation 'arith.shli'}}
94+
%shli = arith.shli %arg0, %arg1 : i1
95+
return
96+
}
97+
98+
// -----
99+
100+
func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
101+
// expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
102+
%shrsi = arith.shrsi %arg0, %arg1 : i1
103+
return
104+
}
105+
106+
// -----
107+
108+
func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
109+
// expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
110+
%shrui = arith.shrui %arg0, %arg1 : i1
111+
return
112+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

+110
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,116 @@ func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) {
144144

145145
// -----
146146

147+
// CHECK-LABEL: arith_shift_left
148+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
149+
func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
150+
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
151+
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
152+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
153+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
154+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
155+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
156+
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
157+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
158+
// CHECK: emitc.yield %[[Ternary]] : ui32
159+
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
160+
%1 = arith.shli %arg0, %arg1 : i32
161+
return
162+
}
163+
164+
// -----
165+
166+
// CHECK-LABEL: arith_shift_right
167+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
168+
func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
169+
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
170+
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
171+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
172+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
173+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
174+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
175+
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
176+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
177+
// CHECK: emitc.yield %[[Ternary]] : ui32
178+
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
179+
%2 = arith.shrui %arg0, %arg1 : i32
180+
181+
// CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
182+
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
183+
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
184+
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
185+
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
186+
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
187+
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
188+
// CHECK: emitc.yield %[[STernary]] : i32
189+
%3 = arith.shrsi %arg0, %arg1 : i32
190+
191+
return
192+
}
193+
194+
// -----
195+
196+
// CHECK-LABEL: arith_shift_left_index
197+
// CHECK-SAME: %[[AMOUNT:.*]]: i32
198+
func.func @arith_shift_left_index(%amount: i32) {
199+
%cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
200+
%cast1 = arith.index_cast %amount : i32 to index
201+
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
202+
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
203+
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
204+
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
205+
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
206+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
207+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
208+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
209+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
210+
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
211+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
212+
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
213+
%1 = arith.shli %cst0, %cast1 : index
214+
return
215+
}
216+
217+
// -----
218+
219+
// CHECK-LABEL: arith_shift_right_index
220+
// CHECK-SAME: %[[AMOUNT:.*]]: i32
221+
func.func @arith_shift_right_index(%amount: i32) {
222+
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
223+
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
224+
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
225+
%arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
226+
%arg1 = arith.index_cast %amount : i32 to index
227+
228+
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
229+
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
230+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
231+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
232+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
233+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
234+
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
235+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
236+
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
237+
%2 = arith.shrui %arg0, %arg1 : index
238+
239+
// CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
240+
// CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
241+
// CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
242+
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
243+
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
244+
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
245+
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
246+
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
247+
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
248+
// CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
249+
// CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
250+
%3 = arith.shrsi %arg0, %arg1 : index
251+
252+
return
253+
}
254+
255+
// -----
256+
147257
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
148258
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
149259
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

0 commit comments

Comments
 (0)