Skip to content

Commit 95a012f

Browse files
authored
Adds command line to translate de model into code (#49)
* Adds command line to translate de model into code * doc
1 parent 5e3668d commit 95a012f

File tree

4 files changed

+174
-0
lines changed

4 files changed

+174
-0
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.1.3
55
+++++
66

7+
* :pr:`49`: adds command line to export a model into code
78
* :pr:`48`: support for subgraph in light API
89
* :pr:`47`: extends export onnx to code to support inner API
910
* :pr:`46`: adds an export to convert an onnx graph into light API code
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
from contextlib import redirect_stdout
5+
from io import StringIO
6+
from onnx import TensorProto
7+
from onnx.helper import (
8+
make_graph,
9+
make_model,
10+
make_node,
11+
make_opsetid,
12+
make_tensor_value_info,
13+
)
14+
from onnx_array_api.ext_test_case import ExtTestCase
15+
from onnx_array_api._command_lines_parser import (
16+
get_main_parser,
17+
get_parser_translate,
18+
main,
19+
)
20+
21+
22+
class TestCommandLines1(ExtTestCase):
23+
def test_main_parser(self):
24+
st = StringIO()
25+
with redirect_stdout(st):
26+
get_main_parser().print_help()
27+
text = st.getvalue()
28+
self.assertIn("translate", text)
29+
30+
def test_parser_translate(self):
31+
st = StringIO()
32+
with redirect_stdout(st):
33+
get_parser_translate().print_help()
34+
text = st.getvalue()
35+
self.assertIn("model", text)
36+
37+
def test_command_translate(self):
38+
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
39+
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
40+
Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None, None])
41+
graph = make_graph(
42+
[
43+
make_node("Add", ["X", "Y"], ["res"]),
44+
make_node("Cos", ["res"], ["Z"]),
45+
],
46+
"g",
47+
[X, Y],
48+
[Z],
49+
)
50+
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)])
51+
52+
with tempfile.TemporaryDirectory() as root:
53+
model_file = os.path.join(root, "model.onnx")
54+
with open(model_file, "wb") as f:
55+
f.write(onnx_model.SerializeToString())
56+
57+
args = ["translate", "-m", model_file]
58+
st = StringIO()
59+
with redirect_stdout(st):
60+
main(args)
61+
62+
code = st.getvalue()
63+
self.assertIn("model = make_model(", code)
64+
65+
args = ["translate", "-m", model_file, "-a", "light"]
66+
st = StringIO()
67+
with redirect_stdout(st):
68+
main(args)
69+
70+
code = st.getvalue()
71+
self.assertIn("start(opset=", code)
72+
73+
74+
if __name__ == "__main__":
75+
unittest.main(verbosity=2)

onnx_array_api/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ._command_lines_parser import main
2+
3+
if __name__ == "__main__":
4+
main()
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import sys
2+
import onnx
3+
from typing import Any, List, Optional
4+
from argparse import ArgumentParser
5+
from textwrap import dedent
6+
7+
8+
def get_main_parser() -> ArgumentParser:
9+
parser = ArgumentParser(
10+
prog="onnx-array-api",
11+
description="onnx-array-api main command line.",
12+
epilog="Type 'python -m onnx_array_api <cmd> --help' "
13+
"to get help for a specific command.",
14+
)
15+
parser.add_argument(
16+
"cmd",
17+
choices=["translate"],
18+
help=dedent(
19+
"""
20+
Selects a command.
21+
22+
'translate' exports an onnx graph into a piece of code replicating it.
23+
"""
24+
),
25+
)
26+
return parser
27+
28+
29+
def get_parser_translate() -> ArgumentParser:
30+
parser = ArgumentParser(
31+
prog="translate",
32+
description=dedent(
33+
"""
34+
Translates an onnx model into a piece of code to replicate it.
35+
The result is printed on the standard output.
36+
"""
37+
),
38+
epilog="This is mostly used to write unit tests without adding "
39+
"an onnx file to the repository.",
40+
)
41+
parser.add_argument(
42+
"-m",
43+
"--model",
44+
type=str,
45+
required=True,
46+
help="onnx model to translate",
47+
)
48+
parser.add_argument(
49+
"-a",
50+
"--api",
51+
choices=["onnx", "light"],
52+
default="onnx",
53+
help="API to choose, API from onnx package or light API.",
54+
)
55+
return parser
56+
57+
58+
def _cmd_translate(argv: List[Any]):
59+
from .light_api import translate
60+
61+
parser = get_parser_translate()
62+
args = parser.parse_args(argv[1:])
63+
onx = onnx.load(args.model)
64+
code = translate(onx, api=args.api)
65+
print(code)
66+
67+
68+
def main(argv: Optional[List[Any]] = None):
69+
fcts = dict(translate=_cmd_translate)
70+
71+
if argv is None:
72+
argv = sys.argv[1:]
73+
if (len(argv) <= 1 and argv[0] not in fcts) or argv[-1] in ("--help", "-h"):
74+
if len(argv) < 2:
75+
parser = get_main_parser()
76+
parser.parse_args(argv)
77+
else:
78+
parsers = dict(translate=get_parser_translate)
79+
cmd = argv[0]
80+
if cmd not in parsers:
81+
raise ValueError(
82+
f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
83+
)
84+
parser = parsers[cmd]()
85+
parser.parse_args(argv[1:])
86+
raise RuntimeError("The programme should have exited before.")
87+
88+
cmd = argv[0]
89+
if cmd in fcts:
90+
fcts[cmd](argv)
91+
else:
92+
raise ValueError(
93+
f"Unknown command {cmd!r}, use --help to get the list of known command."
94+
)

0 commit comments

Comments
 (0)