Skip to content

Commit c1f0a77

Browse files
authored
Supports OrtValue in function ort_profile (#22)
* Supports OrtValue in function ort_profile * improves post processing of the profiler * support args_op_name
1 parent c82f9f3 commit c1f0a77

File tree

4 files changed

+147
-4
lines changed

4 files changed

+147
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ build/*
1010
.eggs/*
1111
.hypothesis/*
1212
*egg-info/*
13+
onnxruntime_profile*
14+
prof
1315
_doc/auto_examples/*
1416
_doc/examples/_cache/*
17+
_doc/examples/onnxruntime_profile*
1518
_doc/examples/plot_*.png
1619
_doc/examples/plot_*.xlsx
1720
_doc/examples/data/*.optimized.onnx

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`22`: support OrtValue in function :func:`ort_profile`
78
* :pr:`17`: implements ArrayAPI
89
* :pr:`3`: fixes Array API with onnxruntime and scikit-learn

_unittests/ut_ort/test_ort_profile.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from onnx_array_api.ext_test_case import ExtTestCase
77
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
88
from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile
9+
from onnxruntime.capi._pybind_state import (
10+
OrtValue as C_OrtValue,
11+
OrtDevice as C_OrtDevice,
12+
)
913

1014

1115
class TestOrtProfile(ExtTestCase):
@@ -28,7 +32,76 @@ def myloss(x, y):
2832
self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
2933
optimized = ort_optimized_model(onx)
3034
prof = ort_profile(optimized, feeds)
31-
prof.to_csv("prof.csv", index=False)
35+
self.assertIsInstance(prof, DataFrame)
36+
prof = ort_profile(optimized, feeds, as_df=False)
37+
self.assertIsInstance(prof, list)
38+
39+
def test_ort_profile_first_it_out(self):
40+
def l1_loss(x, y):
41+
return absolute(x - y).sum()
42+
43+
def l2_loss(x, y):
44+
return ((x - y) ** 2).sum()
45+
46+
def myloss(x, y):
47+
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
48+
49+
jitted_myloss = jit_onnx(myloss)
50+
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
51+
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
52+
jitted_myloss(x, y)
53+
onx = jitted_myloss.get_onnx()
54+
feeds = {"x0": x, "x1": y}
55+
self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
56+
optimized = ort_optimized_model(onx)
57+
prof = ort_profile(optimized, feeds)
58+
events = {
59+
"kernel_time",
60+
"fence_before",
61+
"fence_after",
62+
"SequentialExecutor::Execute",
63+
"model_run",
64+
"model_loading_array",
65+
"session_initialization",
66+
}
67+
self.assertEqual(set(prof["event_name"]), events)
68+
agg = ort_profile(optimized, feeds, first_it_out=True, agg=True)
69+
self.assertIsInstance(agg, DataFrame)
70+
self.assertLess(agg.shape[0], prof.shape[0])
71+
self.assertEqual(set(agg.reset_index(drop=False)["event_name"]), events)
72+
agg = ort_profile(
73+
optimized, feeds, first_it_out=True, agg=True, agg_op_name=False
74+
)
75+
self.assertIsInstance(agg, DataFrame)
76+
self.assertLess(agg.shape[0], prof.shape[0])
77+
self.assertEqual(set(agg.reset_index(drop=False)["event_name"]), events)
78+
79+
def test_ort_profile_ort_value(self):
80+
def to_ort_value(m):
81+
device = C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
82+
ort_value = C_OrtValue.ortvalue_from_numpy(m, device)
83+
return ort_value
84+
85+
def l1_loss(x, y):
86+
return absolute(x - y).sum()
87+
88+
def l2_loss(x, y):
89+
return ((x - y) ** 2).sum()
90+
91+
def myloss(x, y):
92+
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
93+
94+
jitted_myloss = jit_onnx(myloss)
95+
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
96+
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
97+
jitted_myloss(x, y)
98+
onx = jitted_myloss.get_onnx()
99+
np_feeds = {"x0": x, "x1": y}
100+
feeds = {k: to_ort_value(v) for k, v in np_feeds.items()}
101+
102+
self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
103+
optimized = ort_optimized_model(onx)
104+
prof = ort_profile(optimized, feeds)
32105
self.assertIsInstance(prof, DataFrame)
33106
prof = ort_profile(optimized, feeds, as_df=False)
34107
self.assertIsInstance(prof, list)

onnx_array_api/ort/ort_profile.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,56 @@
66
from pandas import DataFrame
77

88

9+
def post_process_df_profile(
10+
df: DataFrame,
11+
first_it_out: bool = False,
12+
agg: bool = False,
13+
agg_op_name: bool = True,
14+
) -> DataFrame:
15+
"""
16+
Post-processed a dataframe obtained after profiling onnxruntime.
17+
It adds a column for a more explicit event name and adds
18+
a column for the iteration number
19+
20+
:param agg: aggregate the result
21+
:param first_it_out: leave the first iteration
22+
out of the aggregation
23+
:param agg_op_name: aggregate on operator name or operator index
24+
:return: DataFrame
25+
"""
26+
events = {"kernel_time", "fence_after", "fence_before"}
27+
28+
def sep_event(s):
29+
for e in events:
30+
if s.endswith(e):
31+
return e
32+
return s
33+
34+
df = df.copy()
35+
df["event_name"] = df["name"].apply(sep_event)
36+
df["iteration"] = -1
37+
current = -1
38+
for i in range(df.shape[0]):
39+
if df.loc[i, "name"] == "SequentialExecutor::Execute":
40+
current += 1
41+
df.loc[i, "iteration"] = current
42+
43+
if not agg:
44+
return df
45+
46+
agg_cols = ["cat", "args_node_index", "args_op_name", "args_provider", "event_name"]
47+
if first_it_out:
48+
df["it==0"] = (df["iteration"] <= 0).astype(int)
49+
agg_cols.insert(0, "it==0")
50+
if agg_op_name:
51+
del agg_cols[agg_cols.index("args_node_index")]
52+
for c in agg_cols:
53+
df[c] = df[c].fillna("")
54+
df["dur"] = df["dur"].fillna(0)
55+
agg = df[agg_cols + ["dur"]].groupby(agg_cols).sum()
56+
return agg
57+
58+
959
def ort_profile(
1060
filename_or_bytes: Union[str, bytes, ModelProto],
1161
feeds: Dict[str, numpy.ndarray],
@@ -14,6 +64,9 @@ def ort_profile(
1464
repeat: int = 10,
1565
as_df: bool = True,
1666
providers: Optional[List[str]] = None,
67+
first_it_out: bool = False,
68+
agg: bool = False,
69+
agg_op_name: bool = False,
1770
**kwargs,
1871
) -> Union[List, DataFrame]:
1972
"""
@@ -27,6 +80,9 @@ def ort_profile(
2780
:param as_df: returns the
2881
:param providers: list of providers to use when initializing the inference session,
2982
if None, the default value is `["CPUExecutionProvider"]`
83+
:param first_it_out: if aggregated, leaves the first iteration out
84+
:param agg: aggregate by event
85+
:param agg_op_name: aggregate on operator name or operator index
3086
:param kwargs: additional parameters when initializing the inference session
3187
:return: DataFrame or dictionary
3288
"""
@@ -45,8 +101,16 @@ def ort_profile(
45101
if providers is None:
46102
providers = ["CPUExecutionProvider"]
47103
sess = InferenceSession(obj, sess_options, providers=providers, **kwargs)
48-
for i in range(repeat):
49-
sess.run(None, feeds)
104+
first = list(feeds.values())[0]
105+
106+
if isinstance(first, numpy.ndarray):
107+
for i in range(repeat):
108+
sess.run(None, feeds)
109+
else:
110+
out_names = [o.name for o in sess.get_outputs()]
111+
for i in range(repeat):
112+
sess._sess.run_with_ort_values(feeds, out_names, None)
113+
50114
prof = sess.end_profiling()
51115
with open(prof, "r") as f:
52116
content = f.read()
@@ -68,7 +132,9 @@ def ort_profile(
68132
break
69133
rows.append(row)
70134
if as_df:
71-
return DataFrame(rows)
135+
return post_process_df_profile(
136+
DataFrame(rows), first_it_out=first_it_out, agg=agg, agg_op_name=agg_op_name
137+
)
72138
return rows
73139

74140

0 commit comments

Comments
 (0)