Skip to content

Commit 2c7e596

Browse files
authored
Adding new oneAPI sample PyTorch AMX BF16/INT8 Inference (#1401)
* Add new oneAPI Sample IPEX Inference Optimization * Replacing random.randint() with random.sample() * Add support for IPEX BF16 and INT8 model option * Revert "Add support for IPEX BF16 and INT8 model option" This reverts commit 2b987db. * Adding new oneAPI sample PyTorch AMX BF16/INT8 Inference * update Features-and-Functionality README with latest changes * update Features-and-Functionality README with PT AMX BF16/INT8 Inference * README review updates * add missing * to PyTorch on README
1 parent 3a8529c commit 2c7e596

File tree

9 files changed

+1255
-1
lines changed

9 files changed

+1255
-1
lines changed

AI-and-Analytics/Features-and-Functionality/IntelPyTorch_InferenceOptimizations_AMX_BF16_INT8/IntelPyTorch_InferenceOptimizations_AMX_BF16_INT8.ipynb

+535
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
4+
'''
5+
==============================================================
6+
Copyright © 2023 Intel Corporation
7+
8+
SPDX-License-Identifier: MIT
9+
==============================================================
10+
'''
11+
12+
import os
13+
from time import time
14+
import matplotlib.pyplot as plt
15+
import torch
16+
import intel_extension_for_pytorch as ipex
17+
from intel_extension_for_pytorch.quantization import prepare, convert
18+
import torchvision
19+
from torchvision import models
20+
from transformers import BertModel
21+
22+
NUM_SAMPLES = 1000 # number of samples to perform inference on
23+
SUPPORTED_MODELS = ["resnet50", "bert"] # models supported by this code sample
24+
25+
# BERT sample data parameters
26+
BERT_BATCH_SIZE = 1
27+
BERT_SEQ_LENGTH = 512
28+
29+
"""
30+
Function to perform inference on Resnet50 and BERT
31+
"""
32+
def runInference(model, data, modelName="resnet50", dataType="FP32", amx=True):
33+
"""
34+
Input parameters
35+
model: the PyTorch model object used for inference
36+
data: a sample input into the model
37+
modelName: str representing the name of the model, supported values - resnet50, bert
38+
dataType: str representing the data type for model parameters, supported values - FP32, BF16, INT8
39+
amx: set to False to disable AMX on BF16, Default: True
40+
Return value
41+
inference_time: the time in seconds it takes to perform inference with the model
42+
"""
43+
44+
# Display run case
45+
if amx:
46+
isa_text = "AVX512_CORE_AMX"
47+
else:
48+
isa_text = "AVX512_CORE_VNNI"
49+
print("%s %s inference with %s" %(modelName, dataType, isa_text))
50+
51+
# Configure environment variable
52+
if not amx:
53+
os.environ["ONEDNN_MAX_CPU_ISA"] = "AVX512_CORE_VNNI"
54+
else:
55+
os.environ["ONEDNN_MAX_CPU_ISA"] = "DEFAULT"
56+
57+
# Special variables for specific models
58+
if "bert" == modelName:
59+
d = torch.randint(model.config.vocab_size, size=[BERT_BATCH_SIZE, BERT_SEQ_LENGTH]) # sample data input for torchscript and inference
60+
61+
# Prepare model for inference based on precision (FP32, BF16, INT8)
62+
if "INT8" == dataType:
63+
# Quantize model to INT8 if needed (one time)
64+
model_filename = "quantized_model_%s.pt" %modelName
65+
if not os.path.exists(model_filename):
66+
qconfig = ipex.quantization.default_static_qconfig
67+
prepared_model = prepare(model, qconfig, example_inputs=data, inplace=False)
68+
converted_model = convert(prepared_model)
69+
with torch.no_grad():
70+
if "resnet50" == modelName:
71+
traced_model = torch.jit.trace(converted_model, data)
72+
elif "bert" == modelName:
73+
traced_model = torch.jit.trace(converted_model, (d,), check_trace=False, strict=False)
74+
else:
75+
raise Exception("ERROR: modelName %s is not supported. Choose from %s" %(modelName, SUPPORTED_MODELS))
76+
traced_model.save(model_filename)
77+
78+
# Load INT8 model for inference
79+
model = torch.jit.load(model_filename)
80+
model.eval()
81+
model = torch.jit.freeze(model)
82+
elif "BF16" == dataType:
83+
model = ipex.optimize(model, dtype=torch.bfloat16)
84+
with torch.no_grad():
85+
with torch.cpu.amp.autocast():
86+
if "resnet50" == modelName:
87+
model = torch.jit.trace(model, data)
88+
elif "bert" == modelName:
89+
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
90+
else:
91+
raise Exception("ERROR: modelName %s is not supported. Choose from %s" %(modelName, SUPPORTED_MODELS))
92+
model = torch.jit.freeze(model)
93+
else: # FP32
94+
with torch.no_grad():
95+
if "resnet50" == modelName:
96+
model = torch.jit.trace(model, data)
97+
elif "bert" == modelName:
98+
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
99+
else:
100+
raise Exception("ERROR: modelName %s is not supported. Choose from %s" %(modelName, SUPPORTED_MODELS))
101+
model = torch.jit.freeze(model)
102+
103+
# Run inference
104+
with torch.no_grad():
105+
if "BF16" == dataType:
106+
with torch.cpu.amp.autocast():
107+
# Warm up
108+
for i in range(20):
109+
model(data)
110+
111+
# Measure latency
112+
start_time = time()
113+
for i in range(NUM_SAMPLES):
114+
model(data)
115+
end_time = time()
116+
else:
117+
# Warm up
118+
for i in range(20):
119+
model(data)
120+
121+
# Measure latency
122+
start_time = time()
123+
for i in range(NUM_SAMPLES):
124+
model(data)
125+
end_time = time()
126+
inference_time = end_time - start_time
127+
print("Inference on %d samples took %.3f seconds" %(NUM_SAMPLES, inference_time))
128+
129+
return inference_time
130+
131+
"""
132+
Prints out results and displays figures summarizing output.
133+
"""
134+
def summarizeResults(modelName="", results=None):
135+
"""
136+
Input parameters
137+
modelName: a str representing the name of the model
138+
results: a dict with the run case and its corresponding time in seconds
139+
Return value
140+
None
141+
"""
142+
143+
# Inference time results
144+
print("\nSummary for %s (%d samples)" %(modelName, NUM_SAMPLES))
145+
for key in results.keys():
146+
print("%s inference time: %.3f seconds" %(key, results[key]))
147+
148+
# Create bar chart with inference time results
149+
plt.figure()
150+
plt.title("%s Inference Time (%d samples)" %(modelName, NUM_SAMPLES))
151+
plt.xlabel("Run Case")
152+
plt.ylabel("Inference Time (seconds)")
153+
plt.bar(results.keys(), results.values())
154+
155+
# Calculate speedup when using AMX
156+
print("\n")
157+
bf16_with_amx_speedup = results["FP32"] / results["BF16_with_AMX"]
158+
print("BF16 with AMX is %.2fX faster than FP32" %bf16_with_amx_speedup)
159+
int8_with_vnni_speedup = results["FP32"] / results["INT8_with_VNNI"]
160+
print("INT8 with VNNI is %.2fX faster than FP32" %int8_with_vnni_speedup)
161+
int8_with_amx_speedup = results["FP32"] / results["INT8_with_AMX"]
162+
print("INT8 with AMX is %.2fX faster than FP32" %int8_with_amx_speedup)
163+
print("\n\n")
164+
165+
# Create bar chart with speedup results
166+
plt.figure()
167+
plt.title("%s AMX BF16/INT8 Speedup over FP32" %modelName)
168+
plt.xlabel("Run Case")
169+
plt.ylabel("Speedup")
170+
plt.bar(results.keys(),
171+
[1, bf16_with_amx_speedup, int8_with_vnni_speedup, int8_with_amx_speedup]
172+
)
173+
174+
"""
175+
Perform all types of inference in main function
176+
177+
Inference run cases for both Resnet50 and BERT
178+
1) FP32 (baseline)
179+
2) BF16 using AVX512_CORE_AMX
180+
3) INT8 using AVX512_CORE_VNNI
181+
4) INT8 using AVX512_CORE_AMX
182+
"""
183+
def main():
184+
# Check if hardware supports AMX
185+
import sys
186+
sys.path.append('../../')
187+
import version_check
188+
from cpuinfo import get_cpu_info
189+
info = get_cpu_info()
190+
flags = info['flags']
191+
amx_supported = False
192+
for flag in flags:
193+
if "amx" in flag:
194+
amx_supported = True
195+
break
196+
if not amx_supported:
197+
print("AMX is not supported on current hardware. Code sample cannot be run.\n")
198+
return
199+
200+
# ResNet50
201+
resnet_model = models.resnet50(pretrained=True)
202+
resnet_data = torch.rand(1, 3, 224, 224)
203+
resnet_model.eval()
204+
fp32_resnet_inference_time = runInference(resnet_model, resnet_data, modelName="resnet50", dataType="FP32", amx=True)
205+
bf16_amx_resnet_inference_time = runInference(resnet_model, resnet_data, modelName="resnet50", dataType="BF16", amx=True)
206+
int8_with_vnni_resnet_inference_time = runInference(resnet_model, resnet_data, modelName="resnet50", dataType="INT8", amx=False)
207+
int8_amx_resnet_inference_time = runInference(resnet_model, resnet_data, modelName="resnet50", dataType="INT8", amx=True)
208+
results_resnet = {
209+
"FP32": fp32_resnet_inference_time,
210+
"BF16_with_AMX": bf16_amx_resnet_inference_time,
211+
"INT8_with_VNNI": int8_with_vnni_resnet_inference_time,
212+
"INT8_with_AMX": int8_amx_resnet_inference_time
213+
}
214+
summarizeResults("ResNet50", results_resnet)
215+
216+
# BERT
217+
bert_model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')
218+
bert_data = torch.randint(bert_model.config.vocab_size, size=[BERT_BATCH_SIZE, BERT_SEQ_LENGTH])
219+
bert_model.eval()
220+
fp32_bert_inference_time = runInference(bert_model, bert_data, modelName="bert", dataType="FP32", amx=True)
221+
bf16_amx_bert_inference_time = runInference(bert_model, bert_data, modelName="bert", dataType="BF16", amx=True)
222+
int8_with_vnni_bert_inference_time = runInference(bert_model, bert_data, modelName="bert", dataType="INT8", amx=False)
223+
int8_amx_bert_inference_time = runInference(bert_model, bert_data, modelName="bert", dataType="INT8", amx=True)
224+
results_bert = {
225+
"FP32": fp32_bert_inference_time,
226+
"BF16_with_AMX": bf16_amx_bert_inference_time,
227+
"INT8_with_VNNI": int8_with_vnni_bert_inference_time,
228+
"INT8_with_AMX": int8_amx_bert_inference_time
229+
}
230+
summarizeResults("BERT", results_bert)
231+
232+
# Display graphs
233+
plt.show()
234+
235+
if __name__ == '__main__':
236+
main()
237+
print('[CODE_SAMPLE_COMPLETED_SUCCESFULLY]')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Copyright Intel Corporation
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4+
5+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6+
7+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

0 commit comments

Comments
 (0)