Skip to content

Commit e692af8

Browse files
authored
[MLIR] Update APInt construction to correctly set isSigned/implicitTrunc (llvm#110466)
This fixes all the places in MLIR that hit the new assertion added in llvm#106524, in preparation for enabling it by default. That is, cases where the value passed to the APInt constructor is not an N-bit signed/unsigned integer, where N is the bit width and signedness is determined by the isSigned flag. The fixes either set the correct value for isSigned, or set the implicitTrunc flag to retain the old behavior. I've left TODOs for the latter case in some places, where I think that it may be worthwhile to stop doing implicit truncation in the future. Note that the assertion is currently still disabled by default, so this patch is mostly NFC. This is just the MLIR changes split off from llvm#80309.
1 parent 8b20f1b commit e692af8

File tree

9 files changed

+24
-12
lines changed

9 files changed

+24
-12
lines changed

mlir/include/mlir/IR/BuiltinAttributes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
701701
return $_get(type.getContext(), type, apValue);
702702
}
703703

704+
// TODO: Avoid implicit trunc?
704705
IntegerType intTy = ::llvm::cast<IntegerType>(type);
705-
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
706+
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
707+
/*implicitTrunc=*/true);
706708
return $_get(type.getContext(), type, apValue);
707709
}]>
708710
];

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,8 @@ class AsmParser {
749749
// zero for non-negated integers.
750750
result =
751751
(IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
752-
if (APInt(uintResult.getBitWidth(), result) != uintResult)
752+
if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
753+
/*implicitTrunc=*/true) != uintResult)
753754
return emitError(loc, "integer value too large");
754755
return success();
755756
}

mlir/lib/Conversion/TosaToArith/TosaToArith.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
4343
TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
4444
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
4545
Type eTy = shapedTy.getElementType();
46-
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
46+
APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
4747
return DenseIntElementsAttr::get(shapedTy, valueInt);
4848
}
4949

mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
528528
int64_t value = 0;
529529
if (failed(parser.parseInteger(value)))
530530
return failure();
531-
values.push_back(APInt(bitWidth, value));
531+
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
532532

533533
Block *destination;
534534
SmallVector<OpAsmParser::UnresolvedOperand> operands;

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
598598
int64_t value = 0;
599599
if (failed(parser.parseInteger(value)))
600600
return failure();
601-
values.push_back(APInt(bitWidth, value));
601+
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
602602

603603
Block *destination;
604604
SmallVector<OpAsmParser::UnresolvedOperand> operands;

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
14041404
if (parser.parseInteger(value))
14051405
return failure();
14061406
shapeTmp++;
1407-
values.push_back(APInt(32, value));
1407+
values.push_back(APInt(32, value, /*isSigned=*/true));
14081408
return success();
14091409
};
14101410

mlir/lib/IR/Builders.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
238238
}
239239

240240
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
241-
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
241+
// The APInt always uses isSigned=true here because we accept the value
242+
// as int32_t.
243+
return IntegerAttr::get(getIntegerType(32),
244+
APInt(32, value, /*isSigned=*/true));
242245
}
243246

244247
IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
@@ -256,14 +259,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
256259
}
257260

258261
IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
259-
return IntegerAttr::get(getIntegerType(8), APInt(8, value));
262+
// The APInt always uses isSigned=true here because we accept the value
263+
// as int8_t.
264+
return IntegerAttr::get(getIntegerType(8),
265+
APInt(8, value, /*isSigned=*/true));
260266
}
261267

262268
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
263269
if (type.isIndex())
264270
return IntegerAttr::get(type, APInt(64, value));
265-
return IntegerAttr::get(
266-
type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
271+
// TODO: Avoid implicit trunc?
272+
return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
273+
type.isSignedInteger(),
274+
/*implicitTrunc=*/true));
267275
}
268276

269277
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,8 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
12871287
} words = {operands[2], operands[3]};
12881288
value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
12891289
} else if (bitwidth <= 32) {
1290-
value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1290+
value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1291+
/*implicitTrunc=*/true);
12911292
}
12921293

12931294
auto attr = opBuilder.getIntegerAttr(intType, value);

mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
176176
IntegerType::get(&context, 16, IntegerType::Signless);
177177
auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
178178
// Check the bit extension of same value under different signedness semantics.
179-
APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
179+
APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
180180
signlessInt16Type.getSignedness());
181181
APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
182182
signedInt16Type.getSignedness());

0 commit comments

Comments
 (0)