Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1 #136

Open
nancygd opened this issue Aug 10, 2024 · 15 comments

Comments

@nancygd
Copy link

nancygd commented Aug 10, 2024

hello, when i run until KSampler, there is a error, do you know how to deal with it? thank you!

@nullquant
Copy link
Owner

Could you please post full output from ComfyUI?

@Thater
Copy link

Thater commented Aug 16, 2024

I'm having the same issue.

!!! Exception during processing !!! The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1
Traceback (most recent call last):
File "/ComfyUI/execution.py", line 313, in execute
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/execution.py", line 188, in get_output_data
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/execution.py", line 165, in map_node_over_list
process_inputs(input_dict, i)
File "/ComfyUI/execution.py", line 154, in process_inputs
results.append(getattr(obj, func)(**inputs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 2225, in sample_adv
return super().sample(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 732, in sample
samples, images, gifs, preview = process_latent_image(model, seed, steps, cfg, sampler_name, scheduler,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 554, in process_latent_image
samples = KSamplerAdvanced().sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/nodes.py", line 1452, in sample
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/nodes.py", line 1385, in common_ksampler
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 22, in informative_sample
raise e
File "/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 9, in informative_sample
return original_sample(*args, **kwargs) # This code helps interpret error messages that occur within exceptions but does not have any impact on other operations.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/sample.py", line 43, in sample
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 829, in sample
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/model_patch.py", line 120, in modified_sample
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 716, in sample
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 695, in inner_sample
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 600, in sample
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/venv/lib/python3.12/site-packages/torch/utils/contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/k_diffusion/sampling.py", line 635, in sample_dpmpp_2m_sde
denoised = model(x, sigmas[i] * s_in, **extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 299, in call
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 682, in call
return self.predict_noise(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 685, in predict_noise
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 279, in sampling_function
out = calc_cond_batch(model, conds, x, timestep, model_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/samplers.py", line 226, in calc_cond_batch
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep
, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/model_patch.py", line 52, in brushnet_model_function_wrapper
return apply_model_method(x, timestep, **options_dict['c'])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/model_base.py", line 145, in apply_model
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 852, in forward
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 44, in forward_timestep_embed
x = layer(x, context, transformer_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/brushnet_nodes.py", line 1061, in forward_patched_by_brushnet
h += to_add.to(h.dtype).to(h.device)
RuntimeError: The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1

@Thater
Copy link

Thater commented Aug 18, 2024

I don't know what changed, but it works for me now

@nullquant
Copy link
Owner

I can't reproduce the error as well. May be some commits of ComfyUI are the reason.

@Thater
Copy link

Thater commented Aug 18, 2024

I figured it out, I had my ComfyUI launcher script running the argument "--fp8_e4m3fn-unet" for flux.

@nancygd
Copy link
Author

nancygd commented Aug 20, 2024

I can't reproduce the error as well. May be some commits of ComfyUI are the reason.

I don't know what happen at here, but my friend sometime had meet the problem, he can fix the problem when he change checkpoint model. but i can't use the method to deal ,so i don't know what happen ,thank you

@nullquant
Copy link
Owner

What checkpoint do you use? It should be float16, bfloat16, float32, or float64. Also check ComfyUI startup options.

@LiJT
Copy link

LiJT commented Oct 10, 2024

I have this exact same error! as long as I removed --fast command in launch argument, this error is gone.... But i wish I can have both.. --fast is incredibly power speed up 40 series flux generation by 40%
comfyanonymous/ComfyUI@904bf58

@cjc999
Copy link

cjc999 commented Oct 11, 2024

After upgrading to the latest version of Comfyui on October 11th, there is an error message. Returning to the version on October 9th is normal. How to solve this problem?

@zhiyulee3
Copy link

I also encountered

@cjc999
Copy link

cjc999 commented Oct 13, 2024

I upgraded CUDA to version 12.4, and now using the latest version of Comfyui is working properly. Everyone can give it a try.

@Anson2048
Copy link

This issue after upgrade Comyui(comfyanonymous/ComfyUI@e38c942#diff-83920b72a497ff05a33ecf5ac3d19df7911f228f9921fa21e7b64c3b24781fafR101). After that, loading the Flux model no longer requires adding the --fast parameter in the command line; you can directly select fp8_e4m3fn_fast. Removing the --fast parameter resolves this problem.

如果使用绘世可以把这个勾去掉
image

@Orenji-Tangerine
Copy link

This issue after upgrade Comyui(comfyanonymous/ComfyUI@e38c942#diff-83920b72a497ff05a33ecf5ac3d19df7911f228f9921fa21e7b64c3b24781fafR101). After that, loading the Flux model no longer requires adding the --fast parameter in the command line; you can directly select fp8_e4m3fn_fast. Removing the --fast parameter resolves this problem.

如果使用绘世可以把这个勾去掉 image

This seems to solve the issue but the "--fast" still works faster than the weight dtype "fp8_e4m3fn_fast" in the Load Diffusion Model node. Maybe @nullquant can work something out so we can have his argument --fast and BrushNet at the same time. Appreciate that! Thx for the hardwork

@kanxun88
Copy link

"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""

import torch
import comfy.model_management
from comfy.cli_args import args
import comfy.float

cast_to = comfy.model_management.cast_to #TODO: remove once no more references

def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device

bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
    has_function = s.bias_function is not None
    bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
    if has_function:
        bias = s.bias_function(bias)

has_function = s.weight_function is not None
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
    weight = s.weight_function(weight)
return weight, bias

class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None

class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.linear(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return self._conv_forward(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return self._conv_forward(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return self._conv_forward(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)


class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        if self.weight is not None:
            weight, bias = cast_bias_weight(self, input)
        else:
            weight = None
            bias = None
        return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input, output_size=None):
        num_spatial_dims = 2
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,
            num_spatial_dims, self.dilation)

        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.conv_transpose2d(
            input, weight, bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input, output_size=None):
        num_spatial_dims = 1
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,
            num_spatial_dims, self.dilation)

        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.conv_transpose1d(
            input, weight, bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Embedding(torch.nn.Embedding, CastWeightBiasOp):
    def reset_parameters(self):
        self.bias = None
        return None

    def forward_comfy_cast_weights(self, input, out_dtype=None):
        output_dtype = out_dtype
        if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
            out_dtype = None
        weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
        return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            if "out_dtype" in kwargs:
                kwargs.pop("out_dtype")
            return super().forward(*args, **kwargs)

@classmethod
def conv_nd(s, dims, *args, **kwargs):
    if dims == 2:
        return s.Conv2d(*args, **kwargs)
    elif dims == 3:
        return s.Conv3d(*args, **kwargs)
    else:
        raise ValueError(f"unsupported dimensions: {dims}")

class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
comfy_cast_weights = True

class Conv1d(disable_weight_init.Conv1d):
    comfy_cast_weights = True

class Conv2d(disable_weight_init.Conv2d):
    comfy_cast_weights = True

class Conv3d(disable_weight_init.Conv3d):
    comfy_cast_weights = True

class GroupNorm(disable_weight_init.GroupNorm):
    comfy_cast_weights = True

class LayerNorm(disable_weight_init.LayerNorm):
    comfy_cast_weights = True

class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
    comfy_cast_weights = True

class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
    comfy_cast_weights = True

class Embedding(disable_weight_init.Embedding):
    comfy_cast_weights = True

def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None

tensor_2d = False
if len(input.shape) == 2:
    tensor_2d = True
    input = input.unsqueeze(1)


if len(input.shape) == 3:
    w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
    w = w.t()

    scale_weight = self.scale_weight
    scale_input = self.scale_input
    if scale_weight is None:
        scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
    else:
        scale_weight = scale_weight.to(input.device)

    if scale_input is None:
        scale_input = torch.ones((), device=input.device, dtype=torch.float32)
        inn = input.reshape(-1, input.shape[2]).to(dtype)
    else:
        scale_input = scale_input.to(input.device)
        inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)

    if bias is not None:
        o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
    else:
        o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)

    if isinstance(o, tuple):
        o = o[0]

    if tensor_2d:
        return o.reshape(input.shape[0], -1)

    return o.reshape((-1, input.shape[1], self.weight.shape[0]))

return None

class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
self.scale_weight = None
self.scale_input = None
return None

    def forward_comfy_cast_weights(self, input):
        out = fp8_linear(self, input)
        if out is not None:
            return out

        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.linear(input, weight, bias)

def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def init(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().init(*args, **kwargs)

        def reset_parameters(self):
            if not hasattr(self, 'scale_weight'):
                self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)

            if not scale_input:
                self.scale_input = None

            if not hasattr(self, 'scale_input'):
                self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
            return None

        def forward_comfy_cast_weights(self, input):
            if fp8_matrix_mult:
                out = fp8_linear(self, input)
                if out is not None:
                    return out

            weight, bias = cast_bias_weight(self, input)

            if weight.numel() < input.numel(): #TODO: optimize
                return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
            else:
                return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)

        def convert_weight(self, weight, inplace=False, **kwargs):
            if inplace:
                weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
                return weight
            else:
                return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)

        def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
            weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
            if inplace_update:
                self.weight.data.copy_(weight)
            else:
                self.weight = torch.nn.Parameter(weight, requires_grad=False)

return scaled_fp8_op

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)

if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
    return fp8_ops

if compute_dtype is None or weight_dtype == compute_dtype:
    return disable_weight_init

return manual_cast

@bdp2024
Copy link

bdp2024 commented Nov 22, 2024

This issue after upgrade Comyui(comfyanonymous/ComfyUI@e38c942#diff-83920b72a497ff05a33ecf5ac3d19df7911f228f9921fa21e7b64c3b24781fafR101). After that, loading the Flux model no longer requires adding the --fast parameter in the command line; you can directly select fp8_e4m3fn_fast. Removing the --fast parameter resolves this problem.
如果使用绘世可以把这个勾去掉 image

This seems to solve the issue but the "--fast" still works faster than the weight dtype "fp8_e4m3fn_fast" in the Load Diffusion Model node. Maybe @nullquant can work something out so we can have his argument --fast and BrushNet at the same time. Appreciate that! Thx for the hardwork

@Orenji-Tangerine modify two lines of code in brushnet_nodes.py as shown below.
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants