|
|
|
@ -137,10 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|
|
|
|
|
|
|
|
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
|
|
|
|
out_cond = torch.zeros_like(x_in)
|
|
|
|
|
out_count = torch.ones_like(x_in)/100000.0
|
|
|
|
|
out_count = torch.zeros_like(x_in)
|
|
|
|
|
|
|
|
|
|
out_uncond = torch.zeros_like(x_in)
|
|
|
|
|
out_uncond_count = torch.ones_like(x_in)/100000.0
|
|
|
|
|
out_uncond_count = torch.zeros_like(x_in)
|
|
|
|
|
|
|
|
|
|
COND = 0
|
|
|
|
|
UNCOND = 1
|
|
|
|
@ -241,6 +241,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|
|
|
|
out_uncond /= out_uncond_count
|
|
|
|
|
del out_uncond_count
|
|
|
|
|
|
|
|
|
|
torch.nan_to_num(out_cond, nan=0.0, posinf=0.0, neginf=0.0, out=out_cond) #in case out_count or out_uncond_count had some zeros
|
|
|
|
|
torch.nan_to_num(out_uncond, nan=0.0, posinf=0.0, neginf=0.0, out=out_uncond)
|
|
|
|
|
return out_cond, out_uncond
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|