Skip to content

[mlir][python] generate value builders #68308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 9, 2023

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Oct 5, 2023

This PR adds the additional generation of what I'm calling "value builders" (a term I'm not married to) that look like this:

def empty(sizes, element_type, *, loc=None, ip=None):
    return get_result_or_results(tensor.EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip))

which instantiates a tensor.EmptyOp and then immediately grabs the result (OpResult) and then returns that instead of a handle to the op.

What's the point of adding these when EmptyOp.result already exists? My claim/feeling/intuition is that eDSL users are more comfortable with a value centric programming model (i.e., passing values as operands) as opposed to an operator instantiation programming model. Thus this change enables (or at least goes towards) the bindings supporting such a user and use case. For example,

i32 = IntegerType.get_signless(32)
...
ten1 = tensor.empty((10, 10), i32)
ten2 = tensor.empty((10, 10), i32)
ten3 = arith.addi(ten1, ten2)

Note, in order to present a "pythonic" API and enable "pythonic" eDSLs, the generated identifiers (op names and operand names) are snake case instead of camel case and thus llvm::convertToSnakeFromCamelCase needed a small fix. Thus this PR is stacked on top of #68375.

In addition, as a kind of victory lap, this PR adds a "rangefor" that looks and acts exactly like python's range but emits scf.for.

@makslevental makslevental force-pushed the value_builders branch 3 times, most recently from a27d4bb to 032ec20 Compare October 5, 2023 16:17
@github-actions
Copy link

github-actions bot commented Oct 5, 2023

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the value_builders branch 4 times, most recently from 5fa1dfe to 5a9046f Compare October 5, 2023 17:35
@makslevental makslevental marked this pull request as ready for review October 5, 2023 18:20
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir llvm:support llvm:adt labels Oct 5, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 5, 2023

@llvm/pr-subscribers-llvm-support
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-llvm-adt

Changes

This PR adds the additional generation of what I'm calling "value builders" (a term I'm not married to) that look like this:

def empty(sizes, element_type, *, loc=None, ip=None):
    return get_result_or_results(tensor.EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip))

which instantiates a tensor.EmptyOp and then immediately grabs the result (OpResult) and then returns that instead of a handle to the op.

What's the point of adding these when EmptyOp.result already exists? My claim/feeling/intuition is that eDSL users are more comfortable with a value centric programming model (i.e., passing values as operands) as opposed to an operator instantiation programming model. Thus this change enables (or at least goes towards) the bindings supporting such a user and use case. For example,

i32 = IntegerType.get_signless(32)
...
ten1 = tensor.empty((10, 10), i32)
ten2 = tensor.empty((10, 10), i32)
ten3 = arith.addi(ten1, ten2)

Note, in order to present a "pythonic" API and enable "pythonic" eDSLs, the generated identifiers (op names and operand names) are snake case instead of camel case and thus llvm::convertToSnakeFromCamelCase needed a small fix; currently runs of caps aren't handled correctly so e.g. something like Intel_OCL_BI is snake cased to intel_o_c_l_b_i (previously discussed on this phabricator patch). The easiest way to fix was to use regexes instead of the existing loop. Full disclosure the regexes were pulled from inflection but it should be pretty clear they're correct.


Patch is 26.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68308.diff

7 Files Affected:

  • (modified) llvm/lib/Support/StringExtras.cpp (+7-11)
  • (modified) llvm/unittests/ADT/StringExtrasTest.cpp (+5)
  • (modified) mlir/python/mlir/dialects/_ods_common.py (+15)
  • (modified) mlir/python/mlir/dialects/_scf_ops_ext.py (+38-2)
  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+69-1)
  • (modified) mlir/test/python/dialects/scf.py (+22-1)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+63-12)
diff --git a/llvm/lib/Support/StringExtras.cpp b/llvm/lib/Support/StringExtras.cpp
index 5683d7005584eb2..fd5a34fb3d6e82c 100644
--- a/llvm/lib/Support/StringExtras.cpp
+++ b/llvm/lib/Support/StringExtras.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cctype>
 
@@ -96,18 +97,13 @@ std::string llvm::convertToSnakeFromCamelCase(StringRef input) {
   if (input.empty())
     return "";
 
-  std::string snakeCase;
-  snakeCase.reserve(input.size());
-  for (char c : input) {
-    if (!std::isupper(c)) {
-      snakeCase.push_back(c);
-      continue;
-    }
-
-    if (!snakeCase.empty() && snakeCase.back() != '_')
-      snakeCase.push_back('_');
-    snakeCase.push_back(llvm::toLower(c));
+  std::string snakeCase = input.str();
+  for (int i = 0; i < 10; ++i) {
+    snakeCase = llvm::Regex("([A-Z]+)([A-Z][a-z])").sub("\\1_\\2", snakeCase);
+    snakeCase = llvm::Regex("([a-z0-9])([A-Z])").sub("\\1_\\2", snakeCase);
   }
+  std::transform(snakeCase.begin(), snakeCase.end(), snakeCase.begin(),
+                 [](unsigned char c) { return std::tolower(c); });
   return snakeCase;
 }
 
diff --git a/llvm/unittests/ADT/StringExtrasTest.cpp b/llvm/unittests/ADT/StringExtrasTest.cpp
index 3f69c91b270a355..fab562f1ed0d594 100644
--- a/llvm/unittests/ADT/StringExtrasTest.cpp
+++ b/llvm/unittests/ADT/StringExtrasTest.cpp
@@ -184,6 +184,11 @@ TEST(StringExtrasTest, ConvertToSnakeFromCamelCase) {
 
   testConvertToSnakeCase("OpName", "op_name");
   testConvertToSnakeCase("opName", "op_name");
+  testConvertToSnakeCase("OPName", "op_name");
+  testConvertToSnakeCase("opNAME", "op_name");
+  testConvertToSnakeCase("opNAMe", "op_na_me");
+  testConvertToSnakeCase("opnameE", "opname_e");
+  testConvertToSnakeCase("OPNameOPName", "op_name_op_name");
   testConvertToSnakeCase("_OpName", "_op_name");
   testConvertToSnakeCase("Op_Name", "op_name");
   testConvertToSnakeCase("", "");
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 7655629a5542520..895c3228139b392 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -13,6 +13,7 @@
     "get_default_loc_context",
     "get_op_result_or_value",
     "get_op_results_or_values",
+    "get_op_result_or_op_results",
     "segmented_accessor",
 ]
 
@@ -167,3 +168,17 @@ def get_op_results_or_values(
         return arg.results
     else:
         return [get_op_result_or_value(element) for element in arg]
+
+
+def get_op_result_or_op_results(
+    op: _Union[_cext.ir.OpView, _cext.ir.Operation],
+) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
+    if isinstance(op, _cext.ir.OpView):
+        op = op.operation
+    return (
+        list(get_op_results_or_values(op))
+        if len(op.results) > 1
+        else get_op_result_or_value(op)
+        if len(op.results) > 0
+        else op
+    )
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
index 4b0a31327abb0ee..35bd247a0a1e7f7 100644
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -7,11 +7,13 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Any, Optional, Sequence, Union
+from typing import Optional, Sequence, Union
 from ._ods_common import (
     get_op_result_or_value as _get_op_result_or_value,
     get_op_results_or_values as _get_op_results_or_values,
 )
+from .arith import constant
+from . import scf
 
 
 class ForOp:
@@ -25,7 +27,7 @@ def __init__(
         iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         """Creates an SCF `for` operation.
 
@@ -104,3 +106,37 @@ def then_block(self):
     def else_block(self):
         """Returns the else block of the if operation."""
         return self.regions[1].blocks[0]
+
+
+def range_(
+    start,
+    stop=None,
+    step=None,
+    iter_args: Optional[Sequence[Value]] = None,
+    *,
+    loc=None,
+    ip=None,
+):
+    if step is None:
+        step = 1
+    if stop is None:
+        stop = start
+        start = 0
+    params = [start, stop, step]
+    for i, p in enumerate(params):
+        if isinstance(p, int):
+            p = constant(p)
+        elif isinstance(p, float):
+            raise ValueError(f"{p=} must be int.")
+        params[i] = p
+
+    for_op = scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
+    iv = for_op.induction_variable
+    iter_args = tuple(for_op.inner_iter_args)
+    with InsertionPoint(for_op.body):
+        if len(iter_args) > 1:
+            yield iv, iter_args
+        elif len(iter_args) == 1:
+            yield iv, iter_args[0]
+        else:
+            yield iv
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index a131209fa45cb6c..be2cfc8adc78179 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -60,6 +60,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                    Optional<AnyType>:$variadic2);
 }
 
+// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
@@ -104,6 +107,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
                  Variadic<AnyType>:$variadic2);
 }
 
+// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOp(_ods_ir.OpView):
@@ -151,6 +157,9 @@ def AttributedOp : TestOp<"attributed_op"> {
                    UnitAttr:$unitAttr, I32Attr:$in);
 }
 
+// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
@@ -184,6 +193,9 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
 }
 
+// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
@@ -205,6 +217,9 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
   let results = (outs);
 }
 
+// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
   // CHECK: def __init__(self, type_, *, loc=None, ip=None):
@@ -220,6 +235,9 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
   let results = (outs AnyType:$res, AnyType);
 }
 
+// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
   // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
@@ -227,6 +245,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
   let results = (outs AnyType:$res, Variadic<AnyType>);
 }
 
+// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.empty"
@@ -241,6 +262,8 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     attributes=attributes, results=results, operands=operands,
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
+// CHECK: def empty(*, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -252,6 +275,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
   let results = (outs I32:$i32, F32:$f32);
 }
 
+// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
   // CHECK:  def __init__(self, *, loc=None, ip=None):
@@ -262,6 +288,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
   let results = (outs AnyType, AnyType, AnyType);
 }
 
+// CHECK: def infer_result_types_op(*, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.missing_names"
@@ -297,6 +326,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
 
+// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
@@ -323,9 +355,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK: @builtins.property
   // CHECK: def optional(self):
   // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
-
 }
 
+// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
@@ -355,6 +389,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
 }
 
+// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
@@ -385,6 +422,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
 }
 
+// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class PythonKeywordOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
@@ -405,6 +445,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:   return self.operation.operands[0]
   let arguments = (ins AnyType:$in);
 }
+
+// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
   // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
@@ -416,6 +460,9 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
   let results = (outs AnyType:$res);
 }
 
+// CHECK: def same_results(in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
   // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None):
@@ -423,6 +470,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
   let results = (outs Variadic<AnyType>:$res);
 }
 
+// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _ods_ir.OpResult:
+// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
@@ -447,6 +497,9 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                    Variadic<AnyType>:$variadic2);
 }
 
+// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
@@ -470,6 +523,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                  Variadic<AnyType>:$variadic2);
 }
 
+// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.simple"
@@ -507,6 +563,9 @@ def SimpleOp : TestOp<"simple"> {
   let results = (outs I64:$i64, AnyFloat:$f64);
 }
 
+// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _Sequence[_ods_ir.OpResult]:
+// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
 def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
@@ -531,6 +590,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
   // CHECK:    return self.regions[2:]
 }
 
+// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
 def VariadicRegionOp : TestOp<"variadic_region"> {
@@ -551,6 +613,9 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
   // CHECK:    return self.regions[0:]
 }
 
+// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.with_successors"
@@ -562,3 +627,6 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
   let successors = (successor AnySuccessor:$successor,
                               VariadicSuccessor<AnySuccessor>:$successors);
 }
+
+// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> _ods_ir.Operation:
+// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 8cb55fdf6a1eb3b..843b87b66871761 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -4,7 +4,7 @@
 from mlir.dialects import arith
 from mlir.dialects import func
 from mlir.dialects import scf
-from mlir.dialects import builtin
+from mlir.dialects._scf_ops_ext import range_
 
 
 def constructAndPrintInModule(f):
@@ -54,6 +54,27 @@ def induction_var(lb, ub, step):
 # CHECK: scf.yield %[[IV]]
 
 
+# CHECK-LABEL: TEST: testForSugar
+@constructAndPrintInModule
+def testForSugar():
+    index_type = IndexType.get()
+
+    @func.FuncOp.from_py_func(index_type, index_type, index_type)
+    def range_loop(lb, ub, step):
+        for i in range_(lb, ub, step):
+            add = arith.addi(i, i)
+            scf.yield_([])
+        return
+
+
+# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
+# CHECK:   scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+# CHECK:     %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK:   }
+# CHECK:   return
+# CHECK: }
+
+
 @constructAndPrintInModule
 def testOpsAsArguments():
     index_type = IndexType.get()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0b5df7ab70dddb2..1357c84099ccd5a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
 _ods_ir = _ods_cext.ir
 
 try:
@@ -39,6 +39,7 @@ except ImportError:
   _ods_ext_module = None
 
 import builtins
+from typing import Sequence as _Sequence, Union as _Union
 
 )Py";
 
@@ -260,11 +261,16 @@ constexpr const char...
[truncated]

@dwblaikie
Copy link
Collaborator

Could you factor out the libSupport/string manipulation part - smaller/independent patches are easier to manage/review/etc

@makslevental makslevental force-pushed the value_builders branch 3 times, most recently from 09b3e29 to b3a1e3d Compare October 6, 2023 20:17
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable, we can always adjust :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants