@ -1,8 +1,27 @@
"""
This file is part of ComfyUI .
Copyright ( C ) 2024 Comfy
This program is free software : you can redistribute it and / or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation , either version 3 of the License , or
( at your option ) any later version .
This program is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the
GNU General Public License for more details .
You should have received a copy of the GNU General Public License
along with this program . If not , see < https : / / www . gnu . org / licenses / > .
"""
import torch
import copy
import inspect
import logging
import uuid
import collections
import comfy . utils
import comfy . model_management
@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
model_options [ " disable_cfg1_optimization " ] = True
return model_options
def wipe_lowvram_weight ( m ) :
if hasattr ( m , " prev_comfy_cast_weights " ) :
m . comfy_cast_weights = m . prev_comfy_cast_weights
del m . prev_comfy_cast_weights
m . weight_function = None
m . bias_function = None
class LowVramPatch :
def __init__ ( self , key , model_patcher ) :
self . key = key
self . model_patcher = model_patcher
def __call__ ( self , weight ) :
return self . model_patcher . calculate_weight ( self . model_patcher . patches [ self . key ] , weight , self . key )
class ModelPatcher :
def __init__ ( self , model , load_device , offload_device , size = 0 , weight_inplace_update = False ) :
self . size = size
@ -82,16 +116,29 @@ class ModelPatcher:
self . load_device = load_device
self . offload_device = offload_device
self . weight_inplace_update = weight_inplace_update
self . model_lowvram = False
self . lowvram_patch_counter = 0
self . patches_uuid = uuid . uuid4 ( )
if not hasattr ( self . model , ' model_loaded_weight_memory ' ) :
self . model . model_loaded_weight_memory = 0
if not hasattr ( self . model , ' lowvram_patch_counter ' ) :
self . model . lowvram_patch_counter = 0
if not hasattr ( self . model , ' model_lowvram ' ) :
self . model . model_lowvram = False
def model_size ( self ) :
if self . size > 0 :
return self . size
self . size = comfy . model_management . module_size ( self . model )
return self . size
def loaded_size ( self ) :
return self . model . model_loaded_weight_memory
def lowvram_patch_counter ( self ) :
return self . model . lowvram_patch_counter
def clone ( self ) :
n = ModelPatcher ( self . model , self . load_device , self . offload_device , self . size , weight_inplace_update = self . weight_inplace_update )
n . patches = { }
@ -265,16 +312,16 @@ class ModelPatcher:
sd . pop ( k )
return sd
def patch_weight_to_device ( self , key , device_to = None ):
def patch_weight_to_device ( self , key , device_to = None , inplace_update = False ):
if key not in self . patches :
return
weight = comfy . utils . get_attr ( self . model , key )
inplace_update = self . weight_inplace_update
inplace_update = self . weight_inplace_update or inplace_update
if key not in self . backup :
self . backup [ key ] = weight. to ( device = self . offload_device , copy = inplace_update )
self . backup [ key ] = collections. namedtuple ( ' Dimension ' , [ ' weight ' , ' inplace_update ' ] ) ( weight. to ( device = self . offload_device , copy = inplace_update ) , inplace_update )
if device_to is not None :
temp_weight = comfy . model_management . cast_to_device ( weight , device_to , torch . float32 , copy = True )
@ -304,28 +351,24 @@ class ModelPatcher:
if device_to is not None :
self . model . to ( device_to )
self . model . device = device_to
self . model . model_loaded_weight_memory = self . model_size ( )
return self . model
def patch_model_lowvram ( self , device_to = None , lowvram_model_memory = 0 , force_patch_weights = False ) :
self . patch_model ( device_to , patch_weights = False )
def lowvram_load ( self , device_to = None , lowvram_model_memory = 0 , force_patch_weights = False ) :
logging . info ( " loading in lowvram mode {} " . format ( lowvram_model_memory / ( 1024 * 1024 ) ) )
class LowVramPatch :
def __init__ ( self , key , model_patcher ) :
self . key = key
self . model_patcher = model_patcher
def __call__ ( self , weight ) :
return self . model_patcher . calculate_weight ( self . model_patcher . patches [ self . key ] , weight , self . key )
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
for n , m in self . model . named_modules ( ) :
lowvram_weight = False
if hasattr ( m , " comfy_cast_weights " ) :
module_mem = comfy . model_management . module_size ( m )
if mem_counter + module_mem > = lowvram_model_memory :
lowvram_weight = True
lowvram_counter + = 1
if m . comfy_cast_weights :
continue
weight_key = " {} .weight " . format ( n )
bias_key = " {} .bias " . format ( n )
@ -347,16 +390,31 @@ class ModelPatcher:
m . prev_comfy_cast_weights = m . comfy_cast_weights
m . comfy_cast_weights = True
else :
if hasattr ( m , " comfy_cast_weights " ) :
if m . comfy_cast_weights :
wipe_lowvram_weight ( m )
if hasattr ( m , " weight " ) :
self . patch_weight_to_device ( weight_key ) #TODO: speed this up without causing OOM
mem_counter + = comfy . model_management . module_size ( m )
if m . weight is not None and m . weight . device == device_to :
continue
self . patch_weight_to_device ( weight_key ) #TODO: speed this up without OOM
self . patch_weight_to_device ( bias_key )
m . to ( device_to )
mem_counter + = comfy . model_management . module_size ( m )
logging . debug ( " lowvram: loaded module regularly {} {} " . format ( n , m ) )
self . model_lowvram = True
self . lowvram_patch_counter = patch_counter
if lowvram_counter > 0 :
self . model . model_lowvram = True
else :
self . model . model_lowvram = False
self . model . lowvram_patch_counter + = patch_counter
self . model . device = device_to
self . model . model_loaded_weight_memory = mem_counter
def patch_model_lowvram ( self , device_to = None , lowvram_model_memory = 0 , force_patch_weights = False ) :
self . patch_model ( device_to , patch_weights = False )
self . lowvram_load ( device_to , lowvram_model_memory = lowvram_model_memory , force_patch_weights = force_patch_weights )
return self . model
def calculate_weight ( self , patches , weight , key ) :
@ -529,31 +587,28 @@ class ModelPatcher:
def unpatch_model ( self , device_to = None , unpatch_weights = True ) :
if unpatch_weights :
if self . model _lowvram:
if self . model . model _lowvram:
for m in self . model . modules ( ) :
if hasattr ( m , " prev_comfy_cast_weights " ) :
m . comfy_cast_weights = m . prev_comfy_cast_weights
del m . prev_comfy_cast_weights
m . weight_function = None
m . bias_function = None
wipe_lowvram_weight ( m )
self . model _lowvram = False
self . lowvram_patch_counter = 0
self . model . model_lowvram = False
self . model . lowvram_patch_counter = 0
keys = list ( self . backup . keys ( ) )
if self . weight_inplace_update :
for k in keys :
comfy . utils . copy_to_param ( self . model , k , self . backup [ k ] )
else :
for k in keys :
comfy . utils . set_attr_param ( self . model , k , self . backup [ k ] )
for k in keys :
bk = self . backup [ k ]
if bk . inplace_update :
comfy . utils . copy_to_param ( self . model , k , bk . weight )
else :
comfy . utils . set_attr_param ( self . model , k , bk . weight )
self . backup . clear ( )
if device_to is not None :
self . model . to ( device_to )
self . model . device = device_to
self . model . model_loaded_weight_memory = 0
keys = list ( self . object_patches_backup . keys ( ) )
for k in keys :
@ -561,5 +616,57 @@ class ModelPatcher:
self . object_patches_backup . clear ( )
def partially_unload ( self , device_to , memory_to_free = 0 ) :
memory_freed = 0
patch_counter = 0
for n , m in list ( self . model . named_modules ( ) ) [ : : - 1 ] :
if memory_to_free < memory_freed :
break
shift_lowvram = False
if hasattr ( m , " comfy_cast_weights " ) :
module_mem = comfy . model_management . module_size ( m )
weight_key = " {} .weight " . format ( n )
bias_key = " {} .bias " . format ( n )
if m . weight is not None and m . weight . device != device_to :
for key in [ weight_key , bias_key ] :
bk = self . backup . get ( key , None )
if bk is not None :
if bk . inplace_update :
comfy . utils . copy_to_param ( self . model , key , bk . weight )
else :
comfy . utils . set_attr_param ( self . model , key , bk . weight )
self . backup . pop ( key )
m . to ( device_to )
if weight_key in self . patches :
m . weight_function = LowVramPatch ( weight_key , self )
patch_counter + = 1
if bias_key in self . patches :
m . bias_function = LowVramPatch ( bias_key , self )
patch_counter + = 1
m . prev_comfy_cast_weights = m . comfy_cast_weights
m . comfy_cast_weights = True
memory_freed + = module_mem
logging . debug ( " freed {} " . format ( n ) )
self . model . model_lowvram = True
self . model . lowvram_patch_counter + = patch_counter
self . model . model_loaded_weight_memory - = memory_freed
return memory_freed
def partially_load ( self , device_to , extra_memory = 0 ) :
if self . model . model_lowvram == False :
return 0
if self . model . model_loaded_weight_memory + extra_memory > self . model_size ( ) :
pass #TODO: Full load
current_used = self . model . model_loaded_weight_memory
self . lowvram_load ( device_to , lowvram_model_memory = current_used + extra_memory )
return self . model . model_loaded_weight_memory - current_used
def current_loaded_device ( self ) :
return self . model . device