diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 69ab21c..b596408 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -8,6 +8,7 @@ from typing import Optional, Any from ..attention import MemoryEfficientCrossAttention from comfy import model_management +import comfy.ops if model_management.xformers_enabled_vae(): import xformers @@ -48,7 +49,7 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, + self.conv = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, @@ -67,7 +68,7 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, + self.conv = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, @@ -95,30 +96,30 @@ class ResnetBlock(nn.Module): self.swish = torch.nn.SiLU(inplace=True) self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, + self.conv1 = comfy.ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, + self.temb_proj = comfy.ops.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) - self.conv2 = torch.nn.Conv2d(out_channels, + self.conv2 = comfy.ops.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, + self.conv_shortcut = comfy.ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, + self.nin_shortcut = comfy.ops.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, @@ -188,22 +189,22 @@ class AttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, + self.q = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, + self.k = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, + self.v = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, + self.proj_out = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, + self.q = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, + self.k = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, + self.v = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, + self.proj_out = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, + self.q = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, + self.k = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, + self.v = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, + self.proj_out = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -399,14 +400,14 @@ class Model(nn.Module): # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, + comfy.ops.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, + comfy.ops.Linear(self.temb_ch, self.temb_ch), ]) # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, + self.conv_in = comfy.ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -475,7 +476,7 @@ class Model(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, + self.conv_out = comfy.ops.Conv2d(block_in, out_ch, kernel_size=3, stride=1, @@ -548,7 +549,7 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, + self.conv_in = comfy.ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -593,7 +594,7 @@ class Encoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, + self.conv_out = comfy.ops.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, @@ -653,7 +654,7 @@ class Decoder(nn.Module): self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, + self.conv_in = comfy.ops.Conv2d(z_channels, block_in, kernel_size=3, stride=1, @@ -695,7 +696,7 @@ class Decoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, + self.conv_out = comfy.ops.Conv2d(block_in, out_ch, kernel_size=3, stride=1,