|
|
|
@ -127,30 +127,23 @@ def cond_cat(c_list):
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|
|
|
|
out_cond = torch.zeros_like(x_in)
|
|
|
|
|
out_count = torch.ones_like(x_in) * 1e-37
|
|
|
|
|
|
|
|
|
|
out_uncond = torch.zeros_like(x_in)
|
|
|
|
|
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
|
|
|
|
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|
|
|
|
out_conds = []
|
|
|
|
|
out_counts = []
|
|
|
|
|
to_run = []
|
|
|
|
|
|
|
|
|
|
COND = 0
|
|
|
|
|
UNCOND = 1
|
|
|
|
|
for i in range(len(conds)):
|
|
|
|
|
out_conds.append(torch.zeros_like(x_in))
|
|
|
|
|
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
|
|
|
|
|
|
|
|
|
to_run = []
|
|
|
|
|
cond = conds[i]
|
|
|
|
|
if cond is not None:
|
|
|
|
|
for x in cond:
|
|
|
|
|
p = get_area_and_mult(x, x_in, timestep)
|
|
|
|
|
if p is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
to_run += [(p, COND)]
|
|
|
|
|
if uncond is not None:
|
|
|
|
|
for x in uncond:
|
|
|
|
|
p = get_area_and_mult(x, x_in, timestep)
|
|
|
|
|
if p is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
to_run += [(p, UNCOND)]
|
|
|
|
|
to_run += [(p, i)]
|
|
|
|
|
|
|
|
|
|
while len(to_run) > 0:
|
|
|
|
|
first = to_run[0]
|
|
|
|
@ -222,22 +215,20 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|
|
|
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
|
|
|
|
else:
|
|
|
|
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
|
|
|
|
del input_x
|
|
|
|
|
|
|
|
|
|
for o in range(batch_chunks):
|
|
|
|
|
if cond_or_uncond[o] == COND:
|
|
|
|
|
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
|
|
|
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
|
|
|
else:
|
|
|
|
|
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
|
|
|
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
|
|
|
del mult
|
|
|
|
|
cond_index = cond_or_uncond[o]
|
|
|
|
|
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
|
|
|
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
|
|
|
|
|
|
|
|
for i in range(len(out_conds)):
|
|
|
|
|
out_conds[i] /= out_counts[i]
|
|
|
|
|
|
|
|
|
|
return out_conds
|
|
|
|
|
|
|
|
|
|
out_cond /= out_count
|
|
|
|
|
del out_count
|
|
|
|
|
out_uncond /= out_uncond_count
|
|
|
|
|
del out_uncond_count
|
|
|
|
|
return out_cond, out_uncond
|
|
|
|
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
|
|
|
|
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
|
|
|
|
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
|
|
|
|
|
|
|
|
|
#The main sampling function shared by all the samplers
|
|
|
|
|
#Returns denoised
|
|
|
|
@ -247,7 +238,13 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|
|
|
|
else:
|
|
|
|
|
uncond_ = uncond
|
|
|
|
|
|
|
|
|
|
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
|
|
|
|
|
|
|
|
|
conds = [cond, uncond_]
|
|
|
|
|
|
|
|
|
|
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
|
|
|
|
cond_pred = out[0]
|
|
|
|
|
uncond_pred = out[1]
|
|
|
|
|
|
|
|
|
|
if "sampler_cfg_function" in model_options:
|
|
|
|
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
|
|
|
|
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
|
|
|
|