Skip to content

Commit 2b987db

Browse files
committed
Add support for IPEX BF16 and INT8 model option
1 parent 6275f33 commit 2b987db

File tree

2 files changed

+89
-21
lines changed

2 files changed

+89
-21
lines changed

AI-and-Analytics/End-to-end-Workloads/LanguageIdentification/Inference/inference_commonVoice.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,63 @@ def trim_wav(self, newWavPath , start, end ):
4646
wavfile.write( newWavPath, self.sampleRate, self.waveData[startSample:endSample])
4747

4848
class speechbrain_inference:
49-
def __init__(self, ipex_op=False):
49+
def __init__(self, ipex_op=False, bf16=False, int8_model=False):
5050
source_model_path = "./lang_id_commonvoice_model"
5151
self.language_id = EncoderClassifier.from_hparams(source=source_model_path, savedir="tmp")
5252
print("Model: " + source_model_path)
5353

54-
# Optimize for inference with IPEX
55-
if ipex_op:
54+
if int8_model:
55+
# INT8 model
56+
source_model_int8_path = "./lang_id_commonvoice_model_INT8"
57+
print("Inference with INT8 model: " + source_model_int8_path)
58+
from neural_compressor.utils.pytorch import load
59+
self.model_int8 = load(source_model_int8_path, self.language_id)
60+
self.model_int8.eval()
61+
elif ipex_op:
62+
# Optimize for inference with IPEX
5663
print("Optimizing inference with IPEX")
5764
self.language_id.eval()
5865
sampleInput = (torch.load("./sample_input_features.pt"), torch.load("./sample_input_wav_lens.pt"))
59-
self.language_id.mods["embedding_model"] = ipex.optimize(self.language_id.mods["embedding_model"], sample_input=sampleInput)
66+
if bf16:
67+
print("BF16 enabled")
68+
self.language_id.mods["embedding_model"] = ipex.optimize(self.language_id.mods["embedding_model"], sample_input=sampleInput, dtype=torch.bfloat16)
69+
else:
70+
self.language_id.mods["embedding_model"] = ipex.optimize(self.language_id.mods["embedding_model"], sample_input=sampleInput)
71+
6072
# Torchscript to resolve performance issues with reorder operations
61-
self.language_id.mods["embedding_model"] = torch.jit.trace(self.language_id.mods["embedding_model"], example_inputs=sampleInput)
73+
with torch.no_grad():
74+
if bf16:
75+
with torch.cpu.amp.autocast():
76+
self.language_id.mods["embedding_model"] = torch.jit.trace(self.language_id.mods["embedding_model"], example_inputs=sampleInput)
77+
else:
78+
self.language_id.mods["embedding_model"] = torch.jit.trace(self.language_id.mods["embedding_model"], example_inputs=sampleInput)
79+
6280
return
6381

64-
def predict(self, data_path="", verbose=False):
82+
def predict(self, data_path="", ipex_op=False, bf16=False, int8_model=False, verbose=False):
6583
signal = self.language_id.load_audio(data_path)
6684
inference_start_time = time()
67-
prediction = self.language_id.classify_batch(signal)
85+
86+
if int8_model: # INT8 model from INC
87+
prediction = self.model_int8(signal)
88+
elif ipex_op: # IPEX
89+
with torch.no_grad():
90+
if bf16:
91+
with torch.cpu.amp.autocast():
92+
prediction = self.language_id.classify_batch(signal)
93+
else:
94+
prediction = self.language_id.classify_batch(signal)
95+
else: # default
96+
prediction = self.language_id.classify_batch(signal)
97+
6898
inference_end_time = time()
6999
inference_latency = inference_end_time - inference_start_time
70100
if verbose:
71101
print(" Inference latency: %.5f seconds" %(inference_latency))
72-
102+
73103
# prediction is a tuple of format (out_prob, score, index) due to modification of speechbrain.pretrained.interfaces.py
74104
label = self.language_id.hparams.label_encoder.decode_torch(prediction[2])[0]
75-
105+
76106
return label, inference_latency
77107

78108
def main(argv):
@@ -82,13 +112,17 @@ def main(argv):
82112
parser.add_argument('-d', type=int, default=3, help="Duration of each wave sample in seconds")
83113
parser.add_argument('-s', type=int, default=5, help="Sample size of waves to be taken from the audio file")
84114
parser.add_argument('--ipex', action="store_true", default=False, help="Enable Intel Extension for PyTorch (IPEX) optimizations")
115+
parser.add_argument('--bf16', action="store_true", default=False, help="Use bfloat16 precision (supported on 4th Gen Xeon Scalable Processors or newer")
116+
parser.add_argument('--int8_model', action="store_true", default=False, help="Run inference with INT8 model generated from Intel Neural Compressor (INC)")
85117
parser.add_argument('--verbose', action="store_true", default=False, help="Print additional debug info")
86118
args = parser.parse_args()
87119

88120
path = args.p
89121
sample_dur = args.d
90122
sample_size = args.s
91123
use_ipex = args.ipex
124+
use_bf16 = args.bf16
125+
use_int8_model = args.int8_model
92126
verbose = args.verbose
93127
print("\nTaking %d samples of %d seconds each" %(sample_size, sample_dur))
94128

@@ -103,7 +137,7 @@ def main(argv):
103137
writer = csv.writer(f)
104138
writer.writerow(["Language", "Total Samples", "Correct Predictions", "Accuracy"])
105139

106-
speechbrain_inf = speechbrain_inference(ipex_op=use_ipex)
140+
speechbrain_inf = speechbrain_inference(ipex_op=use_ipex, bf16=use_bf16, int8_model=use_int8_model)
107141
for language in languageList:
108142
print("\nTesting on %s data" %language)
109143
testDataDirectory = path + "/" + language
@@ -132,7 +166,7 @@ def main(argv):
132166
newWavPath = 'trim_tmp.wav'
133167
data.trim_wav(newWavPath, start, start + sample_dur)
134168
try:
135-
label, inference_latency = speechbrain_inf.predict(data_path=newWavPath, verbose=verbose)
169+
label, inference_latency = speechbrain_inf.predict(data_path=newWavPath, ipex_op=use_ipex, bf16=use_bf16, int8_model=use_int8_model, verbose=verbose)
136170
if verbose:
137171
print(" start-end : " + str(start) + " " + str(start + sample_dur) + " prediction : " + label)
138172
predict_list.append(label)
@@ -174,4 +208,4 @@ def main(argv):
174208

175209
if __name__ == "__main__":
176210
import sys
177-
sys.exit(main(sys.argv))
211+
sys.exit(main(sys.argv))

AI-and-Analytics/End-to-end-Workloads/LanguageIdentification/Inference/inference_custom.py

+43-9
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,55 @@ def trim_wav(self, newWavPath , start, end ):
4848

4949

5050
class speechbrain_inference:
51-
def __init__(self, ipex_op=False):
51+
def __init__(self, ipex_op=False, bf16=False, int8_model=False):
5252
source_model_path = "./lang_id_commonvoice_model"
5353
self.language_id = EncoderClassifier.from_hparams(source=source_model_path, savedir="tmp")
5454
print("Model: " + source_model_path)
5555

56-
# Optimize for inference with IPEX
57-
if ipex_op:
56+
if int8_model:
57+
# INT8 model
58+
source_model_int8_path = "./lang_id_commonvoice_model_INT8"
59+
print("Inference with INT8 model: " + source_model_int8_path)
60+
from neural_compressor.utils.pytorch import load
61+
self.model_int8 = load(source_model_int8_path, self.language_id)
62+
self.model_int8.eval()
63+
elif ipex_op:
64+
# Optimize for inference with IPEX
5865
print("Optimizing inference with IPEX")
5966
self.language_id.eval()
6067
sampleInput = (torch.load("./sample_input_features.pt"), torch.load("./sample_input_wav_lens.pt"))
61-
self.language_id.mods["embedding_model"] = ipex.optimize(self.language_id.mods["embedding_model"], sample_input=sampleInput)
68+
if bf16:
69+
print("BF16 enabled")
70+
self.language_id.mods["embedding_model"] = ipex.optimize(self.language_id.mods["embedding_model"], sample_input=sampleInput, dtype=torch.bfloat16)
71+
else:
72+
self.language_id.mods["embedding_model"] = ipex.optimize(self.language_id.mods["embedding_model"], sample_input=sampleInput)
73+
6274
# Torchscript to resolve performance issues with reorder operations
63-
self.language_id.mods["embedding_model"] = torch.jit.trace(self.language_id.mods["embedding_model"], example_inputs=sampleInput)
75+
with torch.no_grad():
76+
if bf16:
77+
with torch.cpu.amp.autocast():
78+
self.language_id.mods["embedding_model"] = torch.jit.trace(self.language_id.mods["embedding_model"], example_inputs=sampleInput)
79+
else:
80+
self.language_id.mods["embedding_model"] = torch.jit.trace(self.language_id.mods["embedding_model"], example_inputs=sampleInput)
81+
6482
return
6583

66-
def predict(self, data_path="", verbose=False):
84+
def predict(self, data_path="", ipex_op=False, bf16=False, int8_model=False, verbose=False):
6785
signal = self.language_id.load_audio(data_path)
6886
inference_start_time = time()
69-
prediction = self.language_id.classify_batch(signal)
87+
88+
if int8_model: # INT8 model from INC
89+
prediction = self.model_int8(signal)
90+
elif ipex_op: # IPEX
91+
with torch.no_grad():
92+
if bf16:
93+
with torch.cpu.amp.autocast():
94+
prediction = self.language_id.classify_batch(signal)
95+
else:
96+
prediction = self.language_id.classify_batch(signal)
97+
else: # default
98+
prediction = self.language_id.classify_batch(signal)
99+
70100
inference_end_time = time()
71101
inference_latency = inference_end_time - inference_start_time
72102
if verbose:
@@ -86,6 +116,8 @@ def main(argv):
86116
parser.add_argument('-s', type=int, default=100, help="Sample size of waves to be taken from the audio file")
87117
parser.add_argument('--vad', action="store_true", default=False, help="Use Voice Activity Detection (VAD) to extract only the speech segments of the audio file")
88118
parser.add_argument('--ipex', action="store_true", default=False, help="Enable Intel Extension for PyTorch (IPEX) optimizations")
119+
parser.add_argument('--bf16', action="store_true", default=False, help="Use bfloat16 precision (supported on 4th Gen Xeon Scalable Processors or newer")
120+
parser.add_argument('--int8_model', action="store_true", default=False, help="Run inference with INT8 model generated from Intel Neural Compressor (INC)")
89121
parser.add_argument('--ground_truth_compare', action="store_true", default=False, help="Enable comparison of prediction labels to ground truth values")
90122
parser.add_argument('--verbose', action="store_true", default=False, help="Print additional debug info")
91123
args = parser.parse_args()
@@ -95,6 +127,8 @@ def main(argv):
95127
sample_size = args.s
96128
use_vad = args.vad
97129
use_ipex = args.ipex
130+
use_bf16 = args.bf16
131+
use_int8_model = args.int8_model
98132
ground_truth_compare = args.ground_truth_compare
99133
verbose = args.verbose
100134
print("\nTaking %d samples of %d seconds each" %(sample_size, sample_dur))
@@ -119,7 +153,7 @@ def main(argv):
119153
else:
120154
raise Exception("Ground truth labels file does not exist.")
121155

122-
speechbrain_inf = speechbrain_inference(ipex_op=use_ipex)
156+
speechbrain_inf = speechbrain_inference(ipex_op=use_ipex, bf16=use_bf16, int8_model=use_int8_model)
123157
if use_vad:
124158
from speechbrain.pretrained import VAD
125159
print("Using Voice Activity Detection")
@@ -232,7 +266,7 @@ def main(argv):
232266
newWavPath = 'trim_tmp.wav'
233267
data.trim_wav(newWavPath, start, start + sample_dur)
234268
try:
235-
label, inference_latency = speechbrain_inf.predict(data_path=newWavPath, verbose=verbose)
269+
label, inference_latency = speechbrain_inf.predict(data_path=newWavPath, ipex_op=use_ipex, bf16=use_bf16, int8_model=use_int8_model, verbose=verbose)
236270
if verbose:
237271
print(" start-end : " + str(start) + " " + str(start + sample_dur) + " prediction : " + label)
238272
predict_list.append(label)

0 commit comments

Comments
 (0)