Skip to content

Commit d6fe2e1

Browse files
authored
Merge pull request pytorch#1 from zhuhaozhe/hz/add_inductor_debug_doc
add profile part
2 parents 87e977e + 7e6934e commit d6fe2e1

File tree

1 file changed

+285
-1
lines changed

1 file changed

+285
-1
lines changed

intermediate_source/inductor_debug_cpu.rst

Lines changed: 285 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,291 @@ Note that there exists a debugging tool provided by PyTorch, called `Minifier <h
302302
303303
Performance profiling
304304
--------------
305-
TODO: Haozhe
305+
For this part, we are going to describe how to analyze torchinductor model performance.
306+
Firsly, we choose an eager model as a baseline. We set up a benchmark to compare
307+
the end to end performance between eager model and inductor model.
308+
309+
.. code-block:: python
310+
from transformers import T5ForConditionalGeneration
311+
# init an eager model
312+
eager_model = T5ForConditionalGeneration.from_pretrained("t5-small")
313+
seq_length = 1024
314+
bs = 4
315+
vocab_size = model.config.vocab_size
316+
input = torch.randint(0, vocab_size, (bs, seq_length), dtype=torch.int64)
317+
input_dict = {"input_ids": input}
318+
input_dict["decoder_input_ids"] = input
319+
# init inductor model
320+
inductor_model = torch.compile(model)
321+
compiled(**input_dict)
322+
eager_t = 0
323+
inductor_t = 0
324+
for _ in range(100):
325+
model(**input_dict)
326+
for _ in range(1000):
327+
eager_start = time.time()
328+
model(**input_dict)
329+
eager_end = time.time()
330+
eager_t += eager_end - eager_start
331+
332+
for _ in range(100):
333+
model(**input_dict)
334+
for _ in range(1000):
335+
inductor_start = time.time()
336+
compiled(**input_dict)
337+
inductor_end = time.time()
338+
inductor_t += inductor_end - inductor_start
339+
340+
print(model.__class__)
341+
print("eager use:", eager_t)
342+
print("inductor use:", inductor_t)
343+
print("ratio:", eager_t / inductor_t)
344+
345+
Output:
346+
.. code-block:: shell
347+
eager use: 410.12550354003906
348+
inductor use: 478.59081745147705
349+
ratio: 0.8569439458198976
350+
351+
We see inductor model spent more time than eager model, which does not meet our expectation.
352+
To deep dive op-level performance, we can use `Pytorch Profiler<https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`
353+
We can enable kernel profile in inductor by:
354+
.. code-block:: python
355+
from torch._inductor import config
356+
config.cpp.enable_kernel_profile = True
357+
358+
Following the steps in `Pytorch Profiler<https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`
359+
we can get the profiling table and trace files.
360+
.. code-block:: python
361+
from torch.profiler import profile, schedule, ProfilerActivity
362+
my_schedule = schedule(
363+
skip_first=10,
364+
wait=5,
365+
warmup=5,
366+
active=1,
367+
repeat=5)
368+
369+
def trace_handler(p):
370+
output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=20)
371+
print(output)
372+
p.export_chrome_trace(RESULT_DIR + "/" + str(p.step_num) + ".json")
373+
374+
for _ in range(nwarmup):
375+
model(**input_dict)
376+
377+
total = 0
378+
with profile(
379+
activities=[ProfilerActivity.CPU],
380+
schedule=my_schedule,
381+
on_trace_ready=trace_handler
382+
) as p:
383+
for _ in range(100):
384+
begin = time.time()
385+
model(**input_dict)
386+
end=time.time()
387+
total += (end - begin)
388+
p.step()
389+
print("latency: {} ms".format(1000*(total)/100))
390+
391+
We can get following profile tables for eager model
392+
.. code-block:: shell
393+
----------------------- ------------ ------------ ------------ ------------ ------------ ------------
394+
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
395+
----------------------- ------------ ------------ ------------ ------------ ------------ ------------
396+
aten::mm 33.33% 138.616ms 33.33% 138.616ms 1.429ms 97
397+
aten::add_ 19.38% 80.596ms 19.38% 80.596ms 4.242ms 19
398+
aten::bmm 18.78% 78.104ms 18.78% 78.104ms 2.170ms 36
399+
aten::_softmax 11.32% 47.082ms 11.32% 47.082ms 2.616ms 18
400+
aten::copy_ 3.89% 16.190ms 3.89% 16.190ms 103.121us 157
401+
ProfilerStep* 3.53% 14.702ms 100.00% 415.949ms 415.949ms 1
402+
aten::add 2.37% 9.849ms 2.39% 9.958ms 144.319us 69
403+
aten::mul 1.13% 4.693ms 1.14% 4.726ms 65.639us 72
404+
aten::clamp_min 0.85% 3.541ms 0.85% 3.541ms 295.083us 12
405+
aten::index_select 0.84% 3.480ms 1.06% 4.401ms 1.100ms 4
406+
aten::linear 0.63% 2.637ms 33.95% 141.194ms 1.456ms 97
407+
aten::pow 0.61% 2.520ms 0.61% 2.554ms 79.812us 32
408+
aten::matmul 0.50% 2.067ms 56.53% 235.132ms 1.768ms 133
409+
aten::select 0.22% 900.000us 0.22% 910.000us 113.750us 8
410+
aten::log 0.18% 740.000us 0.18% 740.000us 370.000us 2
411+
aten::_unsafe_view 0.17% 718.000us 0.17% 718.000us 3.840us 187
412+
aten::sum 0.17% 715.000us 0.20% 831.000us 25.969us 32
413+
aten::transpose 0.15% 642.000us 0.18% 741.000us 3.963us 187
414+
aten::reshape 0.15% 622.000us 3.66% 15.241ms 88.098us 173
415+
aten::fill_ 0.15% 613.000us 0.15% 613.000us 15.718us 39
416+
----------------------- ------------ ------------ ------------ ------------ ------------ ------------
417+
Self CPU time total: 415.949ms
418+
And for inductor model
419+
.. code-block:: shell
420+
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
421+
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
422+
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
423+
mkl::_mkl_linear 28.24% 133.979ms 28.39% 134.689ms 1.389ms 97
424+
aten::bmm 15.65% 74.250ms 15.65% 74.251ms 2.063ms 36
425+
graph_0_cpp_fused__softmax_7 4.24% 20.123ms 4.24% 20.123ms 20.123ms 1
426+
graph_0_cpp_fused__softmax_42 4.17% 19.773ms 4.17% 19.773ms 19.773ms 1
427+
graph_0_cpp_fused__softmax_35 4.16% 19.751ms 4.16% 19.751ms 19.751ms 1
428+
graph_0_cpp_fused__softmax_21 4.15% 19.674ms 4.15% 19.674ms 19.674ms 1
429+
graph_0_cpp_fused__softmax_14 4.14% 19.654ms 4.14% 19.654ms 19.654ms 1
430+
graph_0_cpp_fused__softmax_28 4.13% 19.576ms 4.13% 19.576ms 19.576ms 1
431+
graph_0_cpp_fused__softmax_56 2.83% 13.404ms 2.83% 13.404ms 13.404ms 1
432+
graph_0_cpp_fused__softmax_80 2.82% 13.371ms 2.82% 13.371ms 13.371ms 1
433+
graph_0_cpp_fused__softmax_68 2.81% 13.323ms 2.81% 13.323ms 13.323ms 1
434+
graph_0_cpp_fused__softmax_92 2.80% 13.297ms 2.80% 13.297ms 13.297ms 1
435+
graph_0_cpp_fused__softmax_104 2.78% 13.208ms 2.78% 13.208ms 13.208ms 1
436+
graph_0_cpp_fused__softmax_2 2.63% 12.468ms 2.63% 12.468ms 12.468ms 1
437+
ProfilerStep* 1.61% 7.616ms 100.00% 474.360ms 474.360ms 1
438+
graph_0_cpp_fused__softmax_73 0.49% 2.320ms 0.49% 2.320ms 2.320ms 1
439+
graph_0_cpp_fused__softmax_85 0.49% 2.309ms 0.49% 2.309ms 2.309ms 1
440+
graph_0_cpp_fused__softmax_97 0.48% 2.283ms 0.48% 2.283ms 2.283ms 1
441+
graph_0_cpp_fused__softmax_61 0.48% 2.268ms 0.48% 2.268ms 2.268ms 1
442+
graph_0_cpp_fused__softmax_49 0.48% 2.255ms 0.48% 2.255ms 2.255ms 1
443+
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
444+
Self CPU time total: 474.360ms
445+
446+
We can search the most time consuming `graph_0_cpp_fused__softmax_7` in `output_code.py` to see the generated code:
447+
.. code-block:: python
448+
cpp_fused__softmax_7 = async_compile.cpp('''
449+
#include <ATen/record_function.h>
450+
#include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
451+
extern "C" void kernel(float* in_out_ptr0,
452+
const float* in_ptr1,
453+
float* out_ptr0,
454+
float* out_ptr1)
455+
{
456+
RECORD_FUNCTION("graph_0_cpp_fused__softmax_7", c10::ArrayRef<c10::IValue>({}));
457+
auto in_ptr0 = in_out_ptr0;
458+
#pragma omp parallel num_threads(32)
459+
{
460+
{
461+
#pragma omp for collapse(2)
462+
for(long i0=static_cast<long>(0L); i0<static_cast<long>(4L); i0+=static_cast<long>(1L))
463+
{
464+
for(long i1=static_cast<long>(0L); i1<static_cast<long>(8L); i1+=static_cast<long>(1L))
465+
{
466+
#pragma GCC ivdep
467+
for(long i2=static_cast<long>(0L); i2<static_cast<long>(1024L); i2+=static_cast<long>(1L))
468+
{
469+
{
470+
float tmp_acc0 = -std::numeric_limits<float>::infinity();
471+
for(long i3=static_cast<long>(0L); i3<static_cast<long>(1024L); i3+=static_cast<long>(1L))
472+
{
473+
auto tmp0 = in_ptr0[static_cast<long>(i3 + (1024L*i2) + (1048576L*i1) + (8388608L*i0))];
474+
auto tmp1 = static_cast<long>(i3 + ((-1L)*i2));
475+
auto tmp2 = static_cast<long>(0);
476+
auto tmp3 = tmp1 > tmp2;
477+
auto tmp4 = static_cast<long>(tmp3);
478+
auto tmp5 = static_cast<long>(16);
479+
auto tmp6 = decltype(tmp4)(tmp4 * tmp5);
480+
auto tmp7 = tmp6 + tmp2;
481+
auto tmp8 = std::abs(tmp1);
482+
auto tmp9 = static_cast<long>(8);
483+
auto tmp10 = tmp8 < tmp9;
484+
auto tmp11 = static_cast<float>(tmp8);
485+
auto tmp12 = static_cast<float>(8.0);
486+
auto tmp13 = tmp11 / tmp12;
487+
auto tmp14 = std::log(tmp13);
488+
auto tmp15 = static_cast<float>(2.772588722239781);
489+
auto tmp16 = tmp14 / tmp15;
490+
auto tmp17 = decltype(tmp16)(tmp16 * tmp12);
491+
auto tmp18 = static_cast<long>(tmp17);
492+
auto tmp19 = tmp18 + tmp9;
493+
auto tmp20 = static_cast<long>(15);
494+
auto tmp21 = min_propagate_nan(tmp19, tmp20);
495+
auto tmp22 = tmp10 ? tmp8 : tmp21;
496+
auto tmp23 = tmp7 + tmp22;
497+
auto tmp24 = in_ptr1[static_cast<long>(i1 + (8L*tmp23))];
498+
auto tmp25 = static_cast<float>(0.0);
499+
auto tmp26 = tmp24 + tmp25;
500+
auto tmp27 = tmp0 + tmp26;
501+
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp27);
502+
}
503+
out_ptr0[static_cast<long>(i2 + (1024L*i1) + (8192L*i0))] = tmp_acc0;
504+
}
505+
}
506+
}
507+
}
508+
}
509+
{
510+
#pragma omp for collapse(2)
511+
for(long i0=static_cast<long>(0L); i0<static_cast<long>(4L); i0+=static_cast<long>(1L))
512+
{
513+
for(long i1=static_cast<long>(0L); i1<static_cast<long>(8L); i1+=static_cast<long>(1L))
514+
{
515+
#pragma GCC ivdep
516+
for(long i2=static_cast<long>(0L); i2<static_cast<long>(1024L); i2+=static_cast<long>(1L))
517+
{
518+
#pragma GCC ivdep
519+
for(long i3=static_cast<long>(0L); i3<static_cast<long>(1024L); i3+=static_cast<long>(1L))
520+
{
521+
auto tmp0 = in_out_ptr0[static_cast<long>(i3 + (1024L*i2) + (1048576L*i1) + (8388608L*i0))];
522+
auto tmp28 = out_ptr0[static_cast<long>(i2 + (1024L*i1) + (8192L*i0))];
523+
auto tmp1 = static_cast<long>(i3 + ((-1L)*i2));
524+
auto tmp2 = static_cast<long>(0);
525+
auto tmp3 = tmp1 > tmp2;
526+
auto tmp4 = static_cast<long>(tmp3);
527+
auto tmp5 = static_cast<long>(16);
528+
auto tmp6 = decltype(tmp4)(tmp4 * tmp5);
529+
auto tmp7 = tmp6 + tmp2;
530+
auto tmp8 = std::abs(tmp1);
531+
auto tmp9 = static_cast<long>(8);
532+
auto tmp10 = tmp8 < tmp9;
533+
auto tmp11 = static_cast<float>(tmp8);
534+
auto tmp12 = static_cast<float>(8.0);
535+
auto tmp13 = tmp11 / tmp12;
536+
auto tmp14 = std::log(tmp13);
537+
auto tmp15 = static_cast<float>(2.772588722239781);
538+
auto tmp16 = tmp14 / tmp15;
539+
auto tmp17 = decltype(tmp16)(tmp16 * tmp12);
540+
auto tmp18 = static_cast<long>(tmp17);
541+
auto tmp19 = tmp18 + tmp9;
542+
auto tmp20 = static_cast<long>(15);
543+
auto tmp21 = min_propagate_nan(tmp19, tmp20);
544+
auto tmp22 = tmp10 ? tmp8 : tmp21;
545+
auto tmp23 = tmp7 + tmp22;
546+
auto tmp24 = in_ptr1[static_cast<long>(i1 + (8L*tmp23))];
547+
auto tmp25 = static_cast<float>(0.0);
548+
auto tmp26 = tmp24 + tmp25;
549+
auto tmp27 = tmp0 + tmp26;
550+
auto tmp29 = tmp27 - tmp28;
551+
in_out_ptr0[static_cast<long>(i3 + (1024L*i2) + (1048576L*i1) + (8388608L*i0))] = tmp29;
552+
}
553+
}
554+
}
555+
}
556+
}
557+
{
558+
#pragma omp for
559+
for(long i0=static_cast<long>(0L); i0<static_cast<long>(33554432L); i0+=static_cast<long>(16L))
560+
{
561+
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<long>(i0));
562+
auto tmp1 = tmp0.exp();
563+
tmp1.store(in_out_ptr0 + static_cast<long>(i0));
564+
}
565+
}
566+
{
567+
#pragma omp for
568+
for(long i0=static_cast<long>(0L); i0<static_cast<long>(32768L); i0+=static_cast<long>(1L))
569+
{
570+
{
571+
#pragma omp declare reduction(+:at::vec::Vectorized<float>:omp_out += omp_in) initializer(omp_priv={{0}})
572+
float tmp_acc0 = 0;
573+
auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0);
574+
for(long i1=static_cast<long>(0L); i1<static_cast<long>(1024L); i1+=static_cast<long>(16L))
575+
{
576+
auto tmp0 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<long>(i1 + (1024L*i0)));
577+
tmp_acc0_vec += tmp0;
578+
}
579+
tmp_acc0 += at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return x + y;}, tmp_acc0_vec);
580+
out_ptr1[static_cast<long>(i0)] = tmp_acc0;
581+
}
582+
}
583+
}
584+
}
585+
}
586+
''')
587+
With the kernel name `cpp_fused__softmax_*` and considering the profile
588+
results together, we may suspect the generated code for 'softmax' is
589+
inefficient. We encourage you to report an issue with all you findings above.
306590
307591
308592
Future work

0 commit comments

Comments
 (0)