ONNX tracing fixes.

main
comfyanonymous 7 months ago
parent 0a6b008117
commit 3b71f84b50

@ -9,6 +9,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ops import comfy.ops
import comfy.ldm.common_dit
def modulate(x, shift, scale): def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@ -407,10 +408,7 @@ class MMDiT(nn.Module):
def patchify(self, x): def patchify(self, x):
B, C, H, W = x.size() B, C, H, W = x.size()
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = x.view( x = x.view(
B, B,
C, C,

@ -0,0 +1,8 @@
import torch
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)

@ -15,6 +15,7 @@ from .layers import (
) )
from einops import rearrange, repeat from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass @dataclass
class FluxParams: class FluxParams:
@ -42,7 +43,7 @@ class Flux(nn.Module):
self.dtype = dtype self.dtype = dtype
params = FluxParams(**kwargs) params = FluxParams(**kwargs)
self.params = params self.params = params
self.in_channels = params.in_channels self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0: if params.hidden_size % params.num_heads != 0:
raise ValueError( raise ValueError(
@ -125,10 +126,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, **kwargs): def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape bs, c, h, w = x.shape
patch_size = 2 patch_size = 2
pad_h = (patch_size - h % 2) % patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
pad_w = (patch_size - w % 2) % patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

@ -9,6 +9,7 @@ from .. import attention
from einops import rearrange, repeat from einops import rearrange, repeat
from .util import timestep_embedding from .util import timestep_embedding
import comfy.ops import comfy.ops
import comfy.ldm.common_dit
def default(x, y): def default(x, y):
if x is not None: if x is not None:
@ -111,9 +112,7 @@ class PatchEmbed(nn.Module):
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." # f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# ) # )
if self.dynamic_img_pad: if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode)
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC x = x.flatten(2).transpose(1, 2) # NCHW -> NLC

@ -131,7 +131,7 @@ def detect_unet_config(state_dict, key_prefix):
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" dit_config["image_model"] = "flux"
dit_config["in_channels"] = 64 dit_config["in_channels"] = 16
dit_config["vec_in_dim"] = 768 dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096 dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072 dit_config["hidden_size"] = 3072

Loading…
Cancel
Save