Skip to content

Commit a36ebd3

Browse files
committed
up
1 parent 78ee0e6 commit a36ebd3

File tree

4 files changed

+67
-25
lines changed

4 files changed

+67
-25
lines changed

examples/models/llama/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ python -m examples.models.llama.export_llama \
412412
-d fp32
413413
```
414414
415+
A few notes:
416+
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized with weight zeros or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and uses weight zeros (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32, but is quantized with scales-only. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
417+
- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.
418+
415419
Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.
416420
417421
The first step is to install ExecuTorch (the same as step 3.1 above):

examples/models/llama/export_llama_lib.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def build_args_parser() -> argparse.ArgumentParser:
161161
type=str,
162162
help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
163163
)
164+
parser.add_argument(
165+
"--use_shared_embedding",
166+
action="store_true",
167+
help="Whether the embedding/unembedding weights should be shared. Only available with torchao kernels.",
168+
)
164169
parser.add_argument(
165170
"--pt2e_quantize",
166171
default=None,
@@ -664,6 +669,15 @@ def _validate_args(args):
664669
if args.num_sharding > 0 and not args.qnn:
665670
raise ValueError("Model shard is only supported with qnn backend now.")
666671

672+
if args.use_shared_embedding:
673+
if not (
674+
args.embedding_quantize is not None
675+
and args.embedding_quantize.startswith("torchao:")
676+
):
677+
raise ValueError(
678+
"Shared embedding is only supported with torchao quantization."
679+
)
680+
667681
if (
668682
args.quantization_mode is not None
669683
and args.quantization_mode.startswith("torchao:")
@@ -1111,6 +1125,21 @@ def _get_source_transforms( # noqa
11111125

11121126
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11131127

1128+
if args.embedding_quantize:
1129+
"""
1130+
When this option is selected, it finds all embedding layers and transforms
1131+
into quantized embedding equivalent module.
1132+
1133+
There are cases where the checkpoint is already quantized, for example
1134+
on use_spin_quant is enabled. In that case, it will do the appropriate
1135+
transformations based on the given checkpoint first. In those cases,
1136+
this wil be a no-op.
1137+
"""
1138+
modelname = f"{modelname}_e"
1139+
transforms.append(get_quant_embedding_transform(args))
1140+
1141+
# quantization_mode should be applied after embedding_quantize
1142+
# to support shared_embedding
11141143
if args.quantization_mode:
11151144
"""
11161145
When this option is selected, it finds all linear layers and transforms
@@ -1130,19 +1159,6 @@ def _get_source_transforms( # noqa
11301159
get_quant_weight_transform(args, dtype_override, verbose_export())
11311160
)
11321161

1133-
if args.embedding_quantize:
1134-
"""
1135-
When this option is selected, it finds all embedding layers and transforms
1136-
into quantized embedding equivalent module.
1137-
1138-
There are cases where the checkpoint is already quantized, for example
1139-
on use_spin_quant is enabled. In that case, it will do the appropriate
1140-
transformations based on the given checkpoint first. In those cases,
1141-
this wil be a no-op.
1142-
"""
1143-
modelname = f"{modelname}_e"
1144-
transforms.append(get_quant_embedding_transform(args))
1145-
11461162
if args.expand_rope_table:
11471163
transforms.append(materialze_broadcast_of_rope_freq_cis)
11481164

examples/models/llama/source_transformation/quantize.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def quantize( # noqa C901
105105
model,
106106
Int8DynamicActivationIntxWeightConfig(
107107
weight_dtype=getattr(torch, f"int{bitwidth}"),
108-
granularity=(
109-
PerRow() if group_size in [0, -1] else PerGroup(group_size)
110-
),
108+
granularity=(PerRow() if group_size == 0 else PerGroup(group_size)),
111109
has_weight_zeros=False,
112110
),
113111
)
@@ -752,19 +750,43 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
752750

753751
def get_quant_embedding_transform(args):
754752
if args.embedding_quantize.startswith("torchao:"):
755-
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
753+
from torchao.experimental.quant_api import (
754+
EmbeddingQuantizer,
755+
SharedEmbeddingQuantizer,
756+
)
757+
from torchao.quantization.granularity import PerGroup, PerRow
758+
759+
quant_args = args.embedding_quantize.split(":")[1].split(",")
760+
if len(quant_args) == 2:
761+
bitwidth, group_size = quant_args
762+
has_weight_zeros = True
763+
else:
764+
bitwidth, group_size, has_weight_zeros = quant_args
765+
766+
if group_size in ["none", "None", "0"]:
767+
group_size = 0
768+
756769
group_size = int(group_size)
757770
bitwidth = int(bitwidth)
758-
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
771+
has_weight_zeros = bool(has_weight_zeros)
772+
weight_dtype = getattr(torch, f"int{bitwidth}")
773+
granularity = PerRow() if group_size == 0 else PerGroup(group_size)
759774

760775
def _torchao_embedding_quantizer(model):
761776
with torch.no_grad():
762-
model = IntxWeightEmbeddingQuantizer(
763-
device="cpu",
764-
precision=torch.float32,
765-
bitwidth=bitwidth,
766-
groupsize=group_size,
767-
).quantize(model)
777+
if not args.use_shared_embedding:
778+
EmbeddingQuantizer(
779+
weight_dtype=weight_dtype,
780+
granularity=granularity,
781+
has_weight_zeros=has_weight_zeros,
782+
use_fallback=False,
783+
).quantize(model)
784+
else:
785+
SharedEmbeddingQuantizer(
786+
weight_dtype=weight_dtype,
787+
granularity=granularity,
788+
has_weight_zeros=has_weight_zeros,
789+
).quantize(model)
768790
return model
769791

770792
return _torchao_embedding_quantizer

third-party/ao

Submodule ao updated 22 files

0 commit comments

Comments
 (0)