diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 9956d36..cd9a421 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.ops +import comfy.ldm.common_dit def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -407,10 +408,7 @@ class MMDiT(nn.Module): def patchify(self, x): B, C, H, W = x.size() - pad_h = (self.patch_size - H % self.patch_size) % self.patch_size - pad_w = (self.patch_size - W % self.patch_size) % self.patch_size - - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) x = x.view( B, C, diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py new file mode 100644 index 0000000..9900255 --- /dev/null +++ b/comfy/ldm/common_dit.py @@ -0,0 +1,8 @@ +import torch + +def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): + padding_mode = "reflect" + pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] + pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] + return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index e7931c1..db6cf3d 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,6 +15,7 @@ from .layers import ( ) from einops import rearrange, repeat +import comfy.ldm.common_dit @dataclass class FluxParams: @@ -42,7 +43,7 @@ class Flux(nn.Module): self.dtype = dtype params = FluxParams(**kwargs) self.params = params - self.in_channels = params.in_channels + self.in_channels = params.in_channels * 2 * 2 self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( @@ -125,10 +126,7 @@ class Flux(nn.Module): def forward(self, x, timestep, context, y, guidance, **kwargs): bs, c, h, w = x.shape patch_size = 2 - pad_h = (patch_size - h % 2) % patch_size - pad_w = (patch_size - w % 2) % patch_size - - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index ea1b5aa..491a58a 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -9,6 +9,7 @@ from .. import attention from einops import rearrange, repeat from .util import timestep_embedding import comfy.ops +import comfy.ldm.common_dit def default(x, y): if x is not None: @@ -111,9 +112,7 @@ class PatchEmbed(nn.Module): # f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." # ) if self.dynamic_img_pad: - pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] - pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode) + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dda9797..c471196 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -131,7 +131,7 @@ def detect_unet_config(state_dict, key_prefix): if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" - dit_config["in_channels"] = 64 + dit_config["in_channels"] = 16 dit_config["vec_in_dim"] = 768 dit_config["context_in_dim"] = 4096 dit_config["hidden_size"] = 3072