Skip to content

Commit 341f318

Browse files
authored
Updates torchao pin to enable shared embedding quantization (pytorch#9548)
Updates torchao pin to enable shared embedding quantization.
1 parent 94ec549 commit 341f318

File tree

4 files changed

+67
-25
lines changed

4 files changed

+67
-25
lines changed

examples/models/llama/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ python -m examples.models.llama.export_llama \
412412
-d fp32
413413
```
414414
415+
A few notes:
416+
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized with weight zeros or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and uses weight zeros (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32, but is quantized with scales-only. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
417+
- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.
418+
415419
Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.
416420
417421
The first step is to install ExecuTorch (the same as step 3.1 above):

examples/models/llama/export_llama_lib.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def build_args_parser() -> argparse.ArgumentParser:
155155
type=str,
156156
help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
157157
)
158+
parser.add_argument(
159+
"--use_shared_embedding",
160+
action="store_true",
161+
help="Whether the embedding/unembedding weights should be shared. Only available with torchao kernels.",
162+
)
158163
parser.add_argument(
159164
"--pt2e_quantize",
160165
default=None,
@@ -684,6 +689,15 @@ def _validate_args(args):
684689
if args.num_sharding > 0 and not args.qnn:
685690
raise ValueError("Model shard is only supported with qnn backend now.")
686691

692+
if args.use_shared_embedding:
693+
if not (
694+
args.embedding_quantize is not None
695+
and args.embedding_quantize.startswith("torchao:")
696+
):
697+
raise ValueError(
698+
"Shared embedding is only supported with torchao quantization."
699+
)
700+
687701
if (
688702
args.quantization_mode is not None
689703
and args.quantization_mode.startswith("torchao:")
@@ -1122,6 +1136,21 @@ def _get_source_transforms( # noqa
11221136

11231137
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11241138

1139+
if args.embedding_quantize:
1140+
"""
1141+
When this option is selected, it finds all embedding layers and transforms
1142+
into quantized embedding equivalent module.
1143+
1144+
There are cases where the checkpoint is already quantized, for example
1145+
on use_spin_quant is enabled. In that case, it will do the appropriate
1146+
transformations based on the given checkpoint first. In those cases,
1147+
this wil be a no-op.
1148+
"""
1149+
modelname = f"{modelname}_e"
1150+
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
1151+
1152+
# quantization_mode should be applied after embedding_quantize
1153+
# to support shared_embedding
11251154
if args.quantization_mode:
11261155
"""
11271156
When this option is selected, it finds all linear layers and transforms
@@ -1145,19 +1174,6 @@ def _get_source_transforms( # noqa
11451174
)
11461175
)
11471176

1148-
if args.embedding_quantize:
1149-
"""
1150-
When this option is selected, it finds all embedding layers and transforms
1151-
into quantized embedding equivalent module.
1152-
1153-
There are cases where the checkpoint is already quantized, for example
1154-
on use_spin_quant is enabled. In that case, it will do the appropriate
1155-
transformations based on the given checkpoint first. In those cases,
1156-
this wil be a no-op.
1157-
"""
1158-
modelname = f"{modelname}_e"
1159-
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
1160-
11611177
if args.expand_rope_table:
11621178
transforms.append(materialze_broadcast_of_rope_freq_cis)
11631179

examples/models/llama/source_transformation/quantize.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ def quantize( # noqa C901
124124
model,
125125
Int8DynamicActivationIntxWeightConfig(
126126
weight_dtype=getattr(torch, f"int{bitwidth}"),
127-
granularity=(
128-
PerRow() if group_size in [0, -1] else PerGroup(group_size)
129-
),
127+
granularity=(PerRow() if group_size == 0 else PerGroup(group_size)),
130128
has_weight_zeros=False,
131129
),
132130
)
@@ -786,19 +784,43 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
786784

787785
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
788786
if args.embedding_quantize.startswith("torchao:"):
789-
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
787+
from torchao.experimental.quant_api import (
788+
EmbeddingQuantizer,
789+
SharedEmbeddingQuantizer,
790+
)
791+
from torchao.quantization.granularity import PerGroup, PerRow
792+
793+
quant_args = args.embedding_quantize.split(":")[1].split(",")
794+
if len(quant_args) == 2:
795+
bitwidth, group_size = quant_args
796+
has_weight_zeros = True
797+
else:
798+
bitwidth, group_size, has_weight_zeros = quant_args
799+
800+
if group_size in ["none", "None", "0"]:
801+
group_size = 0
802+
790803
group_size = int(group_size)
791804
bitwidth = int(bitwidth)
792-
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
805+
has_weight_zeros = bool(has_weight_zeros)
806+
weight_dtype = getattr(torch, f"int{bitwidth}")
807+
granularity = PerRow() if group_size == 0 else PerGroup(group_size)
793808

794809
def _torchao_embedding_quantizer(model):
795810
with torch.no_grad():
796-
model = IntxWeightEmbeddingQuantizer(
797-
device="cpu",
798-
precision=torch.float32,
799-
bitwidth=bitwidth,
800-
groupsize=group_size,
801-
).quantize(model)
811+
if not args.use_shared_embedding:
812+
EmbeddingQuantizer(
813+
weight_dtype=weight_dtype,
814+
granularity=granularity,
815+
has_weight_zeros=has_weight_zeros,
816+
use_fallback=False,
817+
).quantize(model)
818+
else:
819+
SharedEmbeddingQuantizer(
820+
weight_dtype=weight_dtype,
821+
granularity=granularity,
822+
has_weight_zeros=has_weight_zeros,
823+
).quantize(model)
802824
return model
803825

804826
return _torchao_embedding_quantizer

third-party/ao

Submodule ao updated 23 files

0 commit comments

Comments
 (0)