Skip to content

Commit 7f0ab5e

Browse files
committed
Refactor ArithToEmitC: adaptIntegralTypeSignedness
1 parent 3cead57 commit 7f0ab5e

File tree

4 files changed

+40
-52
lines changed

4 files changed

+40
-52
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
288288
let arguments = (ins EmitCType:$source);
289289
let results = (outs EmitCType:$dest);
290290
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
291+
let hasFolder = 1;
291292
}
292293

293294
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
270270

271271
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
272272
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
273-
Type arithmeticType = type;
274-
if (type.isUnsignedInteger() != needsUnsigned) {
275-
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
276-
/*isSigned=*/!needsUnsigned);
277-
}
278-
Value lhs = adaptor.getLhs();
279-
Value rhs = adaptor.getRhs();
280-
if (arithmeticType != type) {
281-
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
282-
lhs);
283-
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
284-
rhs);
285-
}
273+
274+
Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
275+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
276+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
277+
286278
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
287279
return success();
288280
}
@@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
328320
return success();
329321
}
330322

331-
bool isTruncation = operandType.getIntOrFloatBitWidth() >
332-
opReturnType.getIntOrFloatBitWidth();
323+
bool isTruncation =
324+
(isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
325+
operandType.getIntOrFloatBitWidth() >
326+
opReturnType.getIntOrFloatBitWidth());
333327
bool doUnsigned = castToUnsigned || isTruncation;
334328

335-
Type castType = opReturnType;
336-
// If the op is a ui variant and the type wanted as
337-
// return type isn't unsigned, we need to issue an unsigned type to do
338-
// the conversion.
339-
if (castType.isUnsignedInteger() != doUnsigned) {
340-
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
341-
/*isSigned=*/!doUnsigned);
342-
}
329+
// Adapt the signedness of the result (bitwidth-preserving cast)
330+
// This is needed e.g., if the return type is signless.
331+
Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
343332

344-
Value actualOp = adaptor.getIn();
345-
// Adapt the signedness of the operand if necessary
346-
if (operandType.isUnsignedInteger() != doUnsigned) {
347-
Type correctSignednessType =
348-
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
349-
/*isSigned=*/!doUnsigned);
350-
actualOp = rewriter.template create<emitc::CastOp>(
351-
op.getLoc(), correctSignednessType, actualOp);
352-
}
333+
// Adapt the signedness of the operand (bitwidth-preserving cast)
334+
Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
335+
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
353336

354-
auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
355-
actualOp);
337+
// Actual cast (may change bitwidth)
338+
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
339+
castDestType, actualOp);
356340

357341
// Cast to the expected output type
358-
if (castType != opReturnType) {
359-
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
360-
opReturnType, result);
361-
}
342+
auto result = adaptValueType(cast, rewriter, opReturnType);
362343

363344
rewriter.replaceOp(op, result);
364345
return success();
@@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
410391
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
411392
}
412393

413-
Value lhs = adaptor.getLhs();
414-
Value rhs = adaptor.getRhs();
415394
Type arithmeticType = type;
416395
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
417396
!bitEnumContainsAll(op.getOverflowFlags(),
@@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
421400
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
422401
/*isSigned=*/false);
423402
}
424-
if (arithmeticType != type) {
425-
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
426-
lhs);
427-
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
428-
rhs);
429-
}
430403

431-
Value result = rewriter.template create<EmitCOp>(op.getLoc(),
432-
arithmeticType, lhs, rhs);
404+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
405+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
406+
407+
Value arithmeticResult = rewriter.template create<EmitCOp>(
408+
op.getLoc(), arithmeticType, lhs, rhs);
409+
410+
Value result = adaptValueType(arithmeticResult, rewriter, type);
433411

434-
if (arithmeticType != type) {
435-
result =
436-
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
437-
}
438412
rewriter.replaceOp(op, result);
439413
return success();
440414
}

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
241241
emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
242242
}
243243

244+
OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
245+
if (getOperand().getType() == getResult().getType())
246+
return getOperand();
247+
return nullptr;
248+
}
249+
244250
//===----------------------------------------------------------------------===//
245251
// CallOpaqueOp
246252
//===----------------------------------------------------------------------===//

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
466466
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
467467
%truncd = arith.trunci %arg0 : i32 to i8
468468

469+
// CHECK: %[[Const:.*]] = "emitc.constant"
470+
// CHECK-SAME: value = 1
471+
// CHECK-SAME: () -> i32
472+
// CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
473+
// CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
474+
%bool = arith.trunci %arg0 : i32 to i1
475+
469476
return %truncd : i8
470477
}
471478

0 commit comments

Comments
 (0)