Skip to content

Commit

Permalink
make it possible to use hypernetworks without opt split attention
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Oct 7, 2022
1 parent 97bc0b9 commit f7c787e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
42 changes: 34 additions & 8 deletions modules/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import traceback

import torch
from modules import devices

from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat


class HypernetworkModule(torch.nn.Module):
Expand Down Expand Up @@ -48,15 +53,36 @@ def load_hypernetworks(path):

return res

def apply(self, x, context=None, mask=None, original=None):

def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)

if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
if context.shape[1] == 77 and CrossAttention.noise_cond:
context = context + (torch.randn_like(context) * 0.1)
h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
k = self.to_k(h_k(context))
v = self.to_v(h_v(context))
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)

if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context)
v = self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
6 changes: 4 additions & 2 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn.functional import silu

import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
from modules.shared import opts, device, cmd_opts

import ldm.modules.attention
Expand All @@ -20,6 +20,8 @@


def apply_optimizations():
undo_optimizations()

ldm.modules.diffusionmodules.model.nonlinearity = silu

if cmd_opts.opt_split_attention_v1:
Expand All @@ -30,7 +32,7 @@ def apply_optimizations():


def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward

Expand Down

0 comments on commit f7c787e

Please sign in to comment.