Skip to content

Small fixes to torch.compile tutorial #2601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,18 @@ def init_model():
# Note that in the call to ``torch.compile``, we have have the additional
# ``mode`` argument, which we will discuss below.

def evaluate(mod, inp):
return mod(inp)

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")
model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
print("eager:", timed(lambda: model(inp))[1])
print("compile:", timed(lambda: model_opt(inp))[1])

######################################################################
# Notice that ``torch.compile`` takes a lot longer to complete
Expand All @@ -166,7 +163,7 @@ def evaluate(mod, inp):
for i in range(N_ITERS):
inp = generate_data(16)[0]
with torch.no_grad():
_, eager_time = timed(lambda: evaluate(model, inp))
_, eager_time = timed(lambda: model(inp))
eager_times.append(eager_time)
print(f"eager eval time {i}: {eager_time}")

Expand All @@ -176,7 +173,7 @@ def evaluate(mod, inp):
for i in range(N_ITERS):
inp = generate_data(16)[0]
with torch.no_grad():
_, compile_time = timed(lambda: evaluate_opt(model, inp))
_, compile_time = timed(lambda: model_opt(inp))
compile_times.append(compile_time)
print(f"compile eval time {i}: {compile_time}")
print("~" * 10)
Expand Down Expand Up @@ -250,6 +247,10 @@ def train(mod, data):
# Again, we can see that ``torch.compile`` takes longer in the first
# iteration, as it must compile the model, but in subsequent iterations, we see
# significant speedups compared to eager.
#
# We remark that the speedup numbers presented in this tutorial are for
# demonstration purposes only. Official speedup values can be seen at the
# `TorchInductor performance dashboard <https://hud.pytorch.org/benchmark/compilers>`__.

######################################################################
# Comparison to TorchScript and FX Tracing
Expand Down