|
|
|
@ -36,8 +36,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|
|
|
|
strength = cond[1]['strength']
|
|
|
|
|
|
|
|
|
|
adm_cond = None
|
|
|
|
|
if 'adm' in cond[1]:
|
|
|
|
|
adm_cond = cond[1]['adm']
|
|
|
|
|
if 'adm_encoded' in cond[1]:
|
|
|
|
|
adm_cond = cond[1]['adm_encoded']
|
|
|
|
|
|
|
|
|
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
|
|
|
|
mult = torch.ones_like(input_x) * strength
|
|
|
|
@ -405,7 +405,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|
|
|
|
else:
|
|
|
|
|
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
|
|
|
|
x[1] = x[1].copy()
|
|
|
|
|
x[1]["adm"] = torch.cat([adm_out] * batch_size)
|
|
|
|
|
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
|
|
|
|
|
|
|
|
|
|
return conds
|
|
|
|
|
|
|
|
|
|