Skip to content

Commit e7d09ce

Browse files
kurapov-peterrengolinmakslevental
authored
[MLIR][Linalg] Ternary Op & Linalg select (llvm#91461)
Following llvm#90236, adding `select` to linalg as `arith.select`. No implicit type casting. OpDSL doesn't expose a type restriction for bool, but I saw no reason in adding it (put a separate symbolic type and check the semantics in the builder). --------- Co-authored-by: Renato Golin <[email protected]> Co-authored-by: Maksim Levental <[email protected]>
1 parent 995a8af commit e7d09ce

File tree

11 files changed

+255
-3
lines changed

11 files changed

+255
-3
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
6868
def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
6969
let assemblyFormat = "`<` $value `>`";
7070
}
71+
def TernaryFnAttr : EnumAttr<Linalg_Dialect, TernaryFn, "ternary_fn"> {
72+
let assemblyFormat = "`<` $value `>`";
73+
}
7174
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
7275
let assemblyFormat = "`<` $value `>`";
7376
}

mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
4949
let genSpecializedAttr = 0;
5050
let cppNamespace = "::mlir::linalg";
5151
}
52+
def TernaryFn : I32EnumAttr<"TernaryFn", "", [
53+
I32EnumAttrCase<"select", 0>
54+
]> {
55+
let genSpecializedAttr = 0;
56+
let cppNamespace = "::mlir::linalg";
57+
}
5258
def TypeFn : I32EnumAttr<"TypeFn", "", [
5359
I32EnumAttrCase<"cast_signed", 0>,
5460
I32EnumAttrCase<"cast_unsigned", 1>

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,63 @@ structured_op: !LinalgStructuredOpConfig
10081008
- !ScalarExpression
10091009
scalar_arg: rhs
10101010
--- !LinalgOpConfig
1011+
metadata: !LinalgOpMetadata
1012+
name: select
1013+
cpp_class_name: SelectOp
1014+
doc: |-
1015+
Chooses one value based on a binary condition supplied as its first operand.
1016+
1017+
The shapes and element types must be identical. The appropriate casts,
1018+
broadcasts and reductions should be done previously to calling this op.
1019+
1020+
This means reduction/broadcast/element cast semantics is explicit. Further
1021+
passes can take that into account when lowering this code. For example,
1022+
a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
1023+
`linalg.generic` with different affine maps for the two operands.
1024+
structured_op: !LinalgStructuredOpConfig
1025+
args:
1026+
- !LinalgOperandDefConfig
1027+
name: cond
1028+
kind: input_tensor
1029+
type_var: U
1030+
shape_map: affine_map<() -> ()>
1031+
- !LinalgOperandDefConfig
1032+
name: lhs
1033+
kind: input_tensor
1034+
type_var: T1
1035+
shape_map: affine_map<() -> ()>
1036+
- !LinalgOperandDefConfig
1037+
name: rhs
1038+
kind: input_tensor
1039+
type_var: T1
1040+
shape_map: affine_map<() -> ()>
1041+
- !LinalgOperandDefConfig
1042+
name: O
1043+
kind: output_tensor
1044+
type_var: T1
1045+
shape_map: affine_map<() -> ()>
1046+
indexing_maps: !LinalgIndexingMapsConfig
1047+
static_indexing_maps:
1048+
- affine_map<() -> ()>
1049+
- affine_map<() -> ()>
1050+
- affine_map<() -> ()>
1051+
- affine_map<() -> ()>
1052+
iterator_types: []
1053+
assignments:
1054+
- !ScalarAssign
1055+
arg: O
1056+
value: !ScalarExpression
1057+
scalar_fn:
1058+
kind: ternary
1059+
fn_name: select
1060+
operands:
1061+
- !ScalarExpression
1062+
scalar_arg: cond
1063+
- !ScalarExpression
1064+
scalar_arg: lhs
1065+
- !ScalarExpression
1066+
scalar_arg: rhs
1067+
--- !LinalgOpConfig
10111068
metadata: !LinalgOpMetadata
10121069
name: matmul
10131070
cpp_class_name: MatmulOp

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,25 @@ class RegionBuilderHelper {
492492
llvm_unreachable("unsupported binary function");
493493
}
494494

495+
// Build the ternary functions defined by OpDSL.
496+
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
497+
Value arg2) {
498+
bool headBool =
499+
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
500+
bool tailFloatingPoint =
501+
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
502+
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
503+
OpBuilder::InsertionGuard g(builder);
504+
builder.setInsertionPointToEnd(&block);
505+
switch (ternaryFn) {
506+
case TernaryFn::select:
507+
if (!headBool && !(tailFloatingPoint || tailInteger))
508+
llvm_unreachable("unsupported non numeric type");
509+
return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
510+
}
511+
llvm_unreachable("unsupported ternary function");
512+
}
513+
495514
// Build the type functions defined by OpDSL.
496515
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
497516
switch (typeFn) {

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def __repr__(self):
262262
class FunctionKind(Enum):
263263
UNARY = 0
264264
BINARY = 1
265-
TYPE = 2
265+
TERNARY = 2
266+
TYPE = 3
266267

267268

268269
class UnaryFnType:
@@ -339,6 +340,33 @@ class BinaryFn:
339340
powf = BinaryFnType("powf")
340341

341342

343+
class TernaryFnType:
344+
"""Ternary function.
345+
346+
A ternary function takes three tensor expressions and returns the
347+
function evaluation result.
348+
"""
349+
350+
def __init__(self, fn_name: str):
351+
self.fn_name = fn_name
352+
353+
def __call__(
354+
self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression
355+
) -> "TensorFn":
356+
return TensorFn(
357+
FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2]
358+
)
359+
360+
def __repr__(self):
361+
return f"{self.fn_name}"
362+
363+
364+
class TernaryFn:
365+
"""Ternary function namespace."""
366+
367+
select = TernaryFnType("select")
368+
369+
342370
class TypeFnType:
343371
"""Type conversion function.
344372
@@ -437,7 +465,8 @@ class OperandKind(Enum):
437465
INDEX_ATTR = 3
438466
UNARY_FN_ATTR = 4
439467
BINARY_FN_ATTR = 5
440-
TYPE_FN_ATTR = 6
468+
TERNARY_FN_ATTR = 6
469+
TYPE_FN_ATTR = 7
441470

442471

443472
class OperandDef:
@@ -489,6 +518,7 @@ def is_attribute(self) -> bool:
489518
self.kind == OperandKind.INDEX_ATTR
490519
or self.kind == OperandKind.UNARY_FN_ATTR
491520
or self.kind == OperandKind.BINARY_FN_ATTR
521+
or self.kind == OperandKind.TERNARY_FN_ATTR
492522
or self.kind == OperandKind.TYPE_FN_ATTR
493523
)
494524

@@ -670,6 +700,33 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
670700
return ReduceFnUse(None, self, *reduce_dims)
671701

672702

703+
class TernaryFnAttrDef:
704+
"""Ternary function attribute definition.
705+
706+
Ternary function attributes provide a way to make the arithmetic computation
707+
parametrizable. Every attribute specifies a default Ternary function
708+
that may be overwritten at operation instantiation time.
709+
"""
710+
711+
def __init__(self, default: "TernaryFnType"):
712+
if not isinstance(default, TernaryFnType):
713+
raise ValueError(
714+
f"TernaryFnAttrDef requires default of type TernaryFnType "
715+
f"but got {default}"
716+
)
717+
self.operand_def = OperandDef(
718+
OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
719+
)
720+
721+
def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
722+
return TensorFn(
723+
FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1]
724+
)
725+
726+
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
727+
return ReduceFnUse(None, self, *reduce_dims)
728+
729+
673730
class TypeFnAttrDef:
674731
"""Type conversion function attribute definition.
675732

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def prepare_common_structured_op(
6060
in [
6161
OperandKind.UNARY_FN_ATTR,
6262
OperandKind.BINARY_FN_ATTR,
63+
OperandKind.TERNARY_FN_ATTR,
6364
OperandKind.TYPE_FN_ATTR,
6465
]
6566
]
@@ -180,6 +181,12 @@ def prepare_common_structured_op(
180181
f"Attribute {fn_attr.name} needs to be of type "
181182
f"BinaryFnType but got {type(attr_val)}"
182183
)
184+
elif attr_kind == OperandKind.TERNARY_FN_ATTR:
185+
if not isinstance(fn, TernaryFnType):
186+
raise ValueError(
187+
f"Attribute {fn_attr.name} needs to be of type "
188+
f"TernaryFnType but got {type(attr_val)}"
189+
)
183190
else:
184191
if not isinstance(fn, TypeFnType):
185192
raise ValueError(

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,26 @@ def powf(
351351
O[None] = BinaryFn.powf(lhs[None], rhs[None])
352352

353353

354+
@linalg_structured_op
355+
def select(
356+
cond=TensorDef(U),
357+
lhs=TensorDef(T1),
358+
rhs=TensorDef(T1),
359+
O=TensorDef(T1, output=True),
360+
):
361+
"""Chooses one value based on a binary condition supplied as its first operand.
362+
363+
The shapes and element types must be identical. The appropriate casts,
364+
broadcasts and reductions should be done previously to calling this op.
365+
366+
This means reduction/broadcast/element cast semantics is explicit. Further
367+
passes can take that into account when lowering this code. For example,
368+
a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
369+
`linalg.generic` with different affine maps for the two operands.
370+
"""
371+
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
372+
373+
354374
@linalg_structured_op
355375
def matmul(
356376
A=TensorDef(T1, S.M, S.K),

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,31 @@ func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
791791

792792
// -----
793793

794+
func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
795+
%out: memref<7x14x21xf32>) {
796+
linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
797+
outs(%out: memref<7x14x21xf32>)
798+
return
799+
}
800+
801+
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
802+
803+
// CHECK: func @generalize_select
804+
// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
805+
// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
806+
807+
// CHECK: linalg.generic
808+
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]]
809+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
810+
// CHECK-SAME: ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
811+
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
812+
813+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32)
814+
// CHECK-NEXT: %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32
815+
// CHECK-NEXT: linalg.yield %[[select]] : f32
816+
817+
818+
// -----
794819

795820
// CHECK-LABEL: func @fill_tensor
796821
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {

mlir/test/Dialect/Linalg/named-ops-fail.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,19 @@ func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %a
334334
return
335335
}
336336

337+
// -----
338+
339+
func.func @select_type_cast(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
340+
// CHECK: op failed to verify that all of {true_value, false_value, result} have same type
341+
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf16>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
342+
return
343+
}
344+
345+
// -----
346+
347+
func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
348+
// CHECK: op operand #0 must be bool-like, but got 'f32'
349+
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
350+
return
351+
}
352+

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,3 +1924,37 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
19241924
%1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
19251925
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
19261926
}
1927+
1928+
// -----
1929+
1930+
// CHECK-LABEL: func @select_dynamic
1931+
func.func @select_dynamic(%arg0: memref<?x?x?xi1>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>) {
1932+
// CHECK: linalg.select
1933+
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>)
1934+
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
1935+
linalg.select ins(%arg0, %arg1, %arg2 : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg3: memref<?x?x?xf32>)
1936+
return
1937+
}
1938+
1939+
// -----
1940+
1941+
// CHECK-LABEL: func @select_static
1942+
func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
1943+
// CHECK: linalg.select
1944+
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>)
1945+
// CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
1946+
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
1947+
return
1948+
}
1949+
1950+
// -----
1951+
1952+
// CHECK-LABEL: func @select_tensor
1953+
func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1954+
%0 = tensor.empty() : tensor<4x8x16xf32>
1955+
// CHECK: linalg.select
1956+
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
1957+
// CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
1958+
%1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1959+
return %1 : tensor<4x8x16xf32>
1960+
}

0 commit comments

Comments
 (0)