diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index c465619..2564166 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -409,7 +409,7 @@ class MMDiT(nn.Module): pad_h = (self.patch_size - H % self.patch_size) % self.patch_size pad_w = (self.patch_size - W % self.patch_size) % self.patch_size - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') x = x.view( B, C, diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index f37f7ff..aac48a7 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -69,12 +69,14 @@ class PatchEmbed(nn.Module): bias: bool = True, strict_img_size: bool = True, dynamic_img_pad: bool = True, + padding_mode='circular', dtype=None, device=None, operations=None, ): super().__init__() self.patch_size = (patch_size, patch_size) + self.padding_mode = padding_mode if img_size is not None: self.img_size = (img_size, img_size) self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) @@ -110,7 +112,7 @@ class PatchEmbed(nn.Module): if self.dynamic_img_pad: pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC