Skip to content

Commit c28def9

Browse files
authored
fix compile tutorial (#2596)
1 parent d986068 commit c28def9

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

intermediate_source/torch_compile_tutorial.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def init_model():
139139
# ``mode`` argument, which we will discuss below.
140140

141141
def evaluate(mod, inp):
142-
with torch.no_grad():
143-
return mod(inp)
142+
return mod(inp)
144143

145144
model = init_model()
146145

@@ -151,8 +150,9 @@ def evaluate(mod, inp):
151150
evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")
152151

153152
inp = generate_data(16)[0]
154-
print("eager:", timed(lambda: evaluate(model, inp))[1])
155-
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
153+
with torch.no_grad():
154+
print("eager:", timed(lambda: evaluate(model, inp))[1])
155+
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
156156

157157
######################################################################
158158
# Notice that ``torch.compile`` takes a lot longer to complete
@@ -165,7 +165,8 @@ def evaluate(mod, inp):
165165
eager_times = []
166166
for i in range(N_ITERS):
167167
inp = generate_data(16)[0]
168-
_, eager_time = timed(lambda: evaluate(model, inp))
168+
with torch.no_grad():
169+
_, eager_time = timed(lambda: evaluate(model, inp))
169170
eager_times.append(eager_time)
170171
print(f"eager eval time {i}: {eager_time}")
171172

@@ -174,7 +175,8 @@ def evaluate(mod, inp):
174175
compile_times = []
175176
for i in range(N_ITERS):
176177
inp = generate_data(16)[0]
177-
_, compile_time = timed(lambda: evaluate_opt(model, inp))
178+
with torch.no_grad():
179+
_, compile_time = timed(lambda: evaluate_opt(model, inp))
178180
compile_times.append(compile_time)
179181
print(f"compile eval time {i}: {compile_time}")
180182
print("~" * 10)
@@ -183,6 +185,7 @@ def evaluate(mod, inp):
183185
eager_med = np.median(eager_times)
184186
compile_med = np.median(compile_times)
185187
speedup = eager_med / compile_med
188+
assert(speedup > 1)
186189
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
187190
print("~" * 10)
188191

@@ -239,6 +242,7 @@ def train(mod, data):
239242
eager_med = np.median(eager_times)
240243
compile_med = np.median(compile_times)
241244
speedup = eager_med / compile_med
245+
assert(speedup > 1)
242246
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
243247
print("~" * 10)
244248

0 commit comments

Comments
 (0)