|
|
|
@ -7,9 +7,10 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
import comfy.utils
|
|
|
|
|
import comfy.ops
|
|
|
|
|
|
|
|
|
|
def conv(n_in, n_out, **kwargs):
|
|
|
|
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
|
|
|
|
return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
|
|
|
|
|
|
|
|
|
class Clamp(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
@ -19,7 +20,7 @@ class Block(nn.Module):
|
|
|
|
|
def __init__(self, n_in, n_out):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
|
|
|
|
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
|
|
|
|
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
|
|
|
|
self.fuse = nn.ReLU()
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.fuse(self.conv(x) + self.skip(x))
|
|
|
|
|