-
Notifications
You must be signed in to change notification settings - Fork 13.6k
mlir: add an operation to EmitC for function template instantiation #100895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1260,5 +1260,32 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> { | |
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)"; | ||
} | ||
|
||
def EmitC_InstantiateFunctionTemplateOp : EmitC_Op<"instantiate_function_template", []> { | ||
let summary = "Instantiate template operation"; | ||
let description = [{ | ||
Instantiate a function template with a given set of types | ||
(given by the values as argument to this operation) to obtain | ||
a function pointer. | ||
|
||
Example: | ||
|
||
```mlir | ||
%c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32 | ||
%0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">> | ||
``` | ||
Translates to the C++: | ||
```c++ | ||
int32_t v1 = 7; | ||
void* v2 = &func_template<decltype(v1)>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we separate the application of template instantiation from taking its address (which emitc has an op for)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could, but right now there isn't a way to use the returned value except taking a pointer of it, as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can add an OTOH, is it even legal to cast function pointers to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think it's technically undefined because some kinds of hardware cannot support such a cast. However, the runtime library (CUTLASS) I'm trying to interact with only accepts void* function pointers for a particular function, so its not invalid either. |
||
``` | ||
}]; | ||
let arguments = (ins | ||
Arg<StrAttr, "the C++ function to instantiate">:$callee, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need to derive the type from an SSA value? Why not instead use an ArrayAttr (of TypeAttr) directly here? Is there really a use case where it is difficult to infer that information? When you create this operation, you would could just have a builder with an API like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the case where the types are known inside of emitc, this operation is not actually needed (use of verbatim and string formatting with a |
||
Variadic<EmitCType>:$args | ||
); | ||
let results = (outs EmitC_PointerType); | ||
let assemblyFormat = "$callee `<` $args `>` attr-dict `:` functional-type($args, results)"; | ||
} | ||
|
||
|
||
#endif // MLIR_DIALECT_EMITC_IR_EMITC |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -656,6 +656,30 @@ static LogicalResult printOperation(CppEmitter &emitter, | |
return success(); | ||
} | ||
|
||
static LogicalResult | ||
printOperation(CppEmitter &emitter, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are missing a test for this translation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where do those tests live? I grepped around but couldnt find where the others are. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests are under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added, thanks |
||
emitc::InstantiateFunctionTemplateOp instOp) { | ||
|
||
raw_ostream &os = emitter.ostream(); | ||
Operation &op = *instOp.getOperation(); | ||
|
||
if (failed(emitter.emitAssignPrefix(op))) | ||
return failure(); | ||
os << "&" << instOp.getCallee() << "<"; | ||
|
||
auto emitArgs = [&](mlir::Value val) -> LogicalResult { | ||
os << "decltype("; | ||
if (failed(emitter.emitOperand(val))) | ||
return failure(); | ||
os << ")"; | ||
return success(); | ||
}; | ||
if (failed(interleaveCommaWithError(instOp.getArgs(), os, emitArgs))) | ||
return failure(); | ||
os << ">"; | ||
return success(); | ||
} | ||
|
||
static LogicalResult printOperation(CppEmitter &emitter, | ||
emitc::ApplyOp applyOp) { | ||
raw_ostream &os = emitter.ostream(); | ||
|
@@ -1508,6 +1532,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { | |
.Case<func::CallOp, func::FuncOp, func::ReturnOp>( | ||
[&](auto op) { return printOperation(*this, op); }) | ||
.Case<emitc::LiteralOp>([&](auto op) { return success(); }) | ||
.Case<emitc::InstantiateFunctionTemplateOp>( | ||
[&](auto op) { return printOperation(*this, op); }) | ||
.Default([&](Operation *) { | ||
return op.emitOpError("unable to find printer for op"); | ||
}); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT | ||
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP | ||
|
||
func.func @emitc_instantiate_template() { | ||
%c1 = "emitc.constant"() <{value = 7 : i32}> : () -> i32 | ||
%0 = emitc.instantiate_function_template "func_template"<%c1> : (i32) -> !emitc.ptr<!emitc.opaque<"void">> | ||
return | ||
} | ||
// CPP-DEFAULT: void emitc_instantiate_template() { | ||
// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = 7; | ||
// CPP-DEFAULT-NEXT: void* [[V1:[^ ]*]] = &func_template<decltype([[V0]])>; | ||
|
||
// CPP-DECLTOP: void emitc_instantiate_template() { | ||
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; | ||
// CPP-DECLTOP-NEXT: void* [[V1:[^ ]*]]; | ||
// CPP-DECLTOP-NEXT: [[V0]] = 7; | ||
// CPP-DECLTOP-NEXT: [[V1]] = &func_template<decltype([[V0]])>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please give an example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done