Skip to content

Commit f49df3e

Browse files
sdasgup3TensorFlow MLIR Team
authored and
TensorFlow MLIR Team
committed
Integrate StableHLO at openxla/stablehlo@38bb2f9b
PiperOrigin-RevId: 708389837
1 parent da1c36c commit f49df3e

29 files changed

+791
-49
lines changed

stablehlo/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,21 @@ gentbl_cc_library(
407407
],
408408
)
409409

410+
gentbl_cc_library(
411+
name = "stablehlo_complex_math_expander_inc_gen",
412+
tbl_outs = [
413+
(
414+
["--gen-rewriters"],
415+
"stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc",
416+
),
417+
],
418+
tblgen = "@llvm-project//mlir:mlir-tblgen",
419+
td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td",
420+
deps = [
421+
":stablehlo_ops_td_files",
422+
],
423+
)
424+
410425
cc_test(
411426
name = "example_add",
412427
srcs = [
@@ -1140,6 +1155,7 @@ cc_library(
11401155
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
11411156
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
11421157
"stablehlo/transforms/StablehloCompatibilityExpander.cpp",
1158+
"stablehlo/transforms/StablehloComplexMathExpander.cpp",
11431159
"stablehlo/transforms/StablehloConvertToSignless.cpp",
11441160
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
11451161
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
@@ -1168,6 +1184,7 @@ cc_library(
11681184
":linalg_passes",
11691185
":stablehlo_aggressive_simplification_inc_gen",
11701186
":stablehlo_compatibility_expander_inc_gen",
1187+
":stablehlo_complex_math_expander_inc_gen",
11711188
":stablehlo_legalize_deprecated_ops_inc_gen",
11721189
":stablehlo_ops",
11731190
":stablehlo_ops_inc_gen",

stablehlo/BUILD.bazel

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ cc_library(
276276
":chlo_attrs_inc_gen",
277277
":chlo_enums_inc_gen",
278278
":chlo_ops_inc_gen",
279+
":stablehlo_assembly_format",
279280
":stablehlo_type_inference",
280281
"@llvm-project//llvm:Support",
281282
"@llvm-project//mlir:BytecodeOpInterface",
@@ -370,6 +371,21 @@ gentbl_cc_library(
370371
],
371372
)
372373

374+
gentbl_cc_library(
375+
name = "stablehlo_create_complex_math_expander_inc_gen",
376+
tbl_outs = [
377+
(
378+
["--gen-rewriters"],
379+
"stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc",
380+
),
381+
],
382+
tblgen = "@llvm-project//mlir:mlir-tblgen",
383+
td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td",
384+
deps = [
385+
":stablehlo_ops_td_files",
386+
],
387+
)
388+
373389
cc_library(
374390
name = "interpreter_ops",
375391
srcs = [
@@ -1120,6 +1136,7 @@ cc_library(
11201136
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
11211137
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
11221138
"stablehlo/transforms/StablehloCompatibilityExpander.cpp",
1139+
"stablehlo/transforms/StablehloComplexMathExpander.cpp",
11231140
"stablehlo/transforms/StablehloConvertToSignless.cpp",
11241141
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
11251142
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
@@ -1148,6 +1165,7 @@ cc_library(
11481165
":linalg_passes",
11491166
":stablehlo_aggressive_simplification_inc_gen",
11501167
":stablehlo_create_compatibility_expander_inc_gen",
1168+
":stablehlo_create_complex_math_expander_inc_gen",
11511169
":stablehlo_legalize_deprecated_ops_inc_gen",
11521170
":stablehlo_ops",
11531171
":stablehlo_ops_inc_gen",

stablehlo/WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "0876c11ceeb093904decc4d89bef213d483a5656"
20+
LLVM_COMMIT = "e86910337f98e57f5b9253f7d80d5b916eb1d97e"
2121

22-
LLVM_SHA256 = "8379577a71645bbba89dea08beba32b3e56b833da7340ba5be7efa3986c8f8ed"
22+
LLVM_SHA256 = "4ca0eff0ca86ed6f2fdb7682354fdf4c85151d90ac9fb6e55a868e4191359e9f"
2323

2424
http_archive(
2525
name = "llvm-raw",

stablehlo/build_tools/math/README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ following requirements:
3131

3232
- Python 3.11 or newer
3333
- mpmath 1.3 or newer
34-
- functional_algorithms 0.11.1 or newer
34+
- functional_algorithms 0.12 or newer
3535

3636
that can be installed via pypi:
3737

@@ -62,7 +62,7 @@ To execute generated tests from a `build` directory, use:
6262

6363
```sh
6464
for t in $(ls ../stablehlo/tests/math/*.mlir); \
65-
do echo $t && ( bin/stablehlo-opt --chlo-legalize-to-stablehlo $t \
65+
do echo $t && ( bin/stablehlo-opt --stablehlo-complex-math-expander --chlo-legalize-to-stablehlo $t \
6666
| bin/stablehlo-translate --interpret 2>&1 | grep "^ULP difference" ) ; done
6767
```
6868

@@ -77,6 +77,14 @@ build/bin/stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify
7777

7878
and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.
7979

80+
A similar procedure is applied for updating
81+
`stablehlo/tests/stablehlo_complex_math_expander.mlir`:
82+
83+
```sh
84+
build/bin/stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics \
85+
stablehlo/tests/stablehlo_complex_math_expander.mlir | python llvm-project/mlir/utils/generate-test-checks.py | less
86+
```
87+
8088
## A procedure for adding a new algorithm to an existing operation
8189

8290
1. Implement a new algorithm in
@@ -98,6 +106,10 @@ and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.
98106
7. Add a record of the operation to
99107
`generate_ChloDecompositionPatternsMath.py`, see the for-loop in
100108
`main` function.
109+
- If the operation is a StableHLO operation on complex inputs, add
110+
it to `stable-complex-math-expander` pass: update
111+
`populateStablehloComplexMathExpanderPatterns` function in
112+
`stablehlo/transforms/StablehloComplexMathExpander.cpp`.
101113
8. Generate new implementations by running
102114
`generate_ChloDecompositionPatternsMath.py` and remove existing
103115
implementations in

stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_functional_algorithms_required_version():
4444
)
4545

4646

47-
def main():
47+
def main(kind="CHLO"):
4848
try:
4949
import functional_algorithms as fa
5050
except ImportError as msg:
@@ -64,6 +64,11 @@ def main():
6464
warnings.warn(msg)
6565
return
6666

67+
output_filename = dict(
68+
CHLO="ChloDecompositionPatternsMath.td",
69+
StableHLO="StablehloComplexMathExpanderPatterns.td",
70+
)[kind]
71+
6772
output_file = os.path.relpath(
6873
os.path.normpath(
6974
os.path.join(
@@ -72,8 +77,9 @@ def main():
7277
"..",
7378
"stablehlo",
7479
"transforms",
75-
"ChloDecompositionPatternsMath.td",
76-
)),
80+
output_filename,
81+
)
82+
),
7783
os.getcwd(),
7884
)
7985

@@ -98,7 +104,10 @@ def main():
98104
("CHLO_AtanhOp", "complex_atanh", ("z:complex",)),
99105
("CHLO_SquareOp", "complex_square", ("z:complex",)),
100106
("CHLO_SquareOp", "real_square", ("x:float",)),
107+
("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)),
101108
]:
109+
if not chloname.startswith(kind):
110+
continue
102111
print(f'Generating {chloname} from {fname}{args}')
103112
func = getattr(fa.algorithms, fname, None)
104113
if func is None:
@@ -115,6 +124,17 @@ def main():
115124
sources[-1] += src
116125
source = "\n\n".join(sources) + "\n"
117126

127+
if chloname.startswith("StableHLO_"):
128+
# an ugly hack to fix the definition of stablehlo complex math
129+
# functions. TODO(pearu): add the corresponding feature to
130+
# functional_algorithms stablehlo printer
131+
NameOp = chloname.split("_", 1)[1]
132+
source = source.replace(
133+
f"def : Pat<({chloname}",
134+
f"def {NameOp}_ComplexElementType_ComplexMathExpander :"
135+
f" Pat<({chloname}",
136+
)
137+
118138
if os.path.isfile(output_file):
119139
f = open(output_file, "r")
120140
content = f.read()
@@ -146,10 +166,32 @@ def main():
146166
147167
This file is generated using functional_algorithms tool ({fa.__version__}).
148168
See build_tools/math/README.md for more information.""") + "\n")
169+
170+
if kind == "StableHLO":
171+
f.write("""\
172+
include "mlir/IR/OpBase.td"
173+
include "stablehlo/dialect/StablehloOps.td"
174+
175+
class StableHLO_ComparisonDirectionValue<string enumStr> :
176+
ConstantAttr<StableHLO_ComparisonDirectionAttr,
177+
"::mlir::stablehlo::ComparisonDirection::" # enumStr>;
178+
179+
class StableHLO_ConstantLike<string value> : NativeCodeCall<
180+
"::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
181+
182+
def ComplexElementType : Type<
183+
CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
184+
"Complex element type">;
185+
186+
def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
187+
"::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
188+
189+
""")
149190
f.write(source)
150191
f.close()
151192
print(f"Created {output_file}")
152193

153194

154195
if __name__ == "__main__":
155-
main()
196+
main(kind="CHLO")
197+
main(kind="StableHLO")

stablehlo/build_tools/math/generate_tests.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,33 @@
4343
default_max_ulp_difference = 1
4444

4545
operations = [
46-
# The following dictionaries may have additional keys like
47-
#
48-
# size - defines the number of samples: size ** 2
49-
#
50-
# max_ulp_difference - the maximal allowed ULP difference between
51-
# function and reference values
52-
#
53-
# extra_prec_multiplier - the precison multiplier for mpmath.mp
54-
# that defines the precision of computing reference values:
55-
# mpmath.mp.prec * extra_prec_multiplier
56-
#
57-
# When unspecifed, these parameters are retrieved from
58-
# functional_algorithms database of support functions.
59-
#
60-
dict(name="asin", mpmath_name="arcsin"),
61-
dict(name="acos", mpmath_name="arccos"),
62-
dict(name="atan", mpmath_name="arctan"),
63-
dict(name="asinh", mpmath_name="arcsinh"),
64-
dict(name="acosh", mpmath_name="arccosh"),
65-
dict(name="atanh", mpmath_name="arctanh"),
66-
dict(name="square", mpmath_name="square"),
46+
# The following dictionaries may have additional keys like
47+
#
48+
# size - defines the number of samples: size ** 2
49+
#
50+
# max_ulp_difference - the maximal allowed ULP difference between
51+
# function and reference values
52+
#
53+
# extra_prec_multiplier - the precison multiplier for mpmath.mp
54+
# that defines the precision of computing reference values:
55+
# mpmath.mp.prec * extra_prec_multiplier
56+
#
57+
# When unspecifed, these parameters are retrieved from
58+
# functional_algorithms database of support functions.
59+
#
60+
dict(name="asin", mpmath_name="arcsin"),
61+
dict(name="acos", mpmath_name="arccos"),
62+
dict(name="atan", mpmath_name="arctan"),
63+
dict(name="asinh", mpmath_name="arcsinh"),
64+
dict(name="acosh", mpmath_name="arccosh"),
65+
dict(name="atanh", mpmath_name="arctanh"),
66+
dict(name="square", mpmath_name="square"),
67+
dict(
68+
name="log_plus_one",
69+
mpmath_name="log1p",
70+
namespace="stablehlo",
71+
passes="--stablehlo-complex-math-expander",
72+
),
6773
]
6874

6975

@@ -127,19 +133,24 @@ def main():
127133
for op in operations:
128134
opname = op["name"]
129135
mpmath_opname = op.get("mpmath_name", opname)
136+
namespace = op.get("namespace", "chlo")
130137
size_re = size_im = op.get("size", default_size)
131-
138+
passes = op.get("passes", "--chlo-legalize-to-stablehlo")
132139
for dtype in [np.complex64, np.complex128, np.float32, np.float64]:
133140
params = fa.utils.function_validation_parameters(opname, dtype)
134141
max_ulp_difference = op.get(
135-
"max_ulp_difference",
136-
params.get("max_valid_ulp_count", default_max_ulp_difference))
142+
"max_ulp_difference",
143+
params.get("max_valid_ulp_count", default_max_ulp_difference),
144+
)
137145

138146
nmp = fa.utils.numpy_with_mpmath(
139-
extra_prec_multiplier = op.get(
140-
"extra_prec_multiplier",
141-
params.get("extra_prec_multiplier", default_extra_prec_multiplier)),
142-
flush_subnormals=flush_subnormals,
147+
extra_prec_multiplier=op.get(
148+
"extra_prec_multiplier",
149+
params.get(
150+
"extra_prec_multiplier", default_extra_prec_multiplier
151+
),
152+
),
153+
flush_subnormals=flush_subnormals,
143154
)
144155

145156
fi = np.finfo(dtype)
@@ -180,7 +191,7 @@ def main():
180191
main_func = m.make_function("main", "", "", "public")
181192

182193
ref_samples = main_func.call("samples")
183-
actual = main_func.composite(f"chlo.{opname}", ref_samples)
194+
actual = main_func.composite(f"{namespace}.{opname}", ref_samples)
184195
expected = main_func.call("expected")
185196

186197
main_func.void_call(
@@ -202,8 +213,10 @@ def main():
202213
continue
203214

204215
f = open(fname, "w")
205-
f.write("// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |"
206-
" stablehlo-translate --interpret\n")
216+
f.write(
217+
f"// RUN: stablehlo-opt {passes} %s |"
218+
" stablehlo-translate --interpret\n"
219+
)
207220
f.write(
208221
"// This file is generated, see build_tools/math/README.md for more"
209222
" information.\n")

stablehlo/docs/generated/stablehlo_passes.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,33 @@ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
8686
```
8787
-target : The target version. Must be a version of the form #.#.#.
8888
```
89+
### `-stablehlo-complex-math-expander`
90+
91+
_Expander for StableHLO complex math operations._
92+
93+
StableHLO complex math operations are decompositions using
94+
StableHLO real math operations.
95+
96+
This statement is based on the assumption that no hardware exists
97+
that supports complex numbers nor complex math operations
98+
natively. This means that the fallback mechanisms on complex math
99+
operations that compilers may implement, are redundant. With
100+
enabling this pass, all StableHLO complex math operations will be
101+
expanded.
102+
103+
```mlir
104+
func.func @sqrt_op_complex(%arg0: tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
105+
%1 = stablehlo.sqrt %arg0 : tensor<4xcomplex<f64>>
106+
func.return %1 : tensor<4xcomplex<f64>>
107+
}
108+
109+
==>
110+
111+
func.func @sqrt_op_complex(%arg0: tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
112+
TBD
113+
return %2 : tensor<4xcomplex<f64>>
114+
}
115+
```
89116
### `-stablehlo-convert-to-signless`
90117

91118
_Pass to transform the IR to be on signless integers._

stablehlo/docs/spec.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4066,7 +4066,8 @@ Performs element-wise logarithm plus one operation on `operand` tensor and
40664066
produces a `result` tensor. Depending on the element type, does the following:
40674067

40684068
* For floats: `logp1` from IEEE-754.
4069-
* For complex numbers: complex logarithm plus one.
4069+
* For complex numbers:
4070+
`complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))`
40704071
* For quantized types:
40714072
`dequantize_op_quantize(log_plus_one, operand, type(result))`.
40724073

0 commit comments

Comments
 (0)