Skip to content

Commit ca051df

Browse files
authored
[mlir][utils] Add script to verify canonicalizations against Alive2 (#91867)
This script takes IR before and after canonicalization, translates it into llvm IR and converts it to format suitable for Alive2 https://alive2.llvm.org/ce/ This is primarily for arith canonicalizations verification, but technically it can be adapted for any dialect translatable to llvm. Usage `python verify_canon.py canonicalize.mlir -f func1 func2 ...` Example output: https://alive2.llvm.org/ce/z/KhQs4J Initial discussion: #91646 (review)
1 parent 69e1312 commit ca051df

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
# This script is a helper to verify canonicalization patterns using Alive2
6+
# https://alive2.llvm.org/ce/.
7+
# It performs the following steps:
8+
# - Filters out the provided test functions.
9+
# - Runs the canonicalization pass on the remaining functions.
10+
# - Lowers both the original and the canonicalized functions to LLVM IR.
11+
# - Prints the canonicalized and the original functions side-by-side in a format
12+
# that can be copied into Alive2 for verification.
13+
# Example: `python verify_canon.py canonicalize.mlir -f func1 func2 func3`
14+
15+
import subprocess
16+
import tempfile
17+
import sys
18+
from pathlib import Path
19+
from argparse import ArgumentParser
20+
21+
22+
def filter_funcs(ir, funcs):
23+
if not funcs:
24+
return ir
25+
26+
funcs_str = ",".join(funcs)
27+
return subprocess.check_output(
28+
["mlir-opt", f"--symbol-privatize=exclude={funcs_str}", "--symbol-dce"],
29+
input=ir,
30+
)
31+
32+
33+
def add_func_prefix(src, prefix):
34+
return src.replace("@", "@" + prefix)
35+
36+
37+
def merge_ir(chunks):
38+
files = []
39+
for chunk in chunks:
40+
tmp = tempfile.NamedTemporaryFile(suffix=".ll")
41+
tmp.write(chunk)
42+
tmp.flush()
43+
files.append(tmp)
44+
45+
return subprocess.check_output(["llvm-link", "-S"] + [f.name for f in files])
46+
47+
48+
if __name__ == "__main__":
49+
parser = ArgumentParser()
50+
parser.add_argument("file")
51+
parser.add_argument("-f", "--func-names", nargs="+", default=[])
52+
args = parser.parse_args()
53+
54+
file = args.file
55+
funcs = args.func_names
56+
57+
orig_ir = Path(file).read_bytes()
58+
orig_ir = filter_funcs(orig_ir, funcs)
59+
60+
to_llvm_args = ["--convert-to-llvm"]
61+
orig_args = ["mlir-opt"] + to_llvm_args
62+
canon_args = ["mlir-opt", "-canonicalize"] + to_llvm_args
63+
translate_args = ["mlir-translate", "-mlir-to-llvmir"]
64+
65+
orig = subprocess.check_output(orig_args, input=orig_ir)
66+
canonicalized = subprocess.check_output(canon_args, input=orig_ir)
67+
68+
orig = subprocess.check_output(translate_args, input=orig)
69+
canonicalized = subprocess.check_output(translate_args, input=canonicalized)
70+
71+
enc = "utf-8"
72+
orig = bytes(add_func_prefix(orig.decode(enc), "src_"), enc)
73+
canonicalized = bytes(add_func_prefix(canonicalized.decode(enc), "tgt_"), enc)
74+
75+
res = merge_ir([orig, canonicalized])
76+
77+
print(res.decode(enc))

0 commit comments

Comments
 (0)