Skip to content

Commit 27c6d55

Browse files
authored
[mlir][python] generate value builders (#68308)
This PR adds the additional generation of what I'm calling "value builders" (a term I'm not married to) that look like this: ```python def empty(sizes, element_type, *, loc=None, ip=None): return get_result_or_results(tensor.EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)) ``` which instantiates a `tensor.EmptyOp` and then immediately grabs the result (`OpResult`) and then returns that *instead of a handle to the op*. What's the point of adding these when `EmptyOp.result` already exists? My claim/feeling/intuition is that eDSL users are more comfortable with a value centric programming model (i.e., passing values as operands) as opposed to an operator instantiation programming model. Thus this change enables (or at least goes towards) the bindings supporting such a user and use case. For example, ```python i32 = IntegerType.get_signless(32) ... ten1 = tensor.empty((10, 10), i32) ten2 = tensor.empty((10, 10), i32) ten3 = arith.addi(ten1, ten2) ``` Note, in order to present a "pythonic" API and enable "pythonic" eDSLs, the generated identifiers (op names and operand names) are snake case instead of camel case and thus `llvm::convertToSnakeFromCamelCase` needed a small fix. Thus this PR is stacked on top of #68375. In addition, as a kind of victory lap, this PR adds a "rangefor" that looks and acts exactly like python's `range` but emits `scf.for`.
1 parent 8dd9615 commit 27c6d55

File tree

6 files changed

+209
-17
lines changed

6 files changed

+209
-17
lines changed

mlir/python/mlir/dialects/_ods_common.py

+15
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"get_default_loc_context",
1414
"get_op_result_or_value",
1515
"get_op_results_or_values",
16+
"get_op_result_or_op_results",
1617
"segmented_accessor",
1718
]
1819

@@ -167,3 +168,17 @@ def get_op_results_or_values(
167168
return arg.results
168169
else:
169170
return [get_op_result_or_value(element) for element in arg]
171+
172+
173+
def get_op_result_or_op_results(
174+
op: _Union[_cext.ir.OpView, _cext.ir.Operation],
175+
) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
176+
if isinstance(op, _cext.ir.OpView):
177+
op = op.operation
178+
return (
179+
list(get_op_results_or_values(op))
180+
if len(op.results) > 1
181+
else get_op_result_or_value(op)
182+
if len(op.results) > 0
183+
else op
184+
)

mlir/python/mlir/dialects/_scf_ops_ext.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
except ImportError as e:
88
raise RuntimeError("Error loading imports from extension module") from e
99

10-
from typing import Any, Optional, Sequence, Union
10+
from typing import Optional, Sequence, Union
11+
1112
from ._ods_common import (
1213
get_op_result_or_value as _get_op_result_or_value,
1314
get_op_results_or_values as _get_op_results_or_values,
@@ -25,7 +26,7 @@ def __init__(
2526
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
2627
*,
2728
loc=None,
28-
ip=None
29+
ip=None,
2930
):
3031
"""Creates an SCF `for` operation.
3132

mlir/python/mlir/dialects/scf.py

+38
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,42 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from typing import Optional, Sequence
6+
57
from ._scf_ops_gen import *
8+
from .arith import constant
9+
from ..ir import *
10+
11+
12+
def for_(
13+
start,
14+
stop=None,
15+
step=None,
16+
iter_args: Optional[Sequence[Value]] = None,
17+
*,
18+
loc=None,
19+
ip=None,
20+
):
21+
if step is None:
22+
step = 1
23+
if stop is None:
24+
stop = start
25+
start = 0
26+
params = [start, stop, step]
27+
for i, p in enumerate(params):
28+
if isinstance(p, int):
29+
p = constant(p)
30+
elif isinstance(p, float):
31+
raise ValueError(f"{p=} must be int.")
32+
params[i] = p
33+
34+
for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
35+
iv = for_op.induction_variable
36+
iter_args = tuple(for_op.inner_iter_args)
37+
with InsertionPoint(for_op.body):
38+
if len(iter_args) > 1:
39+
yield iv, iter_args
40+
elif len(iter_args) == 1:
41+
yield iv, iter_args[0]
42+
else:
43+
yield iv

mlir/test/mlir-tblgen/op-python-bindings.td

+69-1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
6060
Optional<AnyType>:$variadic2);
6161
}
6262

63+
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
64+
// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
65+
6366
// CHECK: @_ods_cext.register_operation(_Dialect)
6467
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
6568
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
@@ -104,6 +107,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
104107
Variadic<AnyType>:$variadic2);
105108
}
106109

110+
// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
111+
// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
112+
107113

108114
// CHECK: @_ods_cext.register_operation(_Dialect)
109115
// CHECK: class AttributedOp(_ods_ir.OpView):
@@ -151,6 +157,9 @@ def AttributedOp : TestOp<"attributed_op"> {
151157
UnitAttr:$unitAttr, I32Attr:$in);
152158
}
153159

160+
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
161+
// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
162+
154163
// CHECK: @_ods_cext.register_operation(_Dialect)
155164
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
156165
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
@@ -184,6 +193,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
184193
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
185194
}
186195

196+
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
197+
// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
198+
187199
// CHECK: @_ods_cext.register_operation(_Dialect)
188200
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
189201
// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
@@ -205,6 +217,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
205217
let results = (outs);
206218
}
207219

220+
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
221+
// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
222+
208223
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
209224
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
210225
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
@@ -220,13 +235,19 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
220235
let results = (outs AnyType:$res, AnyType);
221236
}
222237

238+
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
239+
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
240+
223241
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
224242
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
225243
// CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
226244
let arguments = (ins TypeAttr:$type);
227245
let results = (outs AnyType:$res, Variadic<AnyType>);
228246
}
229247

248+
// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
249+
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
250+
230251
// CHECK: @_ods_cext.register_operation(_Dialect)
231252
// CHECK: class EmptyOp(_ods_ir.OpView):
232253
// CHECK-LABEL: OPERATION_NAME = "test.empty"
@@ -241,6 +262,8 @@ def EmptyOp : TestOp<"empty">;
241262
// CHECK: attributes=attributes, results=results, operands=operands,
242263
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
243264

265+
// CHECK: def empty(*, loc=None, ip=None)
266+
// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
244267

245268
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
246269
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -252,6 +275,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
252275
let results = (outs I32:$i32, F32:$f32);
253276
}
254277

278+
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
279+
// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
280+
255281
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
256282
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
257283
// CHECK: def __init__(self, *, loc=None, ip=None):
@@ -262,6 +288,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
262288
let results = (outs AnyType, AnyType, AnyType);
263289
}
264290

291+
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
292+
// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
293+
265294
// CHECK: @_ods_cext.register_operation(_Dialect)
266295
// CHECK: class MissingNamesOp(_ods_ir.OpView):
267296
// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
@@ -297,6 +326,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
297326
let results = (outs I32:$i32, AnyFloat, I64:$i64);
298327
}
299328

329+
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
330+
// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
331+
300332
// CHECK: @_ods_cext.register_operation(_Dialect)
301333
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
302334
// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
@@ -323,9 +355,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
323355
// CHECK: @builtins.property
324356
// CHECK: def optional(self):
325357
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
326-
327358
}
328359

360+
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
361+
// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
362+
329363
// CHECK: @_ods_cext.register_operation(_Dialect)
330364
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
331365
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
@@ -355,6 +389,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
355389
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
356390
}
357391

392+
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
393+
// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
394+
358395
// CHECK: @_ods_cext.register_operation(_Dialect)
359396
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
360397
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
@@ -385,6 +422,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
385422
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
386423
}
387424

425+
// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
426+
// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
427+
388428
// CHECK: @_ods_cext.register_operation(_Dialect)
389429
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
390430
// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
@@ -405,6 +445,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
405445
// CHECK: return self.operation.operands[0]
406446
let arguments = (ins AnyType:$in);
407447
}
448+
449+
// CHECK: def python_keyword(in_, *, loc=None, ip=None)
450+
// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
451+
408452
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
409453
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
410454
// CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
@@ -416,13 +460,19 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
416460
let results = (outs AnyType:$res);
417461
}
418462

463+
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
464+
// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
465+
419466
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
420467
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
421468
// CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
422469
let arguments = (ins AnyType:$in1, AnyType:$in2);
423470
let results = (outs Variadic<AnyType>:$res);
424471
}
425472

473+
// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
474+
// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
475+
426476

427477
// CHECK: @_ods_cext.register_operation(_Dialect)
428478
// CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
@@ -447,6 +497,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
447497
Variadic<AnyType>:$variadic2);
448498
}
449499

500+
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
501+
// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
502+
450503
// CHECK: @_ods_cext.register_operation(_Dialect)
451504
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
452505
// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
@@ -470,6 +523,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
470523
Variadic<AnyType>:$variadic2);
471524
}
472525

526+
// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
527+
// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
528+
473529
// CHECK: @_ods_cext.register_operation(_Dialect)
474530
// CHECK: class SimpleOp(_ods_ir.OpView):
475531
// CHECK-LABEL: OPERATION_NAME = "test.simple"
@@ -507,6 +563,9 @@ def SimpleOp : TestOp<"simple"> {
507563
let results = (outs I64:$i64, AnyFloat:$f64);
508564
}
509565

566+
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
567+
// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
568+
510569
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
511570
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
512571
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
@@ -531,6 +590,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
531590
// CHECK: return self.regions[2:]
532591
}
533592

593+
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
594+
// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
595+
534596
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
535597
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
536598
def VariadicRegionOp : TestOp<"variadic_region"> {
@@ -551,6 +613,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
551613
// CHECK: return self.regions[0:]
552614
}
553615

616+
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
617+
// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
618+
554619
// CHECK: @_ods_cext.register_operation(_Dialect)
555620
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
556621
// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
@@ -562,3 +627,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
562627
let successors = (successor AnySuccessor:$successor,
563628
VariadicSuccessor<AnySuccessor>:$successors);
564629
}
630+
631+
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
632+
// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))

mlir/test/python/dialects/scf.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from mlir.dialects import arith
55
from mlir.dialects import func
66
from mlir.dialects import scf
7-
from mlir.dialects import builtin
87

98

109
def constructAndPrintInModule(f):
@@ -54,6 +53,28 @@ def induction_var(lb, ub, step):
5453
# CHECK: scf.yield %[[IV]]
5554

5655

56+
# CHECK-LABEL: TEST: testForSugar
57+
@constructAndPrintInModule
58+
def testForSugar():
59+
index_type = IndexType.get()
60+
range = scf.for_
61+
62+
@func.FuncOp.from_py_func(index_type, index_type, index_type)
63+
def range_loop(lb, ub, step):
64+
for i in range(lb, ub, step):
65+
add = arith.addi(i, i)
66+
scf.yield_([])
67+
return
68+
69+
70+
# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
71+
# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
72+
# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
73+
# CHECK: }
74+
# CHECK: return
75+
# CHECK: }
76+
77+
5778
@constructAndPrintInModule
5879
def testOpsAsArguments():
5980
index_type = IndexType.get()

0 commit comments

Comments
 (0)