Skip to content

Tf amx bf16 training 2023.1 ai kit #1320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,397 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d4aec8e8-84cc-41f4-8ca1-916e9f28fc36",
"metadata": {},
"source": [
"# Intel TensorFlow AMX BF16 Training\n",
"This code sample will train a DistilBERT model while using Intel Optimized TensorFlow. The model will be trained using FP32 and BF16 precision, including the use of Intel(R) Advanced Matrix Extensions (AMX) on BF16. AMX is supported on BF16 data type starting with the 4th Generation of Xeon Scalable Processors. The training time will be compared, showcasing the speedup of AMX."
]
},
{
"cell_type": "markdown",
"id": "ec0f38f9-6165-4ff4-b0a9-6a2fabdd06bb",
"metadata": {},
"source": [
"## Environment Setup\n",
"Ensure the TensorFlow kernel is activated before running this notebook."
]
},
{
"cell_type": "markdown",
"id": "0ac7dfbc-ab47-4c06-ab82-b180449fcea1",
"metadata": {},
"source": [
"# Imports, Dataset, Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ff40939-c380-49fe-af59-9650e0dfaa75",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras.layers import Dense, Input\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.callbacks import ModelCheckpoint\n",
"import transformers\n",
"import time\n",
"import matplotlib.pyplot as plt\n",
"import argparse"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc9d8241-ab47-49a2-b563-8cbe57a562f1",
"metadata": {},
"outputs": [],
"source": [
"is_tune_model = True # or False\n",
"log_dir = \"logs\"\n",
"profiling_needed = False\n",
"execution_mode = \"graph\"\n",
"load_weights_dir = \"weights\"\n",
"save_weights_dir = \"weights\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81a5c73e-b4e5-4023-bca7-bd08386e5521",
"metadata": {},
"outputs": [],
"source": [
"if execution_mode == \"graph\":\n",
" tf.compat.v1.disable_eager_execution()"
]
},
{
"cell_type": "markdown",
"id": "23c380c1-85ab-4375-92fe-c2be1ddd538c",
"metadata": {},
"source": [
"# Identify Supported ISA\n",
"We identify the underlying supported ISA to determine whether AMX is supported. You must use a 4th Gen Intel® Xeon® Scalable Processor or newer must to run this sample."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33320ac8-2ac4-453d-b800-75b0086f98c3",
"metadata": {},
"outputs": [],
"source": [
"# Check if hardware supports AMX\n",
"\n",
"from cpuinfo import get_cpu_info\n",
"info = get_cpu_info()\n",
"flags = info['flags']\n",
"amx_supported = False\n",
"for flag in flags:\n",
" if \"amx\" in flag:\n",
" amx_supported = True\n",
" print(\"AMX is supported on current hardware. Code sample can be run.\\n\")\n",
"if not amx_supported:\n",
" print(\"AMX is not supported on current hardware. Code sample cannot be run.\\n\")\n",
" sys.exit(\"AMX is not supported on current hardware. Code sample cannot be run.\\n\")\n"
]
},
{
"cell_type": "markdown",
"id": "2d65438f-a695-493d-8605-5f678f11ac99",
"metadata": {},
"source": [
"If the message \"AMX is not supported on current hardware. Code sample cannot be run.\" is printed above, the hardware being used does not support AMX. Therefore, this code sample cannot proceed"
]
},
{
"cell_type": "markdown",
"id": "924c0e12-7c52-45f3-b650-f5bcf759897f",
"metadata": {},
"source": [
"# Build the Model\n",
"The functions below will build up the DistilBERT model based on the whether AMX should be enabled, and whether to use FP32 or BF16 data type. The environment variable ONEDNN_MAX_CPU_ISA is used to enable or disable AMX. For more information, refer to the oneDNN documentation on CPU Dispatcher Control. To use BF16 in operations, use the tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') function."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f121ef6-79cb-4975-85a1-6cd7a2ec3c03",
"metadata": {},
"outputs": [],
"source": [
"def bert_encode(texts, tokenizer, max_len=512):\n",
" all_tokens = []\n",
" \n",
" for text in texts:\n",
" text = tokenizer.tokenize(text)\n",
" \n",
" text = text[:max_len-2]\n",
" input_sequence = [\"[CLS]\"] + text + [\"[SEP]\"]\n",
" pad_len = max_len - len(input_sequence)\n",
" \n",
" tokens = tokenizer.convert_tokens_to_ids(input_sequence)\n",
" tokens += [0] * pad_len\n",
" pad_masks = [1] * len(input_sequence) + [0] * pad_len\n",
" segment_ids = [0] * max_len\n",
" \n",
" all_tokens.append(tokens)\n",
" \n",
" return np.array(all_tokens)\n",
" \n",
"def build_model(transformer, max_len=512):\n",
" input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name=\"input_word_ids\")\n",
" sequence_output = transformer(input_word_ids)[0]\n",
" cls_token = sequence_output[:, 0, :]\n",
" out = Dense(1, activation='sigmoid')(cls_token)\n",
" \n",
" model = Model(inputs=input_word_ids, outputs=out)\n",
" model.compile(Adam(lr=1e-5), loss='binary_crossentropy', metrics=['accuracy'])\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de61f744-d2af-403f-b948-154431b4624e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"train = pd.read_csv(\"data/train.csv\")\n",
"test = pd.read_csv(\"data/test.csv\")\n",
"classified_results = pd.read_csv(\"data/sample_submission.csv\")\n",
"\n",
"# load distilbert uncased pre-trained model and corresponding tokenizer from hugging face\n",
"transformer_layer = transformers.TFDistilBertModel.from_pretrained('distilbert-base-uncased')\n",
"tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')"
]
},
{
"cell_type": "markdown",
"id": "2f1a474d-7070-46f1-8d3e-569f656d3695",
"metadata": {
"tags": []
},
"source": [
"# Training with FP32 and BF16, including AMX\n",
"Train the DistilBERT model in three different cases:\n",
"\n",
"1. FP32 (baseline)\n",
"2. BF16 without AMX\n",
"3. BF16 with AMX\n",
"\n",
"The training time is recorded"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8922b858-baf7-4439-af02-bc23c50d93e1",
"metadata": {},
"outputs": [],
"source": [
"# FP32 (baseline)\n",
"# build model\n",
"model = build_model(transformer_layer, max_len=160)\n",
"\n",
"# fine tune model according to disaster tweets dataset\n",
"if is_tune_model:\n",
" train_input = bert_encode(train.text.values, tokenizer, max_len=160)\n",
" train_labels = train.target.values\n",
" start_time = time.time()\n",
" train_history = model.fit(train_input, train_labels, validation_split=0.2, epochs=1, batch_size=16)\n",
" end_time = time.time()\n",
" # save model weights so we don't have to fine tune it every time\n",
" os.makedirs(save_weights_dir, exist_ok=True)\n",
" model.save_weights(save_weights_dir + \"/model_weights.h5\")\n",
"\n",
"else:\n",
" try:\n",
" model.load_weights(load_weights_dir + \"/model_weights.h5\")\n",
" except FileNotFoundError:\n",
" sys.exit(\"\\n\\nTuned model weights not available. Tune model first by setting parameter -t=True\")\n",
"\n",
"fp32_training_time = end_time-start_time\n",
"print(\"Training model with FP32\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9d3bdf2-2748-4c05-bb0e-ebfdbe0d01ad",
"metadata": {},
"outputs": [],
"source": [
"# BF16 without AMX\n",
"os.environ[\"ONEDNN_MAX_CPU_ISA\"] = \"AVX512_BF16\"\n",
"tf.config.optimizer.set_experimental_options({'auto_mixed_precision_onednn_bfloat16':True})\n",
"\n",
"transformer_layer = transformers.TFDistilBertModel.from_pretrained('distilbert-base-uncased')\n",
"tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
"model = build_model(transformer_layer, max_len=160)\n",
"\n",
"# fine tune model according to disaster tweets dataset\n",
"if is_tune_model:\n",
" train_input = bert_encode(train.text.values, tokenizer, max_len=160)\n",
" train_labels = train.target.values\n",
" start_time = time.time()\n",
" train_history = model.fit(train_input, train_labels, validation_split=0.2, epochs=1, batch_size=16)\n",
" end_time = time.time()\n",
" # save model weights so we don't have to fine tune it every time\n",
" os.makedirs(save_weights_dir, exist_ok=True)\n",
" model.save_weights(save_weights_dir + \"/bf16_model_weights.h5\")\n",
"\n",
"else:\n",
" try:\n",
" model.load_weights(load_weights_dir + \"/bf16_model_weights.h5\")\n",
" except FileNotFoundError:\n",
" sys.exit(\"\\n\\nTuned model weights not available. Tune model first by setting parameter -t=True\")\n",
"\n",
"bf16_noAmx_training_time = end_time-start_time\n",
"print(\"Training model with BF16 without AMX\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f00e926f-ebe3-4e5f-aa47-27f8836de067",
"metadata": {},
"outputs": [],
"source": [
"# BF16 with AMX\n",
"os.environ[\"ONEDNN_MAX_CPU_ISA\"] = \"AMX_BF16\"\n",
"\n",
"transformer_layer = transformers.TFDistilBertModel.from_pretrained('distilbert-base-uncased')\n",
"tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
"model = build_model(transformer_layer, max_len=160)\n",
"\n",
"# fine tune model according to disaster tweets dataset\n",
"if is_tune_model:\n",
" train_input = bert_encode(train.text.values, tokenizer, max_len=160)\n",
" train_labels = train.target.values\n",
" start_time = time.time()\n",
" train_history = model.fit(train_input, train_labels, validation_split=0.2, epochs=1, batch_size=16)\n",
" end_time = time.time()\n",
" # save model weights so we don't have to fine tune it every time\n",
" os.makedirs(save_weights_dir, exist_ok=True)\n",
" model.save_weights(save_weights_dir + \"/AMX_bf16_model_weights.h5\")\n",
"\n",
"else:\n",
" try:\n",
" model.load_weights(load_weights_dir + \"/AMX_bf16_model_weights.h5\")\n",
" except FileNotFoundError:\n",
" sys.exit(\"\\n\\nTuned model weights not available. Tune model first by setting parameter -t=True\")\n",
"\n",
"bf16_withAmx_training_time = end_time-start_time\n",
"print(\"Training model with BF16 with AMX\")"
]
},
{
"cell_type": "markdown",
"id": "3db44c05-8ddf-4fbb-93e5-4af67277ae82",
"metadata": {},
"source": [
"# Summary of Results\n",
"The following cells below will summarize the training time for all three cases and display graphs to show the performance speedup."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ca0d726-b133-41f0-93d7-5381488a20e4",
"metadata": {},
"outputs": [],
"source": [
"print(\"Summary\")\n",
"print(\"FP32 training time: %.3f\" %fp32_training_time)\n",
"print(\"BF16 without AMX training time: %.3f\" %bf16_noAmx_training_time)\n",
"print(\"BF16 with AMX training time: %.3f\" %bf16_withAmx_training_time)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1d01731-74f6-4049-b78f-f745d1176ec7",
"metadata": {},
"outputs": [],
"source": [
"plt.figure()\n",
"plt.title(\"DistilBERT Training Time\")\n",
"plt.xlabel(\"Test Case\")\n",
"plt.ylabel(\"Training Time (seconds)\")\n",
"plt.bar([\"FP32\", \"BF16 no AMX\", \"BF16 with AMX\"], [fp32_training_time, bf16_noAmx_training_time, bf16_withAmx_training_time])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1340b8cc-4832-4c0b-a399-cf4c820a6d3a",
"metadata": {},
"outputs": [],
"source": [
"speedup_from_fp32 = fp32_training_time / bf16_withAmx_training_time\n",
"print(\"BF16 with AMX is %.2fX faster than FP32\" %speedup_from_fp32)\n",
"speedup_from_bf16 = bf16_noAmx_training_time / bf16_withAmx_training_time\n",
"print(\"BF16 with AMX is %.2fX faster than BF16 without AMX\" %speedup_from_bf16)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4814af2f-7a4a-4e4c-8b0a-a8a97f633239",
"metadata": {},
"outputs": [],
"source": [
"plt.figure()\n",
"plt.title(\"AMX Speedup\")\n",
"plt.xlabel(\"Test Case\")\n",
"plt.ylabel(\"Speedup\")\n",
"plt.bar([\"FP32\", \"BF16 no AMX\"], [speedup_from_fp32, speedup_from_bf16])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b5e3b1f-dbc2-467d-a62d-dda2211b2012",
"metadata": {},
"outputs": [],
"source": [
"print('[CODE_SAMPLE_COMPLETED_SUCCESFULLY]')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tf_amx_bf16_2",
"language": "python",
"name": "tf_amx_bf16_2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading