|
|
|
@ -8,6 +8,7 @@ from typing import Optional, Any
|
|
|
|
|
|
|
|
|
|
from ..attention import MemoryEfficientCrossAttention
|
|
|
|
|
from comfy import model_management
|
|
|
|
|
import comfy.ops
|
|
|
|
|
|
|
|
|
|
if model_management.xformers_enabled_vae():
|
|
|
|
|
import xformers
|
|
|
|
@ -48,7 +49,7 @@ class Upsample(nn.Module):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.with_conv = with_conv
|
|
|
|
|
if self.with_conv:
|
|
|
|
|
self.conv = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.conv = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
@ -67,7 +68,7 @@ class Downsample(nn.Module):
|
|
|
|
|
self.with_conv = with_conv
|
|
|
|
|
if self.with_conv:
|
|
|
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
|
|
|
self.conv = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.conv = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=2,
|
|
|
|
@ -95,30 +96,30 @@ class ResnetBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.swish = torch.nn.SiLU(inplace=True)
|
|
|
|
|
self.norm1 = Normalize(in_channels)
|
|
|
|
|
self.conv1 = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.conv1 = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=1)
|
|
|
|
|
if temb_channels > 0:
|
|
|
|
|
self.temb_proj = torch.nn.Linear(temb_channels,
|
|
|
|
|
self.temb_proj = comfy.ops.Linear(temb_channels,
|
|
|
|
|
out_channels)
|
|
|
|
|
self.norm2 = Normalize(out_channels)
|
|
|
|
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
|
|
|
|
self.conv2 = torch.nn.Conv2d(out_channels,
|
|
|
|
|
self.conv2 = comfy.ops.Conv2d(out_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=1)
|
|
|
|
|
if self.in_channels != self.out_channels:
|
|
|
|
|
if self.use_conv_shortcut:
|
|
|
|
|
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=1)
|
|
|
|
|
else:
|
|
|
|
|
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
@ -188,22 +189,22 @@ class AttnBlock(nn.Module):
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
|
|
|
|
self.norm = Normalize(in_channels)
|
|
|
|
|
self.q = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.q = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.k = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.k = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.v = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.v = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
|
|
|
|
self.norm = Normalize(in_channels)
|
|
|
|
|
self.q = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.q = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.k = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.k = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.v = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.v = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
|
|
|
|
self.norm = Normalize(in_channels)
|
|
|
|
|
self.q = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.q = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.k = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.k = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.v = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.v = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0)
|
|
|
|
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.proj_out = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
in_channels,
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
stride=1,
|
|
|
|
@ -399,14 +400,14 @@ class Model(nn.Module):
|
|
|
|
|
# timestep embedding
|
|
|
|
|
self.temb = nn.Module()
|
|
|
|
|
self.temb.dense = nn.ModuleList([
|
|
|
|
|
torch.nn.Linear(self.ch,
|
|
|
|
|
comfy.ops.Linear(self.ch,
|
|
|
|
|
self.temb_ch),
|
|
|
|
|
torch.nn.Linear(self.temb_ch,
|
|
|
|
|
comfy.ops.Linear(self.temb_ch,
|
|
|
|
|
self.temb_ch),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# downsampling
|
|
|
|
|
self.conv_in = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.conv_in = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
self.ch,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
@ -475,7 +476,7 @@ class Model(nn.Module):
|
|
|
|
|
|
|
|
|
|
# end
|
|
|
|
|
self.norm_out = Normalize(block_in)
|
|
|
|
|
self.conv_out = torch.nn.Conv2d(block_in,
|
|
|
|
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
|
|
|
|
out_ch,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
@ -548,7 +549,7 @@ class Encoder(nn.Module):
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
|
|
|
|
# downsampling
|
|
|
|
|
self.conv_in = torch.nn.Conv2d(in_channels,
|
|
|
|
|
self.conv_in = comfy.ops.Conv2d(in_channels,
|
|
|
|
|
self.ch,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
@ -593,7 +594,7 @@ class Encoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# end
|
|
|
|
|
self.norm_out = Normalize(block_in)
|
|
|
|
|
self.conv_out = torch.nn.Conv2d(block_in,
|
|
|
|
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
|
|
|
|
2*z_channels if double_z else z_channels,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
@ -653,7 +654,7 @@ class Decoder(nn.Module):
|
|
|
|
|
self.z_shape, np.prod(self.z_shape)))
|
|
|
|
|
|
|
|
|
|
# z to block_in
|
|
|
|
|
self.conv_in = torch.nn.Conv2d(z_channels,
|
|
|
|
|
self.conv_in = comfy.ops.Conv2d(z_channels,
|
|
|
|
|
block_in,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
@ -695,7 +696,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# end
|
|
|
|
|
self.norm_out = Normalize(block_in)
|
|
|
|
|
self.conv_out = torch.nn.Conv2d(block_in,
|
|
|
|
|
self.conv_out = comfy.ops.Conv2d(block_in,
|
|
|
|
|
out_ch,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=1,
|
|
|
|
|