11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+
14
15
import os
15
16
from typing import Callable , Dict , List , Optional , Union
16
17
@@ -1956,7 +1957,7 @@ def _load_norm_into_transformer(
1956
1957
prefix = prefix or cls .transformer_name
1957
1958
for key in list (state_dict .keys ()):
1958
1959
if key .split ("." )[0 ] == prefix :
1959
- state_dict [key . replace (f"{ prefix } ." , "" ) ] = state_dict .pop (key )
1960
+ state_dict [key [ len (f"{ prefix } ." ) :] ] = state_dict .pop (key )
1960
1961
1961
1962
# Find invalid keys
1962
1963
transformer_state_dict = transformer .state_dict ()
@@ -2278,6 +2279,7 @@ def unload_lora_weights(self):
2278
2279
transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer
2279
2280
if hasattr (transformer , "_transformer_norm_layers" ) and transformer ._transformer_norm_layers :
2280
2281
transformer .load_state_dict (transformer ._transformer_norm_layers , strict = False )
2282
+ transformer ._transformer_norm_layers = None
2281
2283
2282
2284
@classmethod
2283
2285
def _maybe_expand_transformer_param_shape_or_error_ (
@@ -2303,15 +2305,6 @@ def _maybe_expand_transformer_param_shape_or_error_(
2303
2305
if key .split ("." )[0 ] == prefix :
2304
2306
state_dict [key .replace (f"{ prefix } ." , "" )] = state_dict .pop (key )
2305
2307
2306
- def get_submodule (module , name ):
2307
- for part in name .split ("." ):
2308
- if len (name ) == 0 :
2309
- break
2310
- if not hasattr (module , part ):
2311
- raise AttributeError (f"Submodule '{ part } ' not found in '{ module } '." )
2312
- module = getattr (module , part )
2313
- return module
2314
-
2315
2308
# Expand transformer parameter shapes if they don't match lora
2316
2309
has_param_with_shape_update = False
2317
2310
@@ -2320,12 +2313,9 @@ def get_submodule(module, name):
2320
2313
module_weight = module .weight .data
2321
2314
module_bias = module .bias .data if hasattr (module , "bias" ) else None
2322
2315
bias = module_bias is not None
2323
- name_split = name .split ("." )
2324
2316
2325
- lora_A_name = f"{ name } .lora_A"
2326
- lora_B_name = f"{ name } .lora_B"
2327
- lora_A_weight_name = f"{ lora_A_name } .weight"
2328
- lora_B_weight_name = f"{ lora_B_name } .weight"
2317
+ lora_A_weight_name = f"{ name } .lora_A.weight"
2318
+ lora_B_weight_name = f"{ name } .lora_B.weight"
2329
2319
2330
2320
if lora_A_weight_name not in state_dict .keys ():
2331
2321
continue
@@ -2353,9 +2343,8 @@ def get_submodule(module, name):
2353
2343
)
2354
2344
2355
2345
has_param_with_shape_update = True
2356
- parent_module_name = "." .join (name_split [:- 1 ])
2357
- current_module_name = name_split [- 1 ]
2358
- parent_module = get_submodule (transformer , parent_module_name )
2346
+ parent_module_name , _ , current_module_name = name .rpartition ("." )
2347
+ parent_module = transformer .get_submodule (parent_module_name )
2359
2348
2360
2349
expanded_module = torch .nn .Linear (
2361
2350
in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype
0 commit comments