|
|
|
@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
|
|
|
|
if self.adm_channels is None:
|
|
|
|
|
self.adm_channels = 0
|
|
|
|
|
self.inpaint_model = False
|
|
|
|
|
|
|
|
|
|
self.concat_keys = ()
|
|
|
|
|
logging.info("model_type {}".format(model_type.name))
|
|
|
|
|
logging.debug("adm {}".format(self.adm_channels))
|
|
|
|
|
|
|
|
|
@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def extra_conds(self, **kwargs):
|
|
|
|
|
out = {}
|
|
|
|
|
if self.inpaint_model:
|
|
|
|
|
concat_keys = ("mask", "masked_image")
|
|
|
|
|
if len(self.concat_keys) > 0:
|
|
|
|
|
cond_concat = []
|
|
|
|
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
|
|
|
|
concat_latent_image = kwargs.get("concat_latent_image", None)
|
|
|
|
@ -125,24 +125,16 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
|
|
|
|
|
|
|
|
|
if len(denoise_mask.shape) == len(noise.shape):
|
|
|
|
|
denoise_mask = denoise_mask[:,:1]
|
|
|
|
|
|
|
|
|
|
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
|
|
|
|
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
|
|
|
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
|
|
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
|
|
|
|
if denoise_mask is not None:
|
|
|
|
|
if len(denoise_mask.shape) == len(noise.shape):
|
|
|
|
|
denoise_mask = denoise_mask[:,:1]
|
|
|
|
|
|
|
|
|
|
def blank_inpaint_image_like(latent_image):
|
|
|
|
|
blank_image = torch.ones_like(latent_image)
|
|
|
|
|
# these are the values for "zero" in pixel space translated to latent space
|
|
|
|
|
blank_image[:,0] *= 0.8223
|
|
|
|
|
blank_image[:,1] *= -0.6876
|
|
|
|
|
blank_image[:,2] *= 0.6364
|
|
|
|
|
blank_image[:,3] *= 0.1380
|
|
|
|
|
return blank_image
|
|
|
|
|
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
|
|
|
|
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
|
|
|
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
|
|
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
|
|
|
|
|
|
|
|
|
for ck in concat_keys:
|
|
|
|
|
for ck in self.concat_keys:
|
|
|
|
|
if denoise_mask is not None:
|
|
|
|
|
if ck == "mask":
|
|
|
|
|
cond_concat.append(denoise_mask.to(device))
|
|
|
|
@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
if ck == "mask":
|
|
|
|
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
|
|
|
|
elif ck == "masked_image":
|
|
|
|
|
cond_concat.append(blank_inpaint_image_like(noise))
|
|
|
|
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
|
|
|
|
data = torch.cat(cond_concat, dim=1)
|
|
|
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
|
|
|
|
|
|
|
|
@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module):
|
|
|
|
|
return unet_state_dict
|
|
|
|
|
|
|
|
|
|
def set_inpaint(self):
|
|
|
|
|
self.inpaint_model = True
|
|
|
|
|
self.concat_keys = ("mask", "masked_image")
|
|
|
|
|
def blank_inpaint_image_like(latent_image):
|
|
|
|
|
blank_image = torch.ones_like(latent_image)
|
|
|
|
|
# these are the values for "zero" in pixel space translated to latent space
|
|
|
|
|
blank_image[:,0] *= 0.8223
|
|
|
|
|
blank_image[:,1] *= -0.6876
|
|
|
|
|
blank_image[:,2] *= 0.6364
|
|
|
|
|
blank_image[:,3] *= 0.1380
|
|
|
|
|
return blank_image
|
|
|
|
|
self.blank_inpaint_image_like = blank_inpaint_image_like
|
|
|
|
|
|
|
|
|
|
def memory_required(self, input_shape):
|
|
|
|
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
|
|
|
|