|
|
|
@ -61,12 +61,38 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
|
|
|
|
self.neg_scale = neg_scale
|
|
|
|
|
|
|
|
|
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
|
|
|
|
# in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
|
|
|
|
|
# but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly
|
|
|
|
|
|
|
|
|
|
positive_cond = self.conds.get("positive", None)
|
|
|
|
|
negative_cond = self.conds.get("negative", None)
|
|
|
|
|
empty_cond = self.conds.get("empty_negative_prompt", None)
|
|
|
|
|
|
|
|
|
|
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, positive_cond, empty_cond], x, timestep, model_options)
|
|
|
|
|
return perp_neg(x, out[1], out[0], out[2], self.neg_scale, self.cfg)
|
|
|
|
|
(noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
|
|
|
|
|
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
|
|
|
|
|
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
|
|
|
|
|
|
|
|
|
|
# normally this would be done in cfg_function, but we skipped
|
|
|
|
|
# that for efficiency: we can compute the noise predictions in
|
|
|
|
|
# a single call to calc_cond_batch() (rather than two)
|
|
|
|
|
# so we replicate the hook here
|
|
|
|
|
for fn in model_options.get("sampler_post_cfg_function", []):
|
|
|
|
|
args = {
|
|
|
|
|
"denoised": cfg_result,
|
|
|
|
|
"cond": positive_cond,
|
|
|
|
|
"uncond": negative_cond,
|
|
|
|
|
"model": self.inner_model,
|
|
|
|
|
"uncond_denoised": noise_pred_neg,
|
|
|
|
|
"cond_denoised": noise_pred_pos,
|
|
|
|
|
"sigma": timestep,
|
|
|
|
|
"model_options": model_options,
|
|
|
|
|
"input": x,
|
|
|
|
|
# not in the original call in samplers.py:cfg_function, but made available for future hooks
|
|
|
|
|
"empty_cond": empty_cond,
|
|
|
|
|
"empty_cond_denoised": noise_pred_empty,}
|
|
|
|
|
cfg_result = fn(args)
|
|
|
|
|
|
|
|
|
|
return cfg_result
|
|
|
|
|
|
|
|
|
|
class PerpNegGuider:
|
|
|
|
|
@classmethod
|
|
|
|
|