Skip to content

Commit eb861ac

Browse files
authored
[mlir][python] Enable python bindings for Index dialect (llvm#85827)
This small patch enables python bindings for the index dialect. --------- Co-authored-by: Steven Varoumas <[email protected]>
1 parent d209d13 commit eb861ac

File tree

4 files changed

+264
-0
lines changed

4 files changed

+264
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ declare_mlir_dialect_python_bindings(
108108
dialects/complex.py
109109
DIALECT_NAME complex)
110110

111+
declare_mlir_dialect_python_bindings(
112+
ADD_TO_PARENT MLIRPythonSources.Dialects
113+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
114+
TD_FILE dialects/IndexOps.td
115+
SOURCES
116+
dialects/index.py
117+
DIALECT_NAME index
118+
GEN_ENUM_BINDINGS)
119+
111120
declare_mlir_dialect_python_bindings(
112121
ADD_TO_PARENT MLIRPythonSources.Dialects
113122
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

mlir/python/mlir/dialects/IndexOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- IndexOps.td - Entry point for Index bindings -----*- tablegen -*---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_INDEX_OPS
10+
#define PYTHON_BINDINGS_INDEX_OPS
11+
12+
include "mlir/Dialect/Index/IR/IndexOps.td"
13+
14+
#endif

mlir/python/mlir/dialects/index.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._index_ops_gen import *
6+
from ._index_enum_gen import *
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import index, arith
5+
6+
7+
def run(f):
8+
print("\nTEST:", f.__name__)
9+
with Context() as ctx, Location.unknown():
10+
module = Module.create()
11+
with InsertionPoint(module.body):
12+
f(ctx)
13+
print(module)
14+
15+
16+
# CHECK-LABEL: TEST: testConstantOp
17+
@run
18+
def testConstantOp(ctx):
19+
a = index.ConstantOp(value=42)
20+
# CHECK: %{{.*}} = index.constant 42
21+
22+
23+
# CHECK-LABEL: TEST: testBoolConstantOp
24+
@run
25+
def testBoolConstantOp(ctx):
26+
a = index.BoolConstantOp(value=True)
27+
# CHECK: %{{.*}} = index.bool.constant true
28+
29+
30+
# CHECK-LABEL: TEST: testAndOp
31+
@run
32+
def testAndOp(ctx):
33+
a = index.ConstantOp(value=42)
34+
r = index.AndOp(a, a)
35+
# CHECK: %{{.*}} = index.and %{{.*}}, %{{.*}}
36+
37+
38+
# CHECK-LABEL: TEST: testOrOp
39+
@run
40+
def testOrOp(ctx):
41+
a = index.ConstantOp(value=42)
42+
r = index.OrOp(a, a)
43+
# CHECK: %{{.*}} = index.or %{{.*}}, %{{.*}}
44+
45+
46+
# CHECK-LABEL: TEST: testXOrOp
47+
@run
48+
def testXOrOp(ctx):
49+
a = index.ConstantOp(value=42)
50+
r = index.XOrOp(a, a)
51+
# CHECK: %{{.*}} = index.xor %{{.*}}, %{{.*}}
52+
53+
54+
# CHECK-LABEL: TEST: testCastSOp
55+
@run
56+
def testCastSOp(ctx):
57+
a = index.ConstantOp(value=42)
58+
b = arith.ConstantOp(value=23, result=IntegerType.get_signless(64))
59+
c = index.CastSOp(input=a, output=IntegerType.get_signless(32))
60+
d = index.CastSOp(input=b, output=IndexType.get())
61+
# CHECK: %{{.*}} = index.casts %{{.*}} : index to i32
62+
# CHECK: %{{.*}} = index.casts %{{.*}} : i64 to index
63+
64+
65+
# CHECK-LABEL: TEST: testCastUOp
66+
@run
67+
def testCastUOp(ctx):
68+
a = index.ConstantOp(value=42)
69+
b = arith.ConstantOp(value=23, result=IntegerType.get_signless(64))
70+
c = index.CastUOp(input=a, output=IntegerType.get_signless(32))
71+
d = index.CastUOp(input=b, output=IndexType.get())
72+
# CHECK: %{{.*}} = index.castu %{{.*}} : index to i32
73+
# CHECK: %{{.*}} = index.castu %{{.*}} : i64 to index
74+
75+
76+
# CHECK-LABEL: TEST: testCeilDivSOp
77+
@run
78+
def testCeilDivSOp(ctx):
79+
a = index.ConstantOp(value=42)
80+
r = index.CeilDivSOp(a, a)
81+
# CHECK: %{{.*}} = index.ceildivs %{{.*}}, %{{.*}}
82+
83+
84+
# CHECK-LABEL: TEST: testCeilDivUOp
85+
@run
86+
def testCeilDivUOp(ctx):
87+
a = index.ConstantOp(value=42)
88+
r = index.CeilDivUOp(a, a)
89+
# CHECK: %{{.*}} = index.ceildivu %{{.*}}, %{{.*}}
90+
91+
92+
# CHECK-LABEL: TEST: testCmpOp
93+
@run
94+
def testCmpOp(ctx):
95+
a = index.ConstantOp(value=42)
96+
b = index.ConstantOp(value=23)
97+
pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
98+
r = index.CmpOp(pred, lhs=a, rhs=b)
99+
# CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})
100+
101+
102+
# CHECK-LABEL: TEST: testAddOp
103+
@run
104+
def testAddOp(ctx):
105+
a = index.ConstantOp(value=42)
106+
r = index.AddOp(a, a)
107+
# CHECK: %{{.*}} = index.add %{{.*}}, %{{.*}}
108+
109+
110+
# CHECK-LABEL: TEST: testSubOp
111+
@run
112+
def testSubOp(ctx):
113+
a = index.ConstantOp(value=42)
114+
r = index.SubOp(a, a)
115+
# CHECK: %{{.*}} = index.sub %{{.*}}, %{{.*}}
116+
117+
118+
# CHECK-LABEL: TEST: testMulOp
119+
@run
120+
def testMulOp(ctx):
121+
a = index.ConstantOp(value=42)
122+
r = index.MulOp(a, a)
123+
# CHECK: %{{.*}} = index.mul %{{.*}}, %{{.*}}
124+
125+
126+
# CHECK-LABEL: TEST: testDivSOp
127+
@run
128+
def testDivSOp(ctx):
129+
a = index.ConstantOp(value=42)
130+
r = index.DivSOp(a, a)
131+
# CHECK: %{{.*}} = index.divs %{{.*}}, %{{.*}}
132+
133+
134+
# CHECK-LABEL: TEST: testDivUOp
135+
@run
136+
def testDivUOp(ctx):
137+
a = index.ConstantOp(value=42)
138+
r = index.DivUOp(a, a)
139+
# CHECK: %{{.*}} = index.divu %{{.*}}, %{{.*}}
140+
141+
142+
# CHECK-LABEL: TEST: testFloorDivSOp
143+
@run
144+
def testFloorDivSOp(ctx):
145+
a = index.ConstantOp(value=42)
146+
r = index.FloorDivSOp(a, a)
147+
# CHECK: %{{.*}} = index.floordivs %{{.*}}, %{{.*}}
148+
149+
150+
# CHECK-LABEL: TEST: testMaxSOp
151+
@run
152+
def testMaxSOp(ctx):
153+
a = index.ConstantOp(value=42)
154+
b = index.ConstantOp(value=23)
155+
r = index.MaxSOp(a, b)
156+
# CHECK: %{{.*}} = index.maxs %{{.*}}, %{{.*}}
157+
158+
159+
# CHECK-LABEL: TEST: testMaxUOp
160+
@run
161+
def testMaxUOp(ctx):
162+
a = index.ConstantOp(value=42)
163+
b = index.ConstantOp(value=23)
164+
r = index.MaxUOp(a, b)
165+
# CHECK: %{{.*}} = index.maxu %{{.*}}, %{{.*}}
166+
167+
168+
# CHECK-LABEL: TEST: testMinSOp
169+
@run
170+
def testMinSOp(ctx):
171+
a = index.ConstantOp(value=42)
172+
b = index.ConstantOp(value=23)
173+
r = index.MinSOp(a, b)
174+
# CHECK: %{{.*}} = index.mins %{{.*}}, %{{.*}}
175+
176+
177+
# CHECK-LABEL: TEST: testMinUOp
178+
@run
179+
def testMinUOp(ctx):
180+
a = index.ConstantOp(value=42)
181+
b = index.ConstantOp(value=23)
182+
r = index.MinUOp(a, b)
183+
# CHECK: %{{.*}} = index.minu %{{.*}}, %{{.*}}
184+
185+
186+
# CHECK-LABEL: TEST: testRemSOp
187+
@run
188+
def testRemSOp(ctx):
189+
a = index.ConstantOp(value=42)
190+
b = index.ConstantOp(value=23)
191+
r = index.RemSOp(a, b)
192+
# CHECK: %{{.*}} = index.rems %{{.*}}, %{{.*}}
193+
194+
195+
# CHECK-LABEL: TEST: testRemUOp
196+
@run
197+
def testRemUOp(ctx):
198+
a = index.ConstantOp(value=42)
199+
b = index.ConstantOp(value=23)
200+
r = index.RemUOp(a, b)
201+
# CHECK: %{{.*}} = index.remu %{{.*}}, %{{.*}}
202+
203+
204+
# CHECK-LABEL: TEST: testShlOp
205+
@run
206+
def testShlOp(ctx):
207+
a = index.ConstantOp(value=42)
208+
b = index.ConstantOp(value=3)
209+
r = index.ShlOp(a, b)
210+
# CHECK: %{{.*}} = index.shl %{{.*}}, %{{.*}}
211+
212+
213+
# CHECK-LABEL: TEST: testShrSOp
214+
@run
215+
def testShrSOp(ctx):
216+
a = index.ConstantOp(value=42)
217+
b = index.ConstantOp(value=3)
218+
r = index.ShrSOp(a, b)
219+
# CHECK: %{{.*}} = index.shrs %{{.*}}, %{{.*}}
220+
221+
222+
# CHECK-LABEL: TEST: testShrUOp
223+
@run
224+
def testShrUOp(ctx):
225+
a = index.ConstantOp(value=42)
226+
b = index.ConstantOp(value=3)
227+
r = index.ShrUOp(a, b)
228+
# CHECK: %{{.*}} = index.shru %{{.*}}, %{{.*}}
229+
230+
231+
# CHECK-LABEL: TEST: testSizeOfOp
232+
@run
233+
def testSizeOfOp(ctx):
234+
r = index.SizeOfOp()
235+
# CHECK: %{{.*}} = index.sizeof

0 commit comments

Comments
 (0)