From 9230f658232fd94d0beeddb94aed093a1eca82b5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Aug 2024 05:43:55 -0400 Subject: [PATCH] Fix some controlnets OOMing when loading. --- comfy/controlnet.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index dcfe492..d447958 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -391,7 +391,8 @@ def controlnet_config(sd): else: operations = comfy.ops.disable_weight_init - return model_config, operations, load_device, unet_dtype, manual_cast_dtype + offload_device = comfy.model_management.unet_offload_device() + return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device def controlnet_load_state_dict(control_model, sd): missing, unexpected = control_model.load_state_dict(sd, strict=False) @@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd): def load_controlnet_mmdit(sd): new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") - model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd) + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') for k in sd: new_sd[k] = sd[k] - control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) + control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.SD3() @@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd): def load_controlnet_hunyuandit(controlnet_data): - model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data) + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data) - control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype) + control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype) control_model = controlnet_load_state_dict(control_model, controlnet_data) latent_format = comfy.latent_formats.SDXL() @@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data): return control def load_controlnet_flux_xlabs(sd): - model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd) - control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) + control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) @@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None): if manual_cast_dtype is not None: controlnet_config["operations"] = comfy.ops.manual_cast controlnet_config["dtype"] = unet_dtype + controlnet_config["device"] = comfy.model_management.unet_offload_device() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)