Skip to content

[mlir][emitc] Refactor emitc.apply op #72569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 32 additions & 19 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,23 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
let hasVerifier = 1;
}

def EmitC_ApplyOp : EmitC_Op<"apply", []> {
let summary = "Apply operation";
def EmitC_AddressOfOp : EmitC_Op<"address_of", []> {
let summary = "Address operation";
let description = [{
With the `apply` operation the operators & (address of) and * (contents of)
can be applied to a single operand.
This operation models the C & (address of) operator for a single operand which
must be an emitc.variable. It returns an emitc pointer to the variable.

Example:

```mlir
// Custom form of applying the & operator.
%0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>

// Generic form of the same operation.
%0 = "emitc.apply"(%arg0) {applicableOperator = "&"}
: (i32) -> !emitc.ptr<i32>

%0 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
```
}];
let arguments = (ins
Arg<StrAttr, "the operator to apply">:$applicableOperator,
AnyType:$operand
);
let results = (outs AnyType:$result);
let arguments = (ins AnyType:$var);
let results = (outs EmitC_PointerType:$result);
let assemblyFormat = [{
$applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
$var attr-dict `:` functional-type($var, $result)
}];
let hasVerifier = 1;
}
Expand Down Expand Up @@ -222,6 +214,27 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
let hasVerifier = 1;
}

def EmitC_DereferenceOp : EmitC_Op<"dereference", []> {
let summary = "Dereference operation";
let description = [{
This operation models the C * (dereference) operator for a single operand which
must be of !emitc.ptr<> type. It returns the value pointed to by the pointer.

Example:

```mlir
// Custom form of applying the & operator.
%0 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> i32
```
}];
let arguments = (ins EmitC_PointerType:$pointer);
let results = (outs AnyType:$result);
let assemblyFormat = [{
$pointer attr-dict `:` functional-type($pointer, $result)
}];
let hasVerifier = 1;
}

def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let summary = "Division operation";
let description = [{
Expand Down Expand Up @@ -448,12 +461,12 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {

Since folding is not supported, it can be used with pointers.
As an example, it is valid to create pointers to `variable` operations
by using `apply` operations and pass these to a `call` operation.
by using `address_of` operations and pass these to a `call` operation.
```mlir
%0 = "emitc.variable"() {value = 0 : i32} : () -> i32
%1 = "emitc.variable"() {value = 0 : i32} : () -> i32
%2 = emitc.apply "&"(%0) : (i32) -> !emitc.ptr<i32>
%3 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
%2 = emitc.address_of %0 : (i32) -> !emitc.ptr<i32>
%3 = emitc.address_of %1 : (i32) -> !emitc.ptr<i32>
emitc.call "write"(%2, %3) : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> ()
```
}];
Expand Down
41 changes: 28 additions & 13 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,21 @@ LogicalResult AddOp::verify() {
}

//===----------------------------------------------------------------------===//
// ApplyOp
// AddressOfOp
//===----------------------------------------------------------------------===//

LogicalResult ApplyOp::verify() {
StringRef applicableOperatorStr = getApplicableOperator();

// Applicable operator must not be empty.
if (applicableOperatorStr.empty())
return emitOpError("applicable operator must not be empty");
LogicalResult AddressOfOp::verify() {
Value variable = getVar();
auto variableDef = dyn_cast_if_present<VariableOp>(variable.getDefiningOp());
if (!variableDef)
return emitOpError() << "requires operand to be a variable";

// Only `*` and `&` are supported.
if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
return emitOpError("applicable operator is illegal");
Type variableType = variable.getType();
emitc::PointerType resultType = getResult().getType();
Type pointeeType = resultType.getPointee();

Operation *op = getOperand().getDefiningOp();
if (op && dyn_cast<ConstantOp>(op))
return emitOpError("cannot apply to constant");
if (variableType != pointeeType)
return emitOpError("requires variable to be of type pointed to by result");

return success();
}
Expand Down Expand Up @@ -189,6 +187,23 @@ LogicalResult emitc::ConstantOp::verify() {

OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// DereferenceOp
//===----------------------------------------------------------------------===//

LogicalResult DereferenceOp::verify() {
auto pointer = getPointer();
emitc::PointerType pointerType = pointer.getType();
Type pointeeType = pointerType.getPointee();
Type resultType = getResult().getType();

if (pointeeType != resultType)
return emitOpError()
<< "requires result to be of type pointed to by operand";

return success();
}

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
Expand Down
46 changes: 30 additions & 16 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
return emitter.emitAttribute(operation->getLoc(), value);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::AddressOfOp addressOfOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *addressOfOp.getOperation();

if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << "&";
os << emitter.getOrCreateName(addressOfOp.getOperand());

return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation();
Expand Down Expand Up @@ -461,30 +474,30 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ApplyOp applyOp) {
static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *applyOp.getOperation();
Operation &op = *castOp.getOperation();

if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << applyOp.getApplicableOperator();
os << emitter.getOrCreateName(applyOp.getOperand());
os << "(";
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
return failure();
os << ") ";
os << emitter.getOrCreateName(castOp.getOperand());

return success();
}

static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
static LogicalResult printOperation(CppEmitter &emitter,
emitc::DereferenceOp dereferenceOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *castOp.getOperation();
Operation &op = *dereferenceOp.getOperation();

if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << "(";
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
return failure();
os << ") ";
os << emitter.getOrCreateName(castOp.getOperand());
os << "*";
os << emitter.getOrCreateName(dereferenceOp.getOperand());

return success();
}
Expand Down Expand Up @@ -949,10 +962,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<cf::BranchOp, cf::CondBranchOp>(
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp,
emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
.Case<emitc::AddOp, emitc::AddressOfOp, emitc::AssignOp,
emitc::CallOp, emitc::CastOp, emitc::CmpOp, emitc::ConstantOp,
emitc::DereferenceOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
Expand Down
22 changes: 7 additions & 15 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,18 @@ func.func @dense_template_argument(%arg : i32) {

// -----

func.func @empty_operator(%arg : i32) {
// expected-error @+1 {{'emitc.apply' op applicable operator must not be empty}}
%2 = emitc.apply ""(%arg) : (i32) -> !emitc.ptr<i32>
return
}

// -----

func.func @illegal_operator(%arg : i32) {
// expected-error @+1 {{'emitc.apply' op applicable operator is illegal}}
%2 = emitc.apply "+"(%arg) : (i32) -> !emitc.ptr<i32>
func.func @illegal_address_of_operand() {
%1 = "emitc.constant"(){value = 42: i32} : () -> i32
// expected-error @+1 {{'emitc.address_of' op requires operand to be a variable}}
%2 = emitc.address_of %1 : (i32) -> !emitc.ptr<i32>
return
}

// -----

func.func @illegal_operand() {
%1 = "emitc.constant"(){value = 42: i32} : () -> i32
// expected-error @+1 {{'emitc.apply' op cannot apply to constant}}
%2 = emitc.apply "&"(%1) : (i32) -> !emitc.ptr<i32>
func.func @illegal_dereference_operand(%arg0 : !emitc.ptr<i32>) {
// expected-error @+1 {{'emitc.dereference' op requires result to be of type pointed to by operand}}
%2 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> (f32)
return
}

Expand Down
13 changes: 10 additions & 3 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ func.func @c() {
return
}

func.func @a(%arg0: i32, %arg1: i32) {
%1 = "emitc.apply"(%arg0) {applicableOperator = "&"} : (i32) -> !emitc.ptr<i32>
%2 = emitc.apply "&"(%arg1) : (i32) -> !emitc.ptr<i32>
func.func @a() {
%arg0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
%1 = "emitc.address_of"(%arg0) : (i32) -> !emitc.ptr<i32>
%2 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
return
}

Expand All @@ -47,6 +48,12 @@ func.func @div_int(%arg0: i32, %arg1: i32) {
return
}

func.func @dereference(%arg0: !emitc.ptr<i32>) {
%1 = "emitc.dereference"(%arg0) : (!emitc.ptr<i32>) -> (i32)
%2 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> (i32)
return
}

func.func @div_float(%arg0: f32, %arg1: f32) {
%1 = "emitc.div" (%arg0, %arg1) : (f32, f32) -> f32
return
Expand Down
7 changes: 4 additions & 3 deletions mlir/test/Target/Cpp/common-cpp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ func.func @opaque_types(%arg0: !emitc.opaque<"bool">, %arg1: !emitc.opaque<"char
return %2 : !emitc.opaque<"status_t">
}

func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
func.func @apply() -> !emitc.ptr<i32> {
%arg0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
// CHECK: int32_t* [[V2]] = &[[V1]];
%0 = emitc.apply "&"(%arg0) : (i32) -> !emitc.ptr<i32>
%0 = emitc.address_of %arg0 : (i32) -> !emitc.ptr<i32>
// CHECK: int32_t [[V3]] = *[[V2]];
%1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
%1 = emitc.dereference %0 : (!emitc.ptr<i32>) -> (i32)
return %0 : !emitc.ptr<i32>
}