-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][python] generate value builders #68308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a27d4bb
to
032ec20
Compare
✅ With the latest revision this PR passed the Python code formatter. |
5fa1dfe
to
5a9046f
Compare
@llvm/pr-subscribers-llvm-support @llvm/pr-subscribers-llvm-adt ChangesThis PR adds the additional generation of what I'm calling "value builders" (a term I'm not married to) that look like this: def empty(sizes, element_type, *, loc=None, ip=None):
return get_result_or_results(tensor.EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)) which instantiates a What's the point of adding these when i32 = IntegerType.get_signless(32)
...
ten1 = tensor.empty((10, 10), i32)
ten2 = tensor.empty((10, 10), i32)
ten3 = arith.addi(ten1, ten2) Note, in order to present a "pythonic" API and enable "pythonic" eDSLs, the generated identifiers (op names and operand names) are snake case instead of camel case and thus Patch is 26.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68308.diff 7 Files Affected:
diff --git a/llvm/lib/Support/StringExtras.cpp b/llvm/lib/Support/StringExtras.cpp
index 5683d7005584eb2..fd5a34fb3d6e82c 100644
--- a/llvm/lib/Support/StringExtras.cpp
+++ b/llvm/lib/Support/StringExtras.cpp
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include <cctype>
@@ -96,18 +97,13 @@ std::string llvm::convertToSnakeFromCamelCase(StringRef input) {
if (input.empty())
return "";
- std::string snakeCase;
- snakeCase.reserve(input.size());
- for (char c : input) {
- if (!std::isupper(c)) {
- snakeCase.push_back(c);
- continue;
- }
-
- if (!snakeCase.empty() && snakeCase.back() != '_')
- snakeCase.push_back('_');
- snakeCase.push_back(llvm::toLower(c));
+ std::string snakeCase = input.str();
+ for (int i = 0; i < 10; ++i) {
+ snakeCase = llvm::Regex("([A-Z]+)([A-Z][a-z])").sub("\\1_\\2", snakeCase);
+ snakeCase = llvm::Regex("([a-z0-9])([A-Z])").sub("\\1_\\2", snakeCase);
}
+ std::transform(snakeCase.begin(), snakeCase.end(), snakeCase.begin(),
+ [](unsigned char c) { return std::tolower(c); });
return snakeCase;
}
diff --git a/llvm/unittests/ADT/StringExtrasTest.cpp b/llvm/unittests/ADT/StringExtrasTest.cpp
index 3f69c91b270a355..fab562f1ed0d594 100644
--- a/llvm/unittests/ADT/StringExtrasTest.cpp
+++ b/llvm/unittests/ADT/StringExtrasTest.cpp
@@ -184,6 +184,11 @@ TEST(StringExtrasTest, ConvertToSnakeFromCamelCase) {
testConvertToSnakeCase("OpName", "op_name");
testConvertToSnakeCase("opName", "op_name");
+ testConvertToSnakeCase("OPName", "op_name");
+ testConvertToSnakeCase("opNAME", "op_name");
+ testConvertToSnakeCase("opNAMe", "op_na_me");
+ testConvertToSnakeCase("opnameE", "opname_e");
+ testConvertToSnakeCase("OPNameOPName", "op_name_op_name");
testConvertToSnakeCase("_OpName", "_op_name");
testConvertToSnakeCase("Op_Name", "op_name");
testConvertToSnakeCase("", "");
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 7655629a5542520..895c3228139b392 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -13,6 +13,7 @@
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
+ "get_op_result_or_op_results",
"segmented_accessor",
]
@@ -167,3 +168,17 @@ def get_op_results_or_values(
return arg.results
else:
return [get_op_result_or_value(element) for element in arg]
+
+
+def get_op_result_or_op_results(
+ op: _Union[_cext.ir.OpView, _cext.ir.Operation],
+) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
+ if isinstance(op, _cext.ir.OpView):
+ op = op.operation
+ return (
+ list(get_op_results_or_values(op))
+ if len(op.results) > 1
+ else get_op_result_or_value(op)
+ if len(op.results) > 0
+ else op
+ )
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index 4b0a31327abb0ee..35bd247a0a1e7f7 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -7,11 +7,13 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Any, Optional, Sequence, Union
+from typing import Optional, Sequence, Union
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
+from .arith import constant
+from . import scf
class ForOp:
@@ -25,7 +27,7 @@ def __init__(
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
- ip=None
+ ip=None,
):
"""Creates an SCF `for` operation.
@@ -104,3 +106,37 @@ def then_block(self):
def else_block(self):
"""Returns the else block of the if operation."""
return self.regions[1].blocks[0]
+
+
+def range_(
+ start,
+ stop=None,
+ step=None,
+ iter_args: Optional[Sequence[Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ if step is None:
+ step = 1
+ if stop is None:
+ stop = start
+ start = 0
+ params = [start, stop, step]
+ for i, p in enumerate(params):
+ if isinstance(p, int):
+ p = constant(p)
+ elif isinstance(p, float):
+ raise ValueError(f"{p=} must be int.")
+ params[i] = p
+
+ for_op = scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
+ iv = for_op.induction_variable
+ iter_args = tuple(for_op.inner_iter_args)
+ with InsertionPoint(for_op.body):
+ if len(iter_args) > 1:
+ yield iv, iter_args
+ elif len(iter_args) == 1:
+ yield iv, iter_args[0]
+ else:
+ yield iv
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index a131209fa45cb6c..be2cfc8adc78179 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,6 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
Optional<AnyType>:$variadic2);
}
+// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
@@ -104,6 +107,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
Variadic<AnyType>:$variadic2);
}
+// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOp(_ods_ir.OpView):
@@ -151,6 +157,9 @@ def AttributedOp : TestOp<"attributed_op"> {
UnitAttr:$unitAttr, I32Attr:$in);
}
+// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
@@ -184,6 +193,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}
+// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
@@ -205,6 +217,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
let results = (outs);
}
+// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
@@ -220,6 +235,9 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
let results = (outs AnyType:$res, AnyType);
}
+// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
// CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
@@ -227,6 +245,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
let results = (outs AnyType:$res, Variadic<AnyType>);
}
+// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.empty"
@@ -241,6 +262,8 @@ def EmptyOp : TestOp<"empty">;
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+// CHECK: def empty(*, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -252,6 +275,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
let results = (outs I32:$i32, F32:$f32);
}
+// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
// CHECK: def __init__(self, *, loc=None, ip=None):
@@ -262,6 +288,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
let results = (outs AnyType, AnyType, AnyType);
}
+// CHECK: def infer_result_types_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
@@ -297,6 +326,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
+// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
@@ -323,9 +355,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: @builtins.property
// CHECK: def optional(self):
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
-
}
+// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
@@ -355,6 +389,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
}
+// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
@@ -385,6 +422,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
}
+// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
@@ -405,6 +445,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: return self.operation.operands[0]
let arguments = (ins AnyType:$in);
}
+
+// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
// CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
@@ -416,6 +460,9 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
let results = (outs AnyType:$res);
}
+// CHECK: def same_results(in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
// CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
@@ -423,6 +470,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
let results = (outs Variadic<AnyType>:$res);
}
+// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
@@ -447,6 +497,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
Variadic<AnyType>:$variadic2);
}
+// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
@@ -470,6 +523,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
Variadic<AnyType>:$variadic2);
}
+// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.simple"
@@ -507,6 +563,9 @@ def SimpleOp : TestOp<"simple"> {
let results = (outs I64:$i64, AnyFloat:$f64);
}
+// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+
// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
@@ -531,6 +590,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: return self.regions[2:]
}
+// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
def VariadicRegionOp : TestOp<"variadic_region"> {
@@ -551,6 +613,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: return self.regions[0:]
}
+// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
@@ -562,3 +627,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
let successors = (successor AnySuccessor:$successor,
VariadicSuccessor<AnySuccessor>:$successors);
}
+
+// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 8cb55fdf6a1eb3b..843b87b66871761 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -4,7 +4,7 @@
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import scf
-from mlir.dialects import builtin
+from mlir.dialects._scf_ops_ext import range_
def constructAndPrintInModule(f):
@@ -54,6 +54,27 @@ def induction_var(lb, ub, step):
# CHECK: scf.yield %[[IV]]
+# CHECK-LABEL: TEST: testForSugar
+@constructAndPrintInModule
+def testForSugar():
+ index_type = IndexType.get()
+
+ @func.FuncOp.from_py_func(index_type, index_type, index_type)
+ def range_loop(lb, ub, step):
+ for i in range_(lb, ub, step):
+ add = arith.addi(i, i)
+ scf.yield_([])
+ return
+
+
+# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
+# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK: }
+# CHECK: return
+# CHECK: }
+
+
@constructAndPrintInModule
def testOpsAsArguments():
index_type = IndexType.get()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0b5df7ab70dddb2..1357c84099ccd5a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
_ods_ir = _ods_cext.ir
try:
@@ -39,6 +39,7 @@ except ImportError:
_ods_ext_module = None
import builtins
+from typing import Sequence as _Sequence, Union as _Union
)Py";
@@ -260,11 +261,16 @@ constexpr const char...
[truncated]
|
Could you factor out the libSupport/string manipulation part - smaller/independent patches are easier to manage/review/etc |
09b3e29
to
b3a1e3d
Compare
b3a1e3d
to
cfc3131
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable, we can always adjust :)
Stray word
This PR adds the additional generation of what I'm calling "value builders" (a term I'm not married to) that look like this:
which instantiates a
tensor.EmptyOp
and then immediately grabs the result (OpResult
) and then returns that instead of a handle to the op.What's the point of adding these when
EmptyOp.result
already exists? My claim/feeling/intuition is that eDSL users are more comfortable with a value centric programming model (i.e., passing values as operands) as opposed to an operator instantiation programming model. Thus this change enables (or at least goes towards) the bindings supporting such a user and use case. For example,Note, in order to present a "pythonic" API and enable "pythonic" eDSLs, the generated identifiers (op names and operand names) are snake case instead of camel case and thus
llvm::convertToSnakeFromCamelCase
needed a small fix. Thus this PR is stacked on top of #68375.In addition, as a kind of victory lap, this PR adds a "rangefor" that looks and acts exactly like python's
range
but emitsscf.for
.