From 1de86851b12f6b4e2f9d64da277e3679d3d5e7a1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 11 Mar 2023 15:15:13 -0500 Subject: [PATCH] Try to fix memory issue. --- comfy/ldm/modules/diffusionmodules/model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 01ab2ed..18f7a8b 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -82,10 +82,11 @@ class Downsample(nn.Module): stride=2, padding=0) - def forward(self, x): + def forward(self, x, already_padded=False): if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + if not already_padded: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -553,7 +554,9 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + already_padded = True # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): @@ -563,7 +566,8 @@ class Encoder(nn.Module): h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) + hs.append(self.down[i_level].downsample(hs[-1], already_padded)) + already_padded = False # middle h = hs[-1]