@@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
153
153
return model
154
154
155
155
156
- # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1 /src/peft/utils/integrations.py#L41
157
- def dequantize_bnb_weight (weight : "torch.nn.Parameter" , state = None ):
156
+ # Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604 /src/peft/utils/integrations.py#L81
157
+ def dequantize_bnb_weight (weight : "torch.nn.Parameter" , state = None , dtype : "torch.dtype" = None ):
158
158
"""
159
159
Helper function to dequantize 4bit or 8bit bnb weights.
160
160
@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
177
177
if state .SCB is None :
178
178
state .SCB = weight .SCB
179
179
180
- im = torch .eye (weight .data .shape [- 1 ]).contiguous ().half ().to (weight .device )
181
- im , imt , SCim , SCimt , coo_tensorim = bnb .functional .double_quant (im )
182
- im , Sim = bnb .functional .transform (im , "col32" )
183
- if state .CxB is None :
184
- state .CxB , state .SB = bnb .functional .transform (weight .data , to_order = state .formatB )
185
- out32 , Sout32 = bnb .functional .igemmlt (im , state .CxB , Sim , state .SB )
186
- return bnb .functional .mm_dequant (out32 , Sout32 , SCim , state .SCB , bias = None ).t ()
180
+ if hasattr (bnb .functional , "int8_vectorwise_dequant" ):
181
+ # Use bitsandbytes API if available (requires v0.45.0+)
182
+ dequantized = bnb .functional .int8_vectorwise_dequant (weight .data , state .SCB )
183
+ else :
184
+ # Multiply by (scale/127) to dequantize.
185
+ dequantized = weight .data * state .SCB .view (- 1 , 1 ) * 7.874015718698502e-3
186
+
187
+ if dtype :
188
+ dequantized = dequantized .to (dtype )
189
+ return dequantized
187
190
188
191
189
192
def _create_accelerate_new_hook (old_hook ):
@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):
205
208
206
209
def _dequantize_and_replace (
207
210
model ,
211
+ dtype ,
208
212
modules_to_not_convert = None ,
209
213
current_key_name = None ,
210
214
quantization_config = None ,
@@ -244,7 +248,7 @@ def _dequantize_and_replace(
244
248
else :
245
249
state = None
246
250
247
- new_module .weight = torch .nn .Parameter (dequantize_bnb_weight (module .weight , state ))
251
+ new_module .weight = torch .nn .Parameter (dequantize_bnb_weight (module .weight , state , dtype ))
248
252
249
253
if bias is not None :
250
254
new_module .bias = bias
@@ -263,9 +267,10 @@ def _dequantize_and_replace(
263
267
if len (list (module .children ())) > 0 :
264
268
_ , has_been_replaced = _dequantize_and_replace (
265
269
module ,
266
- modules_to_not_convert ,
267
- current_key_name ,
268
- quantization_config ,
270
+ dtype = dtype ,
271
+ modules_to_not_convert = modules_to_not_convert ,
272
+ current_key_name = current_key_name ,
273
+ quantization_config = quantization_config ,
269
274
has_been_replaced = has_been_replaced ,
270
275
)
271
276
# Remove the last key for recursion
@@ -280,6 +285,7 @@ def dequantize_and_replace(
280
285
):
281
286
model , has_been_replaced = _dequantize_and_replace (
282
287
model ,
288
+ dtype = model .dtype ,
283
289
modules_to_not_convert = modules_to_not_convert ,
284
290
quantization_config = quantization_config ,
285
291
)
0 commit comments