Skip to content

mlir: add an operation to EmitC for function template instantiation #100895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1260,5 +1260,32 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}

def EmitC_InstantiateFunctionTemplateOp : EmitC_Op<"instantiate_function_template", []> {
let summary = "Instantiate template operation";
let description = [{
Instantiate a function template with a given set of types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give an example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

(given by the values as argument to this operation) to obtain
a function pointer.

Example:

```mlir
%c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
%0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
```
Translates to the C++:
```c++
int32_t v1 = 7;
void* v2 = &func_template<decltype(v1)>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we separate the application of template instantiation from taking its address (which emitc has an op for)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could, but right now there isn't a way to use the returned value except taking a pointer of it, as call excepts a SymbolRefAttr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add an call_indirect op, similar to the func dialect. But we would need to also add support for function types in the emitter (or add an generic auto type, but that might have problems with wrongly inferred types if used in other places).

OTOH, is it even legal to cast function pointers to void *?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH, is it even legal to cast function pointers to void *?

I think it's technically undefined because some kinds of hardware cannot support such a cast. However, the runtime library (CUTLASS) I'm trying to interact with only accepts void* function pointers for a particular function, so its not invalid either.

```
}];
let arguments = (ins
Arg<StrAttr, "the C++ function to instantiate">:$callee,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to derive the type from an SSA value? Why not instead use an ArrayAttr (of TypeAttr) directly here? Is there really a use case where it is difficult to infer that information? When you create this operation, you would could just have a builder with an API like b.create<InstatiateFunctionTemplateOp>(loc, TypeRange(args)) in order to create an ArrayAttr containing all the types.

Copy link
Contributor Author

@rohany rohany Jul 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where the types are known inside of emitc, this operation is not actually needed (use of verbatim and string formatting with a TypeRange would suffice). In the case that I'm considering, we need to derive the types from the SSA values themselves. For example, consider generating code using the CuTe C++ library. After a few layout transformations, only the C++ compiler knows the exact types of a CuTe object. In such a case, a decltype is needed to correctly instantiate the function template.

Variadic<EmitCType>:$args
);
let results = (outs EmitC_PointerType);
let assemblyFormat = "$callee `<` $args `>` attr-dict `:` functional-type($args, results)";
}


#endif // MLIR_DIALECT_EMITC_IR_EMITC
26 changes: 26 additions & 0 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,30 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}

static LogicalResult
printOperation(CppEmitter &emitter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are missing a test for this translation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do those tests live? I grepped around but couldnt find where the others are.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are under llvm-project/mlir/test/Target/Cpp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks

emitc::InstantiateFunctionTemplateOp instOp) {

raw_ostream &os = emitter.ostream();
Operation &op = *instOp.getOperation();

if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << "&" << instOp.getCallee() << "<";

auto emitArgs = [&](mlir::Value val) -> LogicalResult {
os << "decltype(";
if (failed(emitter.emitOperand(val)))
return failure();
os << ")";
return success();
};
if (failed(interleaveCommaWithError(instOp.getArgs(), os, emitArgs)))
return failure();
os << ">";
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ApplyOp applyOp) {
raw_ostream &os = emitter.ostream();
Expand Down Expand Up @@ -1508,6 +1532,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
[&](auto op) { return printOperation(*this, op); })
.Case<emitc::LiteralOp>([&](auto op) { return success(); })
.Case<emitc::InstantiateFunctionTemplateOp>(
[&](auto op) { return printOperation(*this, op); })
.Default([&](Operation *) {
return op.emitOpError("unable to find printer for op");
});
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>
return
}

func.func @test_instantiate_template() {
%c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
%0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
return
}

emitc.verbatim "#ifdef __cplusplus"
emitc.verbatim "extern \"C\" {"
emitc.verbatim "#endif // __cplusplus"
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Target/Cpp/instantiate_function_template.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP

func.func @emitc_instantiate_template() {
%c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32
%0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">>
return
}
// CPP-DEFAULT: void emitc_instantiate_template() {
// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = 7;
// CPP-DEFAULT-NEXT: void* [[V1:[^ ]*]] = &func_template<decltype([[V0]])>;

// CPP-DECLTOP: void emitc_instantiate_template() {
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
// CPP-DECLTOP-NEXT: void* [[V1:[^ ]*]];
// CPP-DECLTOP-NEXT: [[V0]] = 7;
// CPP-DECLTOP-NEXT: [[V1]] = &func_template<decltype([[V0]])>;
Loading