Skip to content

Commit 75bb712

Browse files
committed
fix
1 parent e808d19 commit 75bb712

File tree

9 files changed

+3247
-70
lines changed

9 files changed

+3247
-70
lines changed

paddlenlp/quantization/hadamard_utils.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,42 +34,27 @@ def matmul_hadU(X):
3434

3535

3636
def random_hadamard_matrix(size, dtype, quantization_config):
37-
if not quantization_config.hadamard_is_block:
37+
if quantization_config.hadamard_block_size < 0:
3838
A = paddle.randint(low=0, high=2, shape=[size, size]).astype("float32") * 2 - 1
3939
Q, _ = paddle.linalg.qr(A)
4040
return Q.astype(dtype), 1
4141
else:
42-
if quantization_config.hadamard_block_size != -1:
43-
assert size % quantization_config.hadamard_block_size == 0, "Please choose a correct block_size"
44-
num_blocks = size // quantization_config.hadamard_block_size
45-
Q = paddle.diag(paddle.ones((quantization_config.hadamard_block_size,), dtype="float32"))
46-
block = matmul_hadU(Q)
47-
return block, quantization_config.hadamard_block_size
48-
else:
49-
num_blocks = size
50-
while not (num_blocks % 2):
51-
num_blocks = num_blocks // 2
52-
block_size = size // num_blocks
53-
Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
54-
block = matmul_hadU(Q)
55-
large_matrix = paddle.zeros([size, size])
56-
57-
for i in range(num_blocks):
58-
start_row = i * block_size
59-
start_col = i * block_size
60-
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
61-
return large_matrix.cast(dtype), block_size
62-
63-
64-
def hadamard_matmul(input, side, hadamard_maxtrix, block_size):
42+
assert size % quantization_config.hadamard_block_size == 0, "Please choose a correct block_size"
43+
Q = paddle.diag(paddle.ones((quantization_config.hadamard_block_size,), dtype="float32"))
44+
block = matmul_hadU(Q)
45+
print("random_hadamard_matrix", block, quantization_config.hadamard_block_size)
46+
return block, quantization_config.hadamard_block_size
47+
48+
49+
def hadamard_matmul(input, side, hadamard_matrix, block_size):
6550
# left -> H.T@input right -> input@H
6651
origin_shape = input.shape
6752
input = input.reshape([-1, origin_shape[-1]])
6853
if side == "left":
6954
# H.T@input -> (input.T@H).T
7055
input = input.transpose([1, 0])
7156
block_num = input.shape[-1] // block_size
72-
output = input.reshape([-1, block_num, block_size]) @ hadamard_maxtrix
57+
output = input.reshape([-1, block_num, block_size]) @ hadamard_matrix
7358
output = output.reshape([-1, block_num * block_size])
7459
if side == "left":
7560
output = output.transpose([1, 0])
@@ -81,17 +66,23 @@ def hadamard_matmul(input, side, hadamard_maxtrix, block_size):
8166
def apply_hadamard_matmul(x, side, quantization_config=None, dequant=False):
8267
if getattr(infohub, "hadamard") is None:
8368
setattr(infohub, "hadamard", {})
84-
if side == "left":
85-
x_shape = x.shape[0]
69+
70+
if quantization_config.hadamard_block_size < 0:
71+
if side == "left":
72+
block_size = x.shape[0]
73+
else:
74+
block_size = x.shape[-1]
8675
else:
87-
x_shape = x.shape[-1]
88-
if x_shape in infohub.hadamard:
89-
hadamard_maxtrix, block_size = infohub.hadamard[x_shape]
76+
block_size = quantization_config.hadamard_block_size
77+
78+
if block_size in infohub.hadamard:
79+
hadamard_matrix, hadamard_scale = infohub.hadamard[block_size]
9080
else:
91-
hadamard_matrix, block_size = random_hadamard_matrix(x_shape, x.dtype, quantization_config)
92-
infohub.hadamard[x_shape] = (hadamard_matrix, block_size)
93-
if block_size > 1:
94-
target_x = hadamard_matmul(x, side, hadamard_maxtrix, block_size)
81+
hadamard_matrix, hadamard_scale = random_hadamard_matrix(block_size, x.dtype, quantization_config)
82+
infohub.hadamard[block_size] = (hadamard_matrix, hadamard_scale)
83+
84+
if hadamard_scale > 1:
85+
target_x = hadamard_matmul(x, side, hadamard_matrix, block_size)
9586
else:
9687
if dequant:
9788
hadamard_matrix = hadamard_matrix.T
@@ -100,4 +91,4 @@ def apply_hadamard_matmul(x, side, quantization_config=None, dequant=False):
10091
else:
10192
target_x = hadamard_matrix.T @ x
10293

103-
return target_x, block_size
94+
return target_x, hadamard_scale

paddlenlp/quantization/qat_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ def quantize(
3838
group=None,
3939
):
4040
if apply_hadamard:
41-
target_x, block_size = apply_hadamard_matmul(x, side, quantization_config)
41+
target_x, hadamard_scale = apply_hadamard_matmul(x, side, quantization_config)
4242
else:
4343
target_x = x
44-
block_size = 1
44+
hadamard_scale = 1
4545
qmin, qmax = QMAX_QMIN_MAPPING[weight_quantize_algo + "_" + tensor_type]
46+
print("apply_hadamard", apply_hadamard, qmin, qmax, tensor_type, hadamard_scale)
4647
if tensor_type == "activation":
4748
if act_scale is not None:
4849
if training:
@@ -51,7 +52,8 @@ def quantize(
5152
if state > quantization_config.apply_online_actscale_step:
5253
scale = act_scale
5354
else:
54-
scale = act_scale
55+
# scale = act_scale
56+
scale = paddle.max(paddle.abs(target_x)) / qmax
5557
else:
5658
scale = paddle.max(paddle.abs(target_x)) / qmax
5759
if weight_quantize_algo in ["a8w8linear", "a8w4linear"]:
@@ -66,7 +68,7 @@ def quantize(
6668
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True)
6769
quant_x = paddle.clip((target_x / scale).round(), qmin, qmax).astype("int8").T
6870
scale.stop_gradient = True
69-
scale = scale.squeeze(0) / block_size
71+
scale = scale.squeeze(0) / hadamard_scale
7072
else:
7173
raise NotImplementedError(f"Unknown {weight_quantize_algo}.")
7274
else:
@@ -79,8 +81,8 @@ def dequantize(quant_x, scale, tensor_type, weight_quantize_algo, apply_hadamard
7981
if weight_quantize_algo in ["a8w8linear", "a8w4linear"]:
8082
x = quant_x.T.astype(scale.dtype)
8183
if apply_hadamard:
82-
x, block_size = apply_hadamard_matmul(x, side, dequant=True)
83-
x *= scale / block_size
84+
x, hadamard_scale = apply_hadamard_matmul(x, side, dequant=True)
85+
x *= scale / hadamard_scale
8486
else:
8587
x *= scale
8688
else:
@@ -112,6 +114,7 @@ def int8_forward(
112114
)
113115

114116
out = paddle.matmul(quant_x, quant_w.T).astype(scale_w.dtype) * (scale_x * scale_w)
117+
# out = paddle.matmul(x, quant_w.T.astype("bfloat16")*scale_w)
115118
if bias is not None:
116119
out += bias
117120
return out

paddlenlp/quantization/quantization_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def __init__(
6060
ignore_modules=None,
6161
group_size=-1,
6262
apply_hadamard=False,
63-
hadamard_is_block=True,
64-
hadamard_block_size=-1,
63+
hadamard_block_size=32,
6564
quant_input_grad=False,
6665
apply_online_actscale_step=200,
6766
scale_epsilon=0,
@@ -135,7 +134,6 @@ def __init__(
135134
self.quant_input_grad = quant_input_grad
136135
self.apply_online_actscale_step = apply_online_actscale_step
137136
self.scale_epsilon = scale_epsilon
138-
self.hadamard_is_block = hadamard_is_block
139137
self.moving_rate = moving_rate
140138
self.hadamard_block_size = hadamard_block_size
141139

paddlenlp/quantization/quantization_linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def forward(self, x):
471471
)
472472
else:
473473
input_parallel = x
474+
print("input_parallel", input_parallel.shape, self.quant_weight.shape, self.quant_scale.shape)
474475

475476
output_parallel = quant_weight_linear(
476477
x=input_parallel,
@@ -490,11 +491,13 @@ def forward(self, x):
490491
)
491492
if self.training:
492493
self.state += 1
493-
494+
print("output_parallel", output_parallel.shape)
495+
print(self.gather_output, self.is_mp, self.gather_output and self.is_mp)
494496
if self.gather_output and self.is_mp:
495497
output = mp_ops._c_concat(output_parallel, group=self.model_parallel_group)
496498
else:
497499
output = output_parallel
500+
print("output", output.shape)
498501
return output
499502

500503

paddlenlp/transformers/conversion_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262

6363
def add_quant_mapping(name_action_mappings, quantization_config):
6464
mapping_keys = list(name_action_mappings.keys())
65-
pattern = r"(?:^|\.)layers(\.[a-zA-Z0-9_]+)+\.weight$"
65+
pattern = r"^(?:.*\.)?layers(\.[a-zA-Z0-9_]+)*\.weight$"
6666
for key in mapping_keys:
6767
if re.match(pattern, key):
6868
quant_key = key.replace("weight", "quant_weight")
@@ -1233,16 +1233,15 @@ def get_tensor_parallel_convert_actions(
12331233
base_model_prefix=None,
12341234
):
12351235
name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=is_split)
1236-
if config.quantization_config.is_weight_quantize():
1237-
name_action_mappings = add_quant_mapping(name_action_mappings, config.quantization_config)
1238-
12391236
state_keys_map = cls._resolve_prefix_keys(
12401237
name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, base_model_prefix=base_model_prefix
12411238
)
12421239
for k, v in state_keys_map.items():
12431240
if k not in name_action_mappings:
12441241
continue
12451242
name_action_mappings[v] = name_action_mappings.pop(k)
1243+
if config.quantization_config.is_weight_quantize():
1244+
name_action_mappings = add_quant_mapping(name_action_mappings, config.quantization_config)
12461245
return name_action_mappings
12471246

12481247
@classmethod

paddlenlp/transformers/llama/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,6 @@ def forward(
933933
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
934934
"""Input shape: Batch x Time x Channel"""
935935
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
936-
937936
if self.fuse_attention_qkv:
938937
mix_layer = self.qkv_proj(hidden_states)
939938
# NOTE for GQA attention fusion (compatible with MHA and MQA):
@@ -987,6 +986,7 @@ def forward(
987986
query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim])
988987
else:
989988
query_states = self.q_proj(hidden_states)
989+
990990
key_states = self.k_proj(hidden_states)
991991
value_states = self.v_proj(hidden_states)
992992

paddlenlp/transformers/model_utils.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -400,23 +400,29 @@ def _load_part_state_dict(
400400
continue
401401

402402
py_safe_slice_ = f.get_slice(key)
403-
if quantization_linear_list is not None:
404-
if key.split(".weight")[0] in quantization_linear_list:
405-
weight = paddle.Tensor.__call__(py_safe_slice_[:], zero_copy=True)
406-
key_name = key.split(".weight")[0]
407-
quant_key_name = key_name + ".quant_weight"
408-
quant_state_dict = convert_to_weight_quantize_state_dict(
409-
state_dict={key_name: weight},
410-
name=key_name,
411-
quantization_config=quantization_config,
412-
dtype=dtype,
413-
weight_quantize_algo=parse_weight_quantize_algo(quantization_config, quant_key_name),
403+
if quantization_linear_list is not None and key.split(".weight")[0] in quantization_linear_list:
404+
weight = paddle.Tensor.__call__(py_safe_slice_[:], zero_copy=True)
405+
key_name = key.split(".weight")[0]
406+
quant_key_name = key_name + ".quant_weight"
407+
quant_scale_name = key_name + ".quant_scale"
408+
quant_state_dict = convert_to_weight_quantize_state_dict(
409+
state_dict={key: weight},
410+
name=key_name,
411+
quantization_config=quantization_config,
412+
dtype=dtype,
413+
weight_quantize_algo=parse_weight_quantize_algo(quantization_config, quant_key_name),
414+
)
415+
if quant_key_name in tensor_parallel_split_mapping:
416+
quant_state_dict[quant_key_name] = tensor_parallel_split_mapping[quant_key_name](
417+
quant_state_dict[quant_key_name]
414418
)
415-
if quant_key_name in tensor_parallel_split_mapping:
416-
quant_state_dict[quant_key_name] = tensor_parallel_split_mapping[quant_key_name](
417-
quant_state_dict[quant_key_name]
419+
if quant_scale_name in tensor_parallel_split_mapping:
420+
quant_state_dict[quant_scale_name] = tensor_parallel_split_mapping[quant_scale_name](
421+
quant_state_dict[quant_scale_name]
418422
)
419-
part_state_dict.update(quant_state_dict)
423+
for key in list(quant_state_dict.keys()):
424+
quant_state_dict[key] = paddle.Tensor.__call__(quant_state_dict[key], zero_copy=True)
425+
part_state_dict.update(quant_state_dict)
420426
else:
421427
if key in tensor_parallel_split_mapping:
422428
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
@@ -518,8 +524,6 @@ def load_state_dict(
518524
for k in list(state_dict.keys()):
519525
if "quant" not in k:
520526
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
521-
else:
522-
print("aaaaaaa", k)
523527

524528
if len(scale_dict) != 0:
525529
if ckpt_quant_stage == "O0":
@@ -1001,6 +1005,7 @@ def _load_state_dict_into_meta_model(
10011005

10021006
if old_param is not None:
10031007
param = param.astype(dtype=old_param.dtype)
1008+
print("meta", param_name, param.shape)
10041009
with paddle.no_grad():
10051010
model_state_dict[param_name].get_tensor()._share_data_with(param.value().get_tensor())
10061011
param.value().get_tensor()._clear()

0 commit comments

Comments
 (0)