Skip to content

Commit 519175c

Browse files
authored
[mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place (#95789)
Factor EmitC type signedness adaptation and cast operations in ArithToEmitC using adaptValueType and adaptIntegralTypeSignedness.
1 parent 8af8602 commit 519175c

File tree

2 files changed

+33
-52
lines changed

2 files changed

+33
-52
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
270270

271271
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
272272
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
273-
Type arithmeticType = type;
274-
if (type.isUnsignedInteger() != needsUnsigned) {
275-
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
276-
/*isSigned=*/!needsUnsigned);
277-
}
278-
Value lhs = adaptor.getLhs();
279-
Value rhs = adaptor.getRhs();
280-
if (arithmeticType != type) {
281-
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
282-
lhs);
283-
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
284-
rhs);
285-
}
273+
274+
Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
275+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
276+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
277+
286278
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
287279
return success();
288280
}
@@ -356,37 +348,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
356348
return success();
357349
}
358350

359-
bool isTruncation = operandType.getIntOrFloatBitWidth() >
360-
opReturnType.getIntOrFloatBitWidth();
351+
bool isTruncation =
352+
(isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
353+
operandType.getIntOrFloatBitWidth() >
354+
opReturnType.getIntOrFloatBitWidth());
361355
bool doUnsigned = castToUnsigned || isTruncation;
362356

363-
Type castType = opReturnType;
364-
// If the op is a ui variant and the type wanted as
365-
// return type isn't unsigned, we need to issue an unsigned type to do
366-
// the conversion.
367-
if (castType.isUnsignedInteger() != doUnsigned) {
368-
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
369-
/*isSigned=*/!doUnsigned);
370-
}
357+
// Adapt the signedness of the result (bitwidth-preserving cast)
358+
// This is needed e.g., if the return type is signless.
359+
Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
371360

372-
Value actualOp = adaptor.getIn();
373-
// Adapt the signedness of the operand if necessary
374-
if (operandType.isUnsignedInteger() != doUnsigned) {
375-
Type correctSignednessType =
376-
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
377-
/*isSigned=*/!doUnsigned);
378-
actualOp = rewriter.template create<emitc::CastOp>(
379-
op.getLoc(), correctSignednessType, actualOp);
380-
}
361+
// Adapt the signedness of the operand (bitwidth-preserving cast)
362+
Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
363+
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
381364

382-
auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
383-
actualOp);
365+
// Actual cast (may change bitwidth)
366+
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
367+
castDestType, actualOp);
384368

385369
// Cast to the expected output type
386-
if (castType != opReturnType) {
387-
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
388-
opReturnType, result);
389-
}
370+
auto result = adaptValueType(cast, rewriter, opReturnType);
390371

391372
rewriter.replaceOp(op, result);
392373
return success();
@@ -438,8 +419,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
438419
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
439420
}
440421

441-
Value lhs = adaptor.getLhs();
442-
Value rhs = adaptor.getRhs();
443422
Type arithmeticType = type;
444423
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
445424
!bitEnumContainsAll(op.getOverflowFlags(),
@@ -449,20 +428,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
449428
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
450429
/*isSigned=*/false);
451430
}
452-
if (arithmeticType != type) {
453-
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
454-
lhs);
455-
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
456-
rhs);
457-
}
458431

459-
Value result = rewriter.template create<EmitCOp>(op.getLoc(),
460-
arithmeticType, lhs, rhs);
432+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
433+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
434+
435+
Value arithmeticResult = rewriter.template create<EmitCOp>(
436+
op.getLoc(), arithmeticType, lhs, rhs);
437+
438+
Value result = adaptValueType(arithmeticResult, rewriter, type);
461439

462-
if (arithmeticType != type) {
463-
result =
464-
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
465-
}
466440
rewriter.replaceOp(op, result);
467441
return success();
468442
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
477477
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
478478
%truncd = arith.trunci %arg0 : i32 to i8
479479

480+
// CHECK: %[[Const:.*]] = "emitc.constant"
481+
// CHECK-SAME: value = 1
482+
// CHECK-SAME: () -> i32
483+
// CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
484+
// CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
485+
%bool = arith.trunci %arg0 : i32 to i1
486+
480487
return %truncd : i8
481488
}
482489

0 commit comments

Comments
 (0)