From e6482fbbfc83cd25add0532b2e4c51d305e8a232 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 Apr 2024 17:23:07 -0400 Subject: [PATCH] Refactor calc_cond_uncond_batch into calc_cond_batch. calc_cond_batch can take an arbitrary amount of cond inputs. Added a calc_cond_uncond_batch wrapper with a warning so custom nodes won't break. --- comfy/samplers.py | 67 +++++++++++++++++------------------ comfy_extras/nodes_perpneg.py | 2 +- comfy_extras/nodes_sag.py | 2 +- 3 files changed, 34 insertions(+), 37 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 3678dc8..204a98f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -127,30 +127,23 @@ def cond_cat(c_list): return out -def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - +def calc_cond_batch(model, conds, x_in, timestep, model_options): + out_conds = [] + out_counts = [] to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + if cond is not None: + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue - to_run += [(p, UNCOND)] + to_run += [(p, i)] while len(to_run) > 0: first = to_run[0] @@ -222,22 +215,20 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult + cond_index = cond_or_uncond[o] + out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove + logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") + return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) #The main sampling function shared by all the samplers #Returns denoised @@ -247,7 +238,13 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option else: uncond_ = uncond - cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) + + conds = [cond, uncond_] + + out = calc_cond_batch(model, conds, x, timestep, model_options) + cond_pred = out[0] + uncond_pred = out[1] + if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index dc73c55..9e8a218 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -31,7 +31,7 @@ class PerpNeg: model_options = args["model_options"] nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") - (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options) pos = noise_pred_pos - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index bbd3808..69084e9 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -150,7 +150,7 @@ class SelfAttentionGuidance: degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded_noised = degraded + x - uncond_pred # call into the UNet - (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)