Skip to content

[mlir][emitc] Add EmitC index types #93155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 17, 2024
Merged
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ bool isIntegerIndexOrOpaqueType(Type type);

/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);

/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);

} // namespace emitc
} // namespace mlir

Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
def CExpression : NativeOpTrait<"emitc::CExpression">;

// Types only used in binary arithmetic operations.
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
def IntegerIndexOrOpaqueType : Type<CPred<"emitc::isIntegerIndexOrOpaqueType($_self)">,
"integer, index or opaque type supported by EmitC">;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;

def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
Expand Down Expand Up @@ -470,7 +471,7 @@ def EmitC_ForOp : EmitC_Op<"for",
upper bound and step respectively, and defines an SSA value for its
induction variable. It has one region capturing the loop body. The induction
variable is represented as an argument of this region. This SSA value is a
signless integer or index. The step is a value of same type.
signless integer, or an index. The step is a value of same type.

This operation has no result. The body region must contain exactly one block
that terminates with `emitc.yield`. Calling ForOp::build will create such a
Expand Down
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {

static bool isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
emitc::isPointerWideType(type) ||
llvm::isa<PointerType, OpaqueType>(type);
}
}];
Expand Down Expand Up @@ -130,4 +131,31 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
let assemblyFormat = "`<` qualified($pointee) `>`";
}

def EmitC_SignedSizeT : EmitC_Type<"SignedSizeT", "ssize_t"> {
let summary = "EmitC signed size type";
let description = [{
Data type representing all values of `emitc.size_t`, plus -1.
It corresponds to `ssize_t` found in `<sys/types.h>`.

Use of this type causes the code to be non-C99 compliant.
}];
}

def EmitC_PtrDiffT : EmitC_Type<"PtrDiffT", "ptrdiff_t"> {
let summary = "EmitC signed pointer diff type";
let description = [{
Signed data type as wide as platform-specific pointer types.
In particular, it is as wide as `emitc.size_t`.
It corresponds to `ptrdiff_t` found in `<stddef.h>`.
}];
}

def EmitC_SizeT : EmitC_Type<"SizeT", "size_t"> {
let summary = "EmitC unsigned size type";
let description = [{
Unsigned data type as wide as platform-specific pointer types.
It corresponds to `size_t` found in `<stddef.h>`.
}];
}

#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- TypeConversions.h - Convert signless types into C/C++ types -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H
#define MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H

#include <optional>

namespace mlir {
class TypeConverter;
class Type;
void populateEmitCSizeTTypeConversions(TypeConverter &converter);

namespace emitc {
std::optional<Type> getUnsignedTypeFor(Type ty);
std::optional<Type> getSignedTypeFor(Type ty);
} // namespace emitc

} // namespace mlir

#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H
18 changes: 14 additions & 4 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ bool mlir::emitc::isSupportedEmitCType(Type type) {
return !llvm::isa<emitc::ArrayType>(elemType) &&
isSupportedEmitCType(elemType);
}
if (type.isIndex())
if (type.isIndex() || emitc::isPointerWideType(type))
return true;
if (llvm::isa<IntegerType>(type))
return isSupportedIntegerType(type);
Expand Down Expand Up @@ -110,7 +110,7 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {

bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
isSupportedIntegerType(type);
isSupportedIntegerType(type) || isPointerWideType(type);
}

bool mlir::emitc::isSupportedFloatType(Type type) {
Expand All @@ -126,6 +126,11 @@ bool mlir::emitc::isSupportedFloatType(Type type) {
return false;
}

bool mlir::emitc::isPointerWideType(Type type) {
return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
type);
}

/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
Expand All @@ -142,6 +147,9 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
Type resultType = op->getResult(0).getType();
Type attrType = cast<TypedAttr>(value).getType();

if (isPointerWideType(resultType) && attrType.isIndex())
return success();

if (resultType != attrType)
return op->emitOpError()
<< "requires attribute to either be an #emitc.opaque attribute or "
Expand Down Expand Up @@ -227,9 +235,11 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
Type input = inputs.front(), output = outputs.front();

return ((llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType>(input)) &&
emitc::PointerType, emitc::SignedSizeTType,
emitc::SizeTType, emitc::PtrDiffTType>(input)) &&
(llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
emitc::PointerType>(output)));
emitc::PointerType, emitc::SignedSizeTType,
emitc::SizeTType, emitc::PtrDiffTType>(output)));
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIREmitCTransforms
Transforms.cpp
FormExpressions.cpp
TypeConversions.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
Expand Down
64 changes: 64 additions & 0 deletions mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===- TypeConversions.cpp - Convert signless types into C/C++ types ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>

using namespace mlir;

namespace {

std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
Type resultType,
ValueRange inputs,
Location loc) {
if (inputs.size() != 1)
return std::nullopt;

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
}

} // namespace

void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
converter.addConversion(
[](IndexType type) { return emitc::SizeTType::get(type.getContext()); });

converter.addSourceMaterialization(materializeAsUnrealizedCast);
converter.addTargetMaterialization(materializeAsUnrealizedCast);
converter.addArgumentMaterialization(materializeAsUnrealizedCast);
}

/// Get an unsigned data type as wide as \p ty.
std::optional<Type> mlir::emitc::getUnsignedTypeFor(Type ty) {
if (ty.isInteger())
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
IntegerType::SignednessSemantics::Unsigned);
if (isa<emitc::PtrDiffTType, emitc::SignedSizeTType>(ty))
return emitc::SizeTType::get(ty.getContext());
if (isSupportedIntegerType(ty))
return ty;
return {};
}

/// Get a signed data type as wide as \p ty that supports arithmetic on negative
/// values.
std::optional<Type> mlir::emitc::getSignedTypeFor(Type ty) {
if (ty.isInteger())
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
IntegerType::SignednessSemantics::Signed);
if (isa<emitc::SizeTType>(ty))
return emitc::PtrDiffTType::get(ty.getContext());
if (isSupportedIntegerType(ty))
return ty;
return {};
}
6 changes: 6 additions & 0 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,12 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
}
if (auto iType = dyn_cast<IndexType>(type))
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SizeTType>(type))
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
return (os << "ssize_t"), success();
if (auto sType = dyn_cast<emitc::PtrDiffTType>(type))
return (os << "ptrdiff_t"), success();
if (auto tType = dyn_cast<TensorType>(type)) {
if (!tType.hasRank())
return emitError(loc, "cannot emit unranked tensor type");
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,31 +170,31 @@ func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
// -----

func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// expected-error @+1 {{'emitc.div' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
// expected-error @+1 {{'emitc.div' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'tensor<i32>'}}
%1 = "emitc.div" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}

// -----

func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'tensor<i32>'}}
%1 = "emitc.mul" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}

// -----

func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer, index or opaque type supported by EmitC, but got 'tensor<i32>'}}
%1 = "emitc.rem" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}

// -----

func.func @rem_float(%arg0: f32, %arg1: f32) {
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer type supported by EmitC or index or EmitC opaque type, but got 'f32'}}
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer, index or opaque type supported by EmitC, but got 'f32'}}
%1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32
return
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/EmitC/invalid_types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ func.func @illegal_array_with_tensor_element_type(
// -----

func.func @illegal_integer_type(%arg0: i11, %arg1: i11) -> i11 {
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'i11'}}
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'i11'}}
%mul = "emitc.mul" (%arg0, %arg1) : (i11, i11) -> i11
return
}

// -----

func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'f80'}}
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'f80'}}
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
return
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ func.func @cast(%arg0: i32) {

func.func @c() {
%1 = "emitc.constant"(){value = 42 : i32} : () -> i32
%2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t
%3 = "emitc.constant"(){value = 42 : index} : () -> !emitc.ssize_t
return
}

Expand Down
20 changes: 19 additions & 1 deletion mlir/test/Dialect/EmitC/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ func.func @array_types(
// CHECK-SAME: !emitc.array<30x!emitc.ptr<i32>>,
%arg2: !emitc.array<30x!emitc.ptr<i32>>,
// CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">>
%arg3: !emitc.array<30x!emitc.opaque<"int">>
%arg3: !emitc.array<30x!emitc.opaque<"int">>,
// CHECK-SAME: !emitc.array<30x!emitc.size_t>
%arg4: !emitc.array<30x!emitc.size_t>,
// CHECK-SAME: !emitc.array<30x!emitc.ssize_t>
%arg5: !emitc.array<30x!emitc.ssize_t>,
// CHECK-SAME: !emitc.array<30x!emitc.ptrdiff_t>
%arg6: !emitc.array<30x!emitc.ptrdiff_t>
) {
return
}
Expand Down Expand Up @@ -53,3 +59,15 @@ func.func @pointer_types() {

return
}

// CHECK-LABEL: func @size_types()
func.func @size_types() {
// CHECK-NEXT: !emitc.ssize_t
emitc.call_opaque "f"() {template_args = [!emitc.ssize_t]} : () -> ()
// CHECK-NEXT: !emitc.size_t
emitc.call_opaque "f"() {template_args = [!emitc.size_t]} : () -> ()
// CHECK-NEXT: !emitc.ptrdiff_t
emitc.call_opaque "f"() {template_args = [!emitc.ptrdiff_t]} : () -> ()

return
}
12 changes: 12 additions & 0 deletions mlir/test/Target/Cpp/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,15 @@ func.func @ptr_types() {

return
}

// CHECK-LABEL: void size_types() {
func.func @size_types() {
// CHECK-NEXT: f<ssize_t>();
emitc.call_opaque "f"() {template_args = [!emitc.ssize_t]} : () -> ()
// CHECK-NEXT: f<size_t>();
emitc.call_opaque "f"() {template_args = [!emitc.size_t]} : () -> ()
// CHECK-NEXT: f<ptrdiff_t>();
emitc.call_opaque "f"() {template_args = [!emitc.ptrdiff_t]} : () -> ()

return
}
Loading