|
|
|
@ -120,15 +120,15 @@ class SD21UNCLIP(BaseModel):
|
|
|
|
|
weights = []
|
|
|
|
|
noise_aug = []
|
|
|
|
|
for unclip_cond in unclip_conditioning:
|
|
|
|
|
adm_cond = unclip_cond["clip_vision_output"].image_embeds
|
|
|
|
|
weight = unclip_cond["strength"]
|
|
|
|
|
noise_augment = unclip_cond["noise_augmentation"]
|
|
|
|
|
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
|
|
|
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
|
|
|
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
|
|
|
|
weights.append(weight)
|
|
|
|
|
noise_aug.append(noise_augment)
|
|
|
|
|
adm_inputs.append(adm_out)
|
|
|
|
|
for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
|
|
|
|
|
weight = unclip_cond["strength"]
|
|
|
|
|
noise_augment = unclip_cond["noise_augmentation"]
|
|
|
|
|
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
|
|
|
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
|
|
|
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
|
|
|
|
weights.append(weight)
|
|
|
|
|
noise_aug.append(noise_augment)
|
|
|
|
|
adm_inputs.append(adm_out)
|
|
|
|
|
|
|
|
|
|
if len(noise_aug) > 1:
|
|
|
|
|
adm_out = torch.stack(adm_inputs).sum(0)
|
|
|
|
|