Skip to content

Commit 07d44e7

Browse files
committed
apply suggestions from review
1 parent 6af2097 commit 07d44e7

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import os
1516
from typing import Callable, Dict, List, Optional, Union
1617

@@ -1956,7 +1957,7 @@ def _load_norm_into_transformer(
19561957
prefix = prefix or cls.transformer_name
19571958
for key in list(state_dict.keys()):
19581959
if key.split(".")[0] == prefix:
1959-
state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key)
1960+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
19601961

19611962
# Find invalid keys
19621963
transformer_state_dict = transformer.state_dict()
@@ -2278,6 +2279,7 @@ def unload_lora_weights(self):
22782279
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
22792280
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
22802281
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
2282+
transformer._transformer_norm_layers = None
22812283

22822284
@classmethod
22832285
def _maybe_expand_transformer_param_shape_or_error_(
@@ -2303,15 +2305,6 @@ def _maybe_expand_transformer_param_shape_or_error_(
23032305
if key.split(".")[0] == prefix:
23042306
state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key)
23052307

2306-
def get_submodule(module, name):
2307-
for part in name.split("."):
2308-
if len(name) == 0:
2309-
break
2310-
if not hasattr(module, part):
2311-
raise AttributeError(f"Submodule '{part}' not found in '{module}'.")
2312-
module = getattr(module, part)
2313-
return module
2314-
23152308
# Expand transformer parameter shapes if they don't match lora
23162309
has_param_with_shape_update = False
23172310

@@ -2320,12 +2313,9 @@ def get_submodule(module, name):
23202313
module_weight = module.weight.data
23212314
module_bias = module.bias.data if hasattr(module, "bias") else None
23222315
bias = module_bias is not None
2323-
name_split = name.split(".")
23242316

2325-
lora_A_name = f"{name}.lora_A"
2326-
lora_B_name = f"{name}.lora_B"
2327-
lora_A_weight_name = f"{lora_A_name}.weight"
2328-
lora_B_weight_name = f"{lora_B_name}.weight"
2317+
lora_A_weight_name = f"{name}.lora_A.weight"
2318+
lora_B_weight_name = f"{name}.lora_B.weight"
23292319

23302320
if lora_A_weight_name not in state_dict.keys():
23312321
continue
@@ -2353,9 +2343,8 @@ def get_submodule(module, name):
23532343
)
23542344

23552345
has_param_with_shape_update = True
2356-
parent_module_name = ".".join(name_split[:-1])
2357-
current_module_name = name_split[-1]
2358-
parent_module = get_submodule(transformer, parent_module_name)
2346+
parent_module_name, _, current_module_name = name.rpartition(".")
2347+
parent_module = transformer.get_submodule(parent_module_name)
23592348

23602349
expanded_module = torch.nn.Linear(
23612350
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype

0 commit comments

Comments
 (0)