-
Notifications
You must be signed in to change notification settings - Fork 25
/
model_patch.py
144 lines (112 loc) · 5.77 KB
/
model_patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import comfy
# Check and add 'model_patch' to model.model_options['transformer_options']
def add_model_patch_option(model):
if 'transformer_options' not in model.model_options:
model.model_options['transformer_options'] = {}
to = model.model_options['transformer_options']
if "model_patch" not in to:
to["model_patch"] = {}
return to
# Patch model with model_function_wrapper
def patch_model_function_wrapper(model, forward_patch, remove=False):
def brushnet_model_function_wrapper(apply_model_method, options_dict):
to = options_dict['c']['transformer_options']
control = None
if 'control' in options_dict['c']:
control = options_dict['c']['control']
x = options_dict['input']
timestep = options_dict['timestep']
# check if there are patches to execute
if 'model_patch' not in to or 'forward' not in to['model_patch']:
return apply_model_method(x, timestep, **options_dict['c'])
mp = to['model_patch']
unet = mp['unet']
#print(model.get_model_object("model_sampling").sigmas, len(model.get_model_object("model_sampling").sigmas))
#print(mp['all_sigmas'], len(mp['all_sigmas']))
all_sigmas = mp['all_sigmas']
sigma = to['sigmas'][0].item()
total_steps = all_sigmas.shape[0] - 1
step = torch.argmin((all_sigmas - sigma).abs()).item()
mp['step'] = step
mp['total_steps'] = total_steps
# comfy.model_base.apply_model
xc = model.model.model_sampling.calculate_input(timestep, x)
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
t = model.model.model_sampling.timestep(timestep).float()
# execute all patches
for method in mp['forward']:
method(unet, xc, t, to, control)
return apply_model_method(x, timestep, **options_dict['c'])
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
print('BrushNet is going to replace existing model_function_wrapper:', model.model_options["model_function_wrapper"])
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
to = add_model_patch_option(model)
mp = to['model_patch']
if isinstance(model.model.model_config, comfy.supported_models.SD15):
mp['SDXL'] = False
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
mp['SDXL'] = True
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: ", type(model.model.model_config))
if 'forward' not in mp:
mp['forward'] = []
if remove:
if forward_patch in mp['forward']:
mp['forward'].remove(forward_patch)
else:
mp['forward'].append(forward_patch)
mp['unet'] = model.model.diffusion_model
mp['step'] = 0
mp['total_steps'] = 1
# apply patches to code
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
comfy.samplers.original_sample = comfy.samplers.sample
comfy.samplers.sample = modified_sample
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
'''
Modified by BrushNet nodes
'''
cfg_guider = comfy.samplers.CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
### Modified part ######################################################################
#
to = add_model_patch_option(model)
to['model_patch']['all_sigmas'] = sigmas
#
#sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
#sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
#
#
#if math.isclose(cfg, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
# to['model_patch']['free_guidance'] = False
#else:
# to['model_patch']['free_guidance'] = True
#
#######################################################################################
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
# To use Controlnet with RAUNet it is much easier to modify apply_control a little
def modified_apply_control(h, control, name):
'''
Modified by BrushNet nodes
'''
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(h.dtype).to(h.device)
try:
h += ctrl
except:
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h