Skip to content

Commit d524068

Browse files
committed
fix bug
1 parent dca25ae commit d524068

File tree

10 files changed

+423
-226
lines changed

10 files changed

+423
-226
lines changed

llm/run_finetune.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ def main():
162162
qlora_weight_blocksize=model_args.qlora_weight_blocksize,
163163
qlora_weight_double_quant=model_args.qlora_weight_double_quant,
164164
qlora_weight_double_quant_block_size=model_args.qlora_weight_double_quant_block_size,
165+
apply_hadamard=model_args.apply_hadamard,
166+
hadamard_is_block=model_args.hadamard_is_block,
167+
hadamard_block_size=model_args.hadamard_block_size,
168+
quant_input_grad=model_args.quant_input_grad,
169+
apply_online_actscale_step=model_args.apply_online_actscale_step,
170+
scale_epsilon=model_args.scale_epsilon,
171+
moving_rate=model_args.moving_rate,
165172
)
166173

167174
model_config = AutoConfig.from_pretrained(
@@ -291,7 +298,7 @@ def neft_post_hook(module, input, output):
291298
logging.info("Using ReFT with layers: ", reft_layers)
292299
# init chat_template for tokenizer
293300
init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template)
294-
301+
tokenizer.chat_template = None
295302
# if using chat_template, data_args.eval_with_do_generation must be false
296303
if tokenizer.chat_template is not None:
297304
data_args.eval_with_do_generation = False

ops/src/paddlenlp_kernel/triton/optimizer/adamw_triton.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import paddle
1516
import triton
1617
import triton.language as tl
1718

19+
DTYPE_MAPPING = {
20+
paddle.bfloat16: tl.bfloat16,
21+
paddle.float32: tl.float32,
22+
paddle.float16: tl.float16,
23+
}
24+
1825

1926
@triton.jit
2027
def adamw_kernel(
@@ -30,10 +37,11 @@ def adamw_kernel(
3037
beta1_pow_ptr,
3138
beta2_pow_ptr,
3239
master_weight_ptr,
33-
dtype,
3440
N,
35-
BLOCK_SIZE,
3641
skip_update_param,
42+
param_dtype: tl.constexpr,
43+
moment_dtype: tl.constexpr,
44+
BLOCK_SIZE: tl.constexpr,
3745
):
3846
pid = tl.program_id(0)
3947
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -45,11 +53,8 @@ def adamw_kernel(
4553
param = tl.load(param_ptr + offsets, mask=mask).to(tl.float32)
4654
grad = tl.load(grad_ptr + offsets, mask=mask).to(tl.float32)
4755

48-
moment1 = tl.load(moment1_ptr + offsets, mask=mask)
49-
moment2 = tl.load(moment2_ptr + offsets, mask=mask)
50-
moment_dtype = moment1.dtype
51-
moment1 = moment1.to(tl.float32)
52-
moment2 = moment2.to(tl.float32)
56+
moment1 = tl.load(moment1_ptr + offsets, mask=mask).to(tl.float32)
57+
moment2 = tl.load(moment2_ptr + offsets, mask=mask).to(tl.float32)
5358
lr = tl.load(lr_ptr)
5459
beta1_pow = tl.load(beta1_pow_ptr)
5560
beta2_pow = tl.load(beta2_pow_ptr)
@@ -62,13 +67,6 @@ def adamw_kernel(
6267
moment2 = beta2 * moment2 + (1.0 - beta2) * grad * grad
6368
denom = tl.sqrt(moment2) / tl.sqrt(1.0 - beta2_pow) + epsilon
6469
param += (moment1 / denom) * (-lr / (1 - beta1_pow))
65-
if dtype == 0:
66-
param_dtype = tl.float16
67-
elif dtype == 1:
68-
param_dtype = tl.bfloat16
69-
else:
70-
param_dtype = tl.float32
71-
7270
# Update param
7371
if master_weight_ptr is not None:
7472
tl.store(master_weight_ptr + offsets, param, mask=mask)
@@ -110,12 +108,6 @@ def adamw_triton(
110108
N = param.numel().item()
111109
BLOCK_SIZE = 512
112110
grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
113-
if str(param.dtype) == "paddle.float16":
114-
dtype = 0
115-
elif str(param.dtype) == "paddle.bfloat16":
116-
dtype = 1
117-
else:
118-
dtype = 2
119111
adamw_kernel[grid](
120112
param,
121113
grad,
@@ -129,9 +121,10 @@ def adamw_triton(
129121
beta1_pow,
130122
beta2_pow,
131123
master_weight,
132-
dtype,
133124
N,
134-
BLOCK_SIZE,
135125
skip_update_param,
126+
DTYPE_MAPPING[param.dtype],
127+
DTYPE_MAPPING[moment1.dtype],
128+
BLOCK_SIZE,
136129
)
137130
beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:]

paddlenlp/quantization/hadamard_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import paddle
1616

17+
from paddlenlp.utils import infohub
18+
1719

1820
def matmul_hadU(X):
1921

@@ -74,3 +76,28 @@ def hadamard_matmul(input, side, hadamard_maxtrix, block_size):
7476
output = output.reshape(origin_shape)
7577

7678
return output
79+
80+
81+
def apply_hadamard_matmul(x, side, quantization_config=None, dequant=False):
82+
if getattr(infohub, "hadamard") is None:
83+
setattr(infohub, "hadamard", {})
84+
if side == "left":
85+
x_shape = x.shape[0]
86+
else:
87+
x_shape = x.shape[-1]
88+
if x_shape in infohub.hadamard:
89+
hadamard_maxtrix, block_size = infohub.hadamard[x_shape]
90+
else:
91+
hadamard_matrix, block_size = random_hadamard_matrix(x_shape, x.dtype, quantization_config)
92+
infohub.hadamard[x_shape] = (hadamard_matrix, block_size)
93+
if block_size > 1:
94+
target_x = hadamard_matmul(x, side, hadamard_maxtrix, block_size)
95+
else:
96+
if dequant:
97+
hadamard_matrix = hadamard_matrix.T
98+
if side == "right":
99+
target_x = x @ hadamard_matrix
100+
else:
101+
target_x = hadamard_matrix.T @ x
102+
103+
return target_x, block_size

0 commit comments

Comments
 (0)