Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert old LoRA format to the new format #3595

Open
wants to merge 1 commit into
base: release/9.2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions demo/Diffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,23 @@ def get_dicts(self,
}

else:
# Otherwise, we're dealing with the old format.
warn_message = "You have saved the LoRA weights using the old format. To convert LoRA weights to the new format, first load them in a dictionary and then create a new dictionary as follows: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
print(warn_message)
# Otherwise, convert old LoRA format to the new format.
self.state_dict[path] = {f'unet.{module_name}': params for module_name, params in self.state_dict[path].items()}
keys = list(self.state_dict[path].keys())
if all(key.startswith(('unet', 'text_encoder')) for key in keys):
keys = [k for k in keys if k.startswith(prefix)]
if keys:
print(f"Processing {prefix} LoRA: {path}")
state_dict[path] = {k.replace(f"{prefix}.", ""): v for k, v in self.state_dict[path].items() if k in keys}

if path in self.network_alphas:
if self.network_alphas[path]:
alpha_keys = [k for k in self.network_alphas[path].keys() if k.startswith(prefix)]
network_alphas[path] = {
k.replace(f"{prefix}.", ""): v for k, v in self.network_alphas[path].items() if k in alpha_keys
}
else:
network_alphas[path] = None

return state_dict, network_alphas

Expand Down