@@ -139,8 +139,7 @@ def init_model():
139
139
# ``mode`` argument, which we will discuss below.
140
140
141
141
def evaluate (mod , inp ):
142
- with torch .no_grad ():
143
- return mod (inp )
142
+ return mod (inp )
144
143
145
144
model = init_model ()
146
145
@@ -151,8 +150,9 @@ def evaluate(mod, inp):
151
150
evaluate_opt = torch .compile (evaluate , mode = "reduce-overhead" )
152
151
153
152
inp = generate_data (16 )[0 ]
154
- print ("eager:" , timed (lambda : evaluate (model , inp ))[1 ])
155
- print ("compile:" , timed (lambda : evaluate_opt (model , inp ))[1 ])
153
+ with torch .no_grad ():
154
+ print ("eager:" , timed (lambda : evaluate (model , inp ))[1 ])
155
+ print ("compile:" , timed (lambda : evaluate_opt (model , inp ))[1 ])
156
156
157
157
######################################################################
158
158
# Notice that ``torch.compile`` takes a lot longer to complete
@@ -165,7 +165,8 @@ def evaluate(mod, inp):
165
165
eager_times = []
166
166
for i in range (N_ITERS ):
167
167
inp = generate_data (16 )[0 ]
168
- _ , eager_time = timed (lambda : evaluate (model , inp ))
168
+ with torch .no_grad ():
169
+ _ , eager_time = timed (lambda : evaluate (model , inp ))
169
170
eager_times .append (eager_time )
170
171
print (f"eager eval time { i } : { eager_time } " )
171
172
@@ -174,7 +175,8 @@ def evaluate(mod, inp):
174
175
compile_times = []
175
176
for i in range (N_ITERS ):
176
177
inp = generate_data (16 )[0 ]
177
- _ , compile_time = timed (lambda : evaluate_opt (model , inp ))
178
+ with torch .no_grad ():
179
+ _ , compile_time = timed (lambda : evaluate_opt (model , inp ))
178
180
compile_times .append (compile_time )
179
181
print (f"compile eval time { i } : { compile_time } " )
180
182
print ("~" * 10 )
@@ -183,6 +185,7 @@ def evaluate(mod, inp):
183
185
eager_med = np .median (eager_times )
184
186
compile_med = np .median (compile_times )
185
187
speedup = eager_med / compile_med
188
+ assert (speedup > 1 )
186
189
print (f"(eval) eager median: { eager_med } , compile median: { compile_med } , speedup: { speedup } x" )
187
190
print ("~" * 10 )
188
191
@@ -239,6 +242,7 @@ def train(mod, data):
239
242
eager_med = np .median (eager_times )
240
243
compile_med = np .median (compile_times )
241
244
speedup = eager_med / compile_med
245
+ assert (speedup > 1 )
242
246
print (f"(train) eager median: { eager_med } , compile median: { compile_med } , speedup: { speedup } x" )
243
247
print ("~" * 10 )
244
248
0 commit comments