diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 21f1a87..e7931c1 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -124,10 +124,16 @@ class Flux(nn.Module): def forward(self, x, timestep, context, y, guidance, **kwargs): bs, c, h, w = x.shape - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + patch_size = 2 + pad_h = (patch_size - h % 2) % patch_size + pad_w = (patch_size - w % 2) % patch_size - h_len = (h // 2) - w_len = (w // 2) + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] @@ -135,4 +141,4 @@ class Flux(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance) - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]