-
Notifications
You must be signed in to change notification settings - Fork 723
add IntelPytorch Quantization code samples #1301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
71490af
add IntelPytorch Quantization code samples
ZhaoqiongZ bfac55e
fix the spelling error in the README file
ZhaoqiongZ 1881af0
use john's README with grammar fix and title change
ZhaoqiongZ df3031b
Rename third-party-grograms.txt to third-party-programs.txt
jimmytwei File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
617 changes: 617 additions & 0 deletions
617
...tics/Features-and-Functionality/IntelPytorch_Quantization/IntelPytorch_Quantization.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
154 changes: 154 additions & 0 deletions
154
...alytics/Features-and-Functionality/IntelPytorch_Quantization/IntelPytorch_Quantization.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
#!/usr/bin/env python | ||
# encoding: utf-8 | ||
|
||
''' | ||
============================================================== | ||
Copyright © 2022 Intel Corporation | ||
SPDX-License-Identifier: MIT | ||
============================================================== | ||
''' | ||
|
||
import torch | ||
import torchvision | ||
import tqdm | ||
import os | ||
from time import time | ||
import intel_extension_for_pytorch as ipex | ||
from intel_extension_for_pytorch.quantization import prepare, convert | ||
|
||
# Hyperparameters and constants | ||
LR = 0.001 | ||
DOWNLOAD = True | ||
DATA = 'datasets/cifar10/' | ||
WARMUP = 3 | ||
ITERS = 100 | ||
|
||
|
||
def inference(model, data): | ||
# Warmup for several iteration. | ||
for i in range(WARMUP): | ||
out = model(data) | ||
|
||
# Benchmark: accumulate inference time for multi iteration and calculate the average inference time. | ||
print("Inference ...") | ||
inference_time = 0 | ||
for i in range(ITERS): | ||
start_time = time() | ||
_ = model(data) | ||
end_time = time() | ||
inference_time = inference_time + (end_time - start_time) | ||
|
||
|
||
|
||
inference_time = inference_time / ITERS | ||
print("Inference Time Avg: ", inference_time) | ||
return inference_time | ||
|
||
|
||
def staticQuantize(model_fp32, data, calibration_data_loader): | ||
# Acquire inference times for static quantization INT8 model | ||
qconfig_static = ipex.quantization.default_static_qconfig | ||
# Alternatively, define your own qconfig: | ||
# from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig | ||
# qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), | ||
# weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)) | ||
prepared_model_static = prepare(model_fp32, qconfig_static, example_inputs=data, inplace=False) | ||
print("Calibration with Static Quantization ...") | ||
for batch_idx, (data, target) in enumerate(calibration_data_loader): | ||
prepared_model_static(data) | ||
if batch_idx % 10 == 0: | ||
print("Batch %d/%d complete, continue ..." %(batch_idx+1, len(calibration_data_loader))) | ||
print("Calibration Done") | ||
|
||
converted_model_static = convert(prepared_model_static) | ||
with torch.no_grad(): | ||
traced_model_static = torch.jit.trace(converted_model_static, data) | ||
traced_model_static = torch.jit.freeze(traced_model_static) | ||
|
||
traced_model_static.save("quantized_model_static.pt") | ||
return traced_model_static | ||
|
||
def dynamicQuantize(model_fp32, data): | ||
# Acquire inference times for dynamic quantization INT8 model | ||
qconfig_dynamic = ipex.quantization.default_dynamic_qconfig | ||
print("Quantize Model with Dynamic Quantization ...") | ||
|
||
prepared_model_dynamic = prepare(model_fp32, qconfig_dynamic, example_inputs=data, inplace=False) | ||
|
||
converted_model_dynamic = convert(prepared_model_dynamic) | ||
with torch.no_grad(): | ||
traced_model_dynamic = torch.jit.trace(converted_model_dynamic, data) | ||
traced_model_dynamic = torch.jit.freeze(traced_model_dynamic) | ||
|
||
# save the quantized static model | ||
traced_model_dynamic.save("quantized_model_dynamic.pt") | ||
return traced_model_dynamic | ||
|
||
|
||
""" | ||
Perform all types of training in main function | ||
""" | ||
def main(): | ||
|
||
# Load dataset | ||
transform = torchvision.transforms.Compose([ | ||
torchvision.transforms.Resize((224, 224)), | ||
torchvision.transforms.ToTensor(), | ||
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
test_dataset = torchvision.datasets.CIFAR10( | ||
root=DATA, | ||
train=False, | ||
transform=transform, | ||
download=DOWNLOAD, | ||
) | ||
calibration_data_loader = torch.utils.data.DataLoader( | ||
dataset=test_dataset, | ||
batch_size=128 | ||
) | ||
|
||
data = torch.rand(1, 3, 224, 224) | ||
model_fp32 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) | ||
model_fp32.eval() | ||
|
||
if not os.path.exists('quantized_model_static.pt'): | ||
# Static Quantizaton & Save Model to quantized_model_static.pt | ||
print('quantize the model with static quantization') | ||
staticQuantize(model_fp32, data, calibration_data_loader) | ||
|
||
if not os.path.exists('quantized_model_dynamic.pt'): | ||
# Dynamic Quantization & Save Model to quantized_model_dynamic.pt | ||
print('quantize the model with dynamic quantization') | ||
dynamicQuantize(model_fp32, data) | ||
|
||
print("Inference with FP32") | ||
fp32_inference_time = inference(model_fp32, data) | ||
|
||
traced_model_static = torch.jit.load('quantized_model_static.pt') | ||
traced_model_static.eval() | ||
traced_model_static = torch.jit.freeze(traced_model_static) | ||
print("Inference with Static INT8") | ||
int8_inference_time_static = inference(traced_model_static, data) | ||
|
||
traced_model_dynamic = torch.jit.load('quantized_model_dynamic.pt') | ||
traced_model_dynamic.eval() | ||
traced_model_dynamic = torch.jit.freeze(traced_model_dynamic) | ||
print("Inference with Dynamic INT8") | ||
int8_inference_time_dynamic = inference(traced_model_dynamic, data) | ||
|
||
# Inference time results | ||
print("Summary") | ||
print("FP32 inference time: %.3f" %fp32_inference_time) | ||
print("INT8 static quantization inference time: %.3f" %int8_inference_time_static) | ||
print("INT8 dynamic quantization inference time: %.3f" %int8_inference_time_dynamic) | ||
|
||
# Calculate speedup when using quantization | ||
speedup_from_fp32_static = fp32_inference_time / int8_inference_time_static | ||
print("Staic INT8 %.2fX faster than FP32" %speedup_from_fp32_static) | ||
speedup_from_fp32_dynamic = fp32_inference_time / int8_inference_time_dynamic | ||
print("Dynamic INT8 %.2fX faster than FP32" %speedup_from_fp32_dynamic) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
print('[CODE_SAMPLE_COMPLETED_SUCCESFULLY]') |
7 changes: 7 additions & 0 deletions
7
AI-and-Analytics/Features-and-Functionality/IntelPytorch_Quantization/License.txt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Copyright Intel Corporation | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
159 changes: 159 additions & 0 deletions
159
AI-and-Analytics/Features-and-Functionality/IntelPytorch_Quantization/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# `Optimize PyTorch Models using Intel® Extension for PyTorch* Quantization` Sample | ||
|
||
The `Optimize PyTorch Models using Intel® Extension for PyTorch* Quantization` sample demonstrates how to quantize a ResNet50 model that is calibrated by the CIFAR10 dataset using the Intel® Extension for PyTorch*. | ||
|
||
The Intel® Extension for PyTorch* extends PyTorch* with optimizations for extra performance boost on Intel® hardware. While most of the optimizations will be included in future PyTorch* releases, the extension delivers up-to-date features and optimizations for PyTorch on Intel® hardware. For example, newer optimizations include AVX-512 Vector Neural Network Instructions (AVX512 VNNI) and Intel® Advanced Matrix Extensions (Intel® AMX). | ||
|
||
| Area | Description | ||
|:--- |:--- | ||
| What you will learn | Inference performance improvements using Intel® Extension for PyTorch* (IPEX) with feature quantization | ||
| Time to complete | 5 minutes | ||
| Category | Concepts and Functionality | ||
|
||
## Purpose | ||
|
||
The Intel® Extension for PyTorch* gives users the ability to speed up inference on Intel® Xeon Scalable processors with INT8 data format and specialized computer instructions. The INT8 data format uses quarter the bit width of floating-point-32 (FP32), lowering the amount of memory needed and execution time to process. | ||
|
||
## Prerequisites | ||
|
||
| Optimized for | Description | ||
|:--- |:--- | ||
| OS | Ubuntu* 18.04 or newer | ||
| Hardware | Intel® Xeon® Scalable Processor family | ||
| Software | Intel® Extension for PyTorch* | ||
|
||
### For Local Development Environments | ||
|
||
You will need to download and install the following toolkits, tools, and components to use the sample. | ||
|
||
- **Intel® AI Analytics Toolkit (AI Kit)** | ||
|
||
You can get the AI Kit from [Intel® oneAPI Toolkits](https://www.intel.com/content/www/us/en/developer/tools/oneapi/toolkits.html#analytics-kit). <br> See [*Get Started with the Intel® AI Analytics Toolkit for Linux**](https://www.intel.com/content/www/us/en/develop/documentation/get-started-with-ai-linux) for AI Kit installation information and post-installation steps and scripts. | ||
|
||
- **Jupyter Notebook** | ||
|
||
Install using PIP: `$pip install notebook`. <br> Alternatively, see [*Installing Jupyter*](https://jupyter.org/install) for detailed installation instructions. | ||
|
||
- **Additional Packages** | ||
|
||
You will need to install these additional packages: **Matplotlib** and **tqdm**. | ||
``` | ||
python -m pip install matplotlib | ||
python -m pip install tqdm | ||
``` | ||
|
||
### For Intel® DevCloud | ||
|
||
The necessary tools and components are already installed in the environment. You do not need to install additional components. See *[Intel® DevCloud for oneAPI](https://DevCloud.intel.com/oneapi/get_started/)* for information. | ||
|
||
## Key Implementation Details | ||
|
||
This code sample quantizes a ResNet50 model that is calibrated with the CIFAR10 dataset while using Intel® Extension for PyTorch*. The model is inferenced using FP32 and INT8 precision, including the use of Intel® Advanced Matrix Extensions (AMX). AMX is supported on BF16 and INT8 data types starting with 4th Gen Xeon Scalable Processors. The inference time will be compared, showcasing the speedup of INT8. | ||
|
||
The sample tutorial contains one Jupyter Notebook and a Python script. You can use either. | ||
|
||
### Jupyter Notebook | ||
|
||
| Notebook | Description | ||
|:--- |:--- | ||
|`IntelPyTorch_Quantization.ipynb` | Performs inference with IPEX quantization in Jupyter Notebook. | ||
|
||
### Python Scripts | ||
|
||
| Script | Description | ||
|:--- |:--- | ||
|`IntelPyTorch_Quantization.py` | The script performs inference with IPEX quantization and compares the performance against the baseline | ||
|
||
## Set Environment Variables | ||
|
||
When working with the command-line interface (CLI), you should configure the oneAPI toolkits using environment variables. Set up your CLI environment by sourcing the `setvars` script every time you open a new terminal window. This practice ensures that your compiler, libraries, and tools are ready for development. | ||
|
||
## Run the `Optimize PyTorch Models using Intel® Extension for PyTorch* Quantization` Sample | ||
|
||
### On Linux* | ||
|
||
> **Note**: If you have not already done so, set up your CLI | ||
> environment by sourcing the `setvars` script in the root of your oneAPI installation. | ||
> | ||
> Linux*: | ||
> - For system wide installations: `. /opt/intel/oneapi/setvars.sh` | ||
> - For private installations: ` . ~/intel/oneapi/setvars.sh` | ||
> - For non-POSIX shells, like csh, use the following command: `bash -c 'source <install-dir>/setvars.sh ; exec csh'` | ||
> | ||
> For more information on configuring environment variables, see *[Use the setvars Script with Linux* or macOS*](https://www.intel.com/content/www/us/en/develop/documentation/oneapi-programming-guide/top/oneapi-development-environment-setup/use-the-setvars-script-with-linux-or-macos.html)*. | ||
#### Activate Conda | ||
|
||
1. Activate the Conda environment. | ||
``` | ||
conda activate pytorch | ||
``` | ||
2. Activate Conda environment without Root access (Optional). | ||
|
||
By default, the AI Kit is installed in the `/opt/intel/oneapi` folder and requires root privileges to manage it. | ||
|
||
You can choose to activate Conda environment without root access. To bypass root access to manage your Conda environment, clone and activate your desired Conda environment using the following commands similar to the following. | ||
|
||
``` | ||
conda create --name user_pytorch --clone pytorch | ||
conda activate user_pytorch | ||
``` | ||
|
||
#### Running the Jupyter Notebook | ||
|
||
1. Change to the sample directory. | ||
2. Launch Jupyter Notebook. | ||
``` | ||
jupyter notebook --ip=0.0.0.0 --port 8888 --allow-root | ||
``` | ||
3. Follow the instructions to open the URL with the token in your browser. | ||
4. Locate and select the Notebook. | ||
``` | ||
IntelPyTorch_Quantization.ipynb | ||
``` | ||
5. Change your Jupyter Notebook kernel to **PyTorch (AI kit)**. | ||
6. Run every cell in the Notebook in sequence. | ||
|
||
#### Running on the Command Line (Optional) | ||
|
||
1. Change to the sample directory. | ||
2. Run the script. | ||
``` | ||
python IntelPyTorch_Quantization.py | ||
``` | ||
|
||
### Run the `Optimize PyTorch Models using Intel® Extension for PyTorch* Quantization` Sample on Intel® DevCloud | ||
|
||
1. If you do not already have an account, request an Intel® DevCloud account at [*Create an Intel® DevCloud Account*](https://intelsoftwaresites.secure.force.com/DevCloud/oneapi). | ||
2. On a Linux* system, open a terminal. | ||
3. SSH into Intel® DevCloud. | ||
``` | ||
ssh DevCloud | ||
``` | ||
> **Note**: You can find information about configuring your Linux system and connecting to Intel DevCloud at Intel® DevCloud for oneAPI [Get Started](https://DevCloud.intel.com/oneapi/get_started). | ||
|
||
4. Follow the instructions to open the URL with the token in your browser. | ||
5. Locate and select the Notebook. | ||
``` | ||
IntelPyTorch_Quantization.ipynb | ||
```` | ||
6. Change the kernel to **PyTorch (AI kit)**. | ||
7. Run every cell in the Notebook in sequence. | ||
|
||
### Troubleshooting | ||
|
||
If you receive an error message, troubleshoot the problem using the **Diagnostics Utility for Intel® oneAPI Toolkits**. The diagnostic utility provides configuration and system checks to help find missing dependencies, permissions errors, and other issues. See the *[Diagnostics Utility for Intel® oneAPI Toolkits User Guide](https://www.intel.com/content/www/us/en/develop/documentation/diagnostic-utility-user-guide/top.html)* for more information on using the utility. | ||
|
||
## Example Output | ||
|
||
If successful, the sample displays `[CODE_SAMPLE_COMPLETED_SUCCESSFULLY]`. Additionally, the sample generates performance and analysis diagrams for comparison. | ||
|
||
The following image shows approximate performance speed increases using AMX INT8 during inference. | ||
|
||
 | ||
|
||
## License | ||
|
||
Code samples are licensed under the MIT license. See | ||
[License.txt](https://github.com/oneapi-src/oneAPI-samples/blob/master/License.txt) for details. | ||
|
||
Third party program Licenses can be found here: [third-party-programs.txt](https://github.com/oneapi-src/oneAPI-samples/blob/master/third-party-programs.txt) |
Binary file added
BIN
+28.8 KB
...eatures-and-Functionality/IntelPytorch_Quantization/assets/relative_speedup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions
2
AI-and-Analytics/Features-and-Functionality/IntelPytorch_Quantization/requirements.txt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tqdm | ||
matplotlib |
29 changes: 29 additions & 0 deletions
29
AI-and-Analytics/Features-and-Functionality/IntelPytorch_Quantization/sample.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"guid": "54B3F469-1E58-400B-9FCA-8BE08F680DCA", | ||
"name": "Optimize PyTorch Models using Intel® Extension for PyTorch* (IPEX) Quantization", | ||
"categories": ["Toolkit/oneAPI AI And Analytics/Features And Functionality"], | ||
"description": "Applying IPEX Quantization Optimizations to a PyTorch workload in a step-by-step manner to gain performance boost in inference.", | ||
"builder": ["cli"], | ||
"languages": [{ | ||
"python": {} | ||
}], | ||
"os": ["linux"], | ||
"targetDevice": ["CPU"], | ||
"ciTests": { | ||
"linux": [{ | ||
"env": ["source ${ONEAPI_ROOT}/setvars.sh --force", | ||
"conda env remove -n user_pytorch", | ||
"conda create --name user_pytorch --clone pytorch", | ||
"conda activate user_pytorch", | ||
"pip install -r requirements.txt", | ||
"~/.conda/envs/user_pytorch/bin/python -m ipykernel install --user --name=user_pytorch" | ||
], | ||
"id": "ipex_inference_optimization", | ||
"steps": [ | ||
"source activate user_pytorch", | ||
"python IntelPytorch_Quantization.py" | ||
] | ||
}] | ||
}, | ||
"expertise": "Concepts and Functionality" | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhaoqiongZ For name, do you want to use: "IPEX Static and Dynamic Quantization" or keep the same name?
If latter, please update README.md.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Jimmy, the one john changed is good, so I will just keep it.