Skip to content

[mlir][linalg] Type conversion of operands in new elementwise-op. #131542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -563,13 +563,16 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
The number of dims of the iterator-types are inferred from the rank of
the result type.

Numeric casting is performed on the input operand, promoting it to the same
data type as the result.
Comment on lines +566 to +567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document that there's default casting and that it can be specialised with the cast attribute?


Example:

Defining a unary linalg.elemwise with default indexing-map:
```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_kind<exp>
ins(%x : tensor<4x16x8xf32>)
ins(%x : tensor<4x16x8xf16>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you have changed this example so that there's casting. But what kind of casting? And why is it crucial? It would be good to expand docs.

outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```

Expand All @@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to make this a list of TypeFnAttr which allows for a sentinel none if no castin is required. Some parser/printer helpers can allow something like [cast_signed, -] to say no casting is needed. I'd also say this list should be off the size of the number of ins operands.

);

let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
Expand Down
31 changes: 25 additions & 6 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4250,17 +4250,36 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
SmallVector<Value> yields;
Value result;

TypeFn castVal = TypeFn::cast_signed;
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
return attr.getName() == "cast";
});

if (castIter != attrs.end()) {
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
castVal = attr.getValue();
}

if (arityGroup == ElementwiseArityGroup::Unary) {
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These val0 and val1 are quite enigmatic. I don't quite see what these mean. Could you use more descriptive names? Thanks!

block.getArgument(0));
result = helper.buildUnaryFn(kind.unaryFn, val0);

} else if (arityGroup == ElementwiseArityGroup::Binary) {
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
block.getArgument(1));
Value val0 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
block.getArgument(0));
Value val1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
block.getArgument(1));
result = helper.buildBinaryFn(kind.binaryFn, val0, val1);

} else if (arityGroup == ElementwiseArityGroup::Ternary) {
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
block.getArgument(1), block.getArgument(2));

// select op's select-arg (block arg 0) must remain bool.
Value val1 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
block.getArgument(1));
Value val2 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
block.getArgument(2));
result =
helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), val1, val2);
} else
assert(false && "found unhandled category in elemwise");

Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,27 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}

// -----

// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
//
// CHECK: @cast_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[A]], %[[B]]
// CHECK-SAME: outs(%[[C]]
//
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
// CHECK: %[[CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
// CHECK: %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32
// CHECK: linalg.yield %[[MUL]] : f32
//
func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] You are not following the existing naming convention from this file.

Also, a test with a non-default cast attribute would be also helpful.

%r = linalg.elementwise
kind=#linalg.elementwise_kind<mul>
ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
}
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,41 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
}

// -----

// CHECK: @convert_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>,
// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<div>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, you are not following the naming convention documented at the top.

%C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
kind=#linalg.elementwise_kind<div>
ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
}


// -----

// CHECK: @explicit_cast(%[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<16x8xi32>,
// CHECK-SAME: %[[C:.+]]: tensor<16x8xi32>) -> tensor<16x8xi32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<add>
// CHECK-SAME: {cast = #linalg.type_fn<cast_signed>}
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<16x8xi32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xi32>) -> tensor<16x8xi32>
//
func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> {
%0 = linalg.elementwise
kind=#linalg.elementwise_kind<add>
{cast = #linalg.type_fn<cast_signed>}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also want test for unsigned cast.

ins(%A, %B : tensor<16x8xi16>, tensor<16x8xi32>)
outs(%C : tensor<16x8xi32>) -> tensor<16x8xi32>
return %0 : tensor<16x8xi32>
}