Skip to content

Fix eval compile timing in torch.compile tutorial #2596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def init_model():
# ``mode`` argument, which we will discuss below.

def evaluate(mod, inp):
with torch.no_grad():
return mod(inp)
return mod(inp)

model = init_model()

Expand All @@ -151,8 +150,9 @@ def evaluate(mod, inp):
evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")

inp = generate_data(16)[0]
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
with torch.no_grad():
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])

######################################################################
# Notice that ``torch.compile`` takes a lot longer to complete
Expand All @@ -165,7 +165,8 @@ def evaluate(mod, inp):
eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
_, eager_time = timed(lambda: evaluate(model, inp))
with torch.no_grad():
_, eager_time = timed(lambda: evaluate(model, inp))
eager_times.append(eager_time)
print(f"eager eval time {i}: {eager_time}")

Expand All @@ -174,7 +175,8 @@ def evaluate(mod, inp):
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
_, compile_time = timed(lambda: evaluate_opt(model, inp))
with torch.no_grad():
_, compile_time = timed(lambda: evaluate_opt(model, inp))
compile_times.append(compile_time)
print(f"compile eval time {i}: {compile_time}")
print("~" * 10)
Expand All @@ -183,6 +185,7 @@ def evaluate(mod, inp):
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

Expand Down Expand Up @@ -239,6 +242,7 @@ def train(mod, data):
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

Expand Down