Skip to content

Commit 18e7dcb

Browse files
authored
[mlir][emitc] Arith to EmitC: handle floating-point<->integer conversions (#87614)
Add support for floating-point to integer, integer to floating-point conversions. Floating point conversions to 1-bit integer types are not handled at the moment, as these don't map directly to boolean conversions.
1 parent 1d43cdc commit 18e7dcb

File tree

3 files changed

+196
-1
lines changed

3 files changed

+196
-1
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,96 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
201201
}
202202
};
203203

204+
// Floating-point to integer conversions.
205+
template <typename CastOp>
206+
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
207+
public:
208+
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
209+
: OpConversionPattern<CastOp>(typeConverter, context) {}
210+
211+
LogicalResult
212+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
213+
ConversionPatternRewriter &rewriter) const override {
214+
215+
Type operandType = adaptor.getIn().getType();
216+
if (!emitc::isSupportedFloatType(operandType))
217+
return rewriter.notifyMatchFailure(castOp,
218+
"unsupported cast source type");
219+
220+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
221+
if (!dstType)
222+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
223+
224+
// Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
225+
// truncated to 0, whereas a boolean conversion would return true.
226+
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
227+
return rewriter.notifyMatchFailure(castOp,
228+
"unsupported cast destination type");
229+
230+
// Convert to unsigned if it's the "ui" variant
231+
// Signless is interpreted as signed, so no need to cast for "si"
232+
Type actualResultType = dstType;
233+
if (isa<arith::FPToUIOp>(castOp)) {
234+
actualResultType =
235+
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
236+
/*isSigned=*/false);
237+
}
238+
239+
Value result = rewriter.create<emitc::CastOp>(
240+
castOp.getLoc(), actualResultType, adaptor.getOperands());
241+
242+
if (isa<arith::FPToUIOp>(castOp)) {
243+
result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
244+
}
245+
rewriter.replaceOp(castOp, result);
246+
247+
return success();
248+
}
249+
};
250+
251+
// Integer to floating-point conversions.
252+
template <typename CastOp>
253+
class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
254+
public:
255+
ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
256+
: OpConversionPattern<CastOp>(typeConverter, context) {}
257+
258+
LogicalResult
259+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
260+
ConversionPatternRewriter &rewriter) const override {
261+
// Vectors in particular are not supported
262+
Type operandType = adaptor.getIn().getType();
263+
if (!emitc::isSupportedIntegerType(operandType))
264+
return rewriter.notifyMatchFailure(castOp,
265+
"unsupported cast source type");
266+
267+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
268+
if (!dstType)
269+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
270+
271+
if (!emitc::isSupportedFloatType(dstType))
272+
return rewriter.notifyMatchFailure(castOp,
273+
"unsupported cast destination type");
274+
275+
// Convert to unsigned if it's the "ui" variant
276+
// Signless is interpreted as signed, so no need to cast for "si"
277+
Type actualOperandType = operandType;
278+
if (isa<arith::UIToFPOp>(castOp)) {
279+
actualOperandType =
280+
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
281+
/*isSigned=*/false);
282+
}
283+
Value fpCastOperand = adaptor.getIn();
284+
if (actualOperandType != operandType) {
285+
fpCastOperand = rewriter.template create<emitc::CastOp>(
286+
castOp.getLoc(), actualOperandType, fpCastOperand);
287+
}
288+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
289+
290+
return success();
291+
}
292+
};
293+
204294
} // namespace
205295

206296
//===----------------------------------------------------------------------===//
@@ -222,7 +312,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
222312
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
223313
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
224314
CmpIOpConversion,
225-
SelectOpConversion
315+
SelectOpConversion,
316+
ItoFCastOpConversion<arith::SIToFPOp>,
317+
ItoFCastOpConversion<arith::UIToFPOp>,
318+
FtoICastOpConversion<arith::FPToSIOp>,
319+
FtoICastOpConversion<arith::FPToUIOp>
226320
>(typeConverter, ctx);
227321
// clang-format on
228322
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
2+
3+
func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
4+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
5+
%t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
6+
return %t: tensor<5xi32>
7+
}
8+
9+
// -----
10+
11+
func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
12+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
13+
%t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
14+
return %t: vector<5xi32>
15+
}
16+
17+
// -----
18+
19+
func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
20+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
21+
%t = arith.fptosi %arg0 : bf16 to i32
22+
return %t: i32
23+
}
24+
25+
// -----
26+
27+
func.func @arith_cast_f16(%arg0: f16) -> i32 {
28+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
29+
%t = arith.fptosi %arg0 : f16 to i32
30+
return %t: i32
31+
}
32+
33+
34+
// -----
35+
36+
func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
37+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
38+
%t = arith.sitofp %arg0 : i32 to bf16
39+
return %t: bf16
40+
}
41+
42+
// -----
43+
44+
func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
45+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
46+
%t = arith.sitofp %arg0 : i32 to f16
47+
return %t: f16
48+
}
49+
50+
// -----
51+
52+
func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 {
53+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
54+
%t = arith.fptosi %arg0 : f32 to i1
55+
return %t: i1
56+
}
57+
58+
// -----
59+
60+
func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
61+
// expected-error @+1 {{failed to legalize operation 'arith.fptoui'}}
62+
%t = arith.fptoui %arg0 : f32 to i1
63+
return %t: i1
64+
}
65+

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
141141

142142
return
143143
}
144+
145+
// -----
146+
147+
func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
148+
// CHECK: emitc.cast %arg0 : f32 to i32
149+
%0 = arith.fptosi %arg0 : f32 to i32
150+
151+
// CHECK: emitc.cast %arg1 : f64 to i32
152+
%1 = arith.fptosi %arg1 : f64 to i32
153+
154+
// CHECK: emitc.cast %arg0 : f32 to i16
155+
%2 = arith.fptosi %arg0 : f32 to i16
156+
157+
// CHECK: emitc.cast %arg1 : f64 to i16
158+
%3 = arith.fptosi %arg1 : f64 to i16
159+
160+
// CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
161+
// CHECK: emitc.cast %[[CAST0]] : ui32 to i32
162+
%4 = arith.fptoui %arg0 : f32 to i32
163+
164+
return
165+
}
166+
167+
func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
168+
// CHECK: emitc.cast %arg0 : i8 to f32
169+
%0 = arith.sitofp %arg0 : i8 to f32
170+
171+
// CHECK: emitc.cast %arg1 : i64 to f32
172+
%1 = arith.sitofp %arg1 : i64 to f32
173+
174+
// CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
175+
// CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
176+
%2 = arith.uitofp %arg0 : i8 to f32
177+
178+
return
179+
}

0 commit comments

Comments
 (0)