Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scaled_dot_product_attention fallback to controlmodel_ipadapter.py for PyTorch v1 compatibility #2707

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions scripts/ipadapter/plugable_ipadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,35 @@
from .ipadapter_model import ImageEmbed, IPAdapterModel


def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
# Fallback implementation for PyTorch v1 compatibility (less efficient)
# Slightly modified from: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value

try:
scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
except AttributeError:
pass


def get_block(model, flag):
return {
"input": model.input_blocks,
Expand Down Expand Up @@ -30,7 +59,7 @@ def attn_forward_hacked(self, x, context=None, **kwargs):
(q, k, v),
)

out = torch.nn.functional.scaled_dot_product_attention(
out = scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)
out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
Expand Down Expand Up @@ -227,7 +256,7 @@ def forward(attn_blk, x, q):
ip_k = ip_k.to(dtype=q.dtype)
ip_v = ip_v.to(dtype=q.dtype)

ip_out = torch.nn.functional.scaled_dot_product_attention(
ip_out = scaled_dot_product_attention(
q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
Expand Down