Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <mandic00@live.com>
  • Loading branch information
vladmandic committed Nov 29, 2024
1 parent d6c1487 commit a635421
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 25 deletions.
8 changes: 0 additions & 8 deletions modules/lora/lora.py

This file was deleted.

17 changes: 9 additions & 8 deletions modules/lora/lora_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def make_unet_conversion_map() -> Dict[str, str]:

class KeyConvert:
def __init__(self):
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet_"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
self.OFT_PREFIX_UNET = "oft_unet_"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet_"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
self.OFT_PREFIX_UNET = "oft_unet_"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"

def __call__(self, key):
if self.is_sdxl:
Expand Down Expand Up @@ -446,6 +446,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)

sd_lora_rank = 1
if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
Expand Down
6 changes: 4 additions & 2 deletions modules/lora/network.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from collections import namedtuple
import enum
from typing import Union
from collections import namedtuple

from modules import sd_models, hashes, shared


NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

Expand Down Expand Up @@ -105,7 +107,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk):


class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613
return None


Expand Down
1 change: 1 addition & 0 deletions modules/lora/network_norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import modules.lora.network as network


class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
Expand Down
3 changes: 2 additions & 1 deletion modules/lora/network_oft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from einops import rearrange
import modules.lora.network as network
from modules.lora.lyco_helpers import factorization
from einops import rearrange


class ModuleTypeOFT(network.ModuleType):
Expand All @@ -10,6 +10,7 @@ def create_module(self, net: network.Network, weights: network.NetworkWeights):
return NetworkModuleOFT(net, weights)
return None


# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule): # pylint: disable=abstract-method
Expand Down
14 changes: 9 additions & 5 deletions modules/lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import re
import time
import concurrent
import torch
import diffusers.models.lora

import modules.lora.network as network
import modules.lora.network_lora as network_lora
import modules.lora.network_hada as network_hada
Expand All @@ -14,8 +17,6 @@
import modules.lora.network_glora as network_glora
import modules.lora.network_overrides as network_overrides
import modules.lora.lora_convert as lora_convert
import torch
import diffusers.models.lora
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant


Expand Down Expand Up @@ -74,7 +75,7 @@ def assign_network_names_to_compvis_modules(sd_model):
shared.sd_model.network_layer_mapping = network_layer_mapping


def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> network.Network | None:
def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> Union[network.Network, None]:
name = name.replace(".", "_")
shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}')
if not shared.native:
Expand Down Expand Up @@ -103,7 +104,7 @@ def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_
return net


def load_network(name, network_on_disk) -> network.Network | None:
def load_network(name, network_on_disk) -> Union[network.Network, None]:
if not shared.sd_loaded:
return None

Expand Down Expand Up @@ -173,6 +174,7 @@ def load_network(name, network_on_disk) -> network.Network | None:
net.bundle_embeddings = bundle_embeddings
return net


def maybe_recompile_model(names, te_multipliers):
recompile_model = False
if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled:
Expand All @@ -186,7 +188,7 @@ def maybe_recompile_model(names, te_multipliers):
if not recompile_model:
if len(loaded_networks) > 0 and debug:
shared.log.debug('Model Compile: Skipping LoRa loading')
return
return recompile_model
else:
recompile_model = True
shared.compiled_model_state.lora_model = []
Expand Down Expand Up @@ -277,6 +279,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
t1 = time.time()
timer['load'] += t1 - t0


def set_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown, ex_bias):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
Expand Down Expand Up @@ -389,6 +392,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
t1 = time.time()
timer['apply'] += t1 - t0


def network_load():
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
for component_name in ['text_encoder','text_encoder_2', 'unet', 'transformer']:
Expand Down

0 comments on commit a635421

Please sign in to comment.