|
|
|
@ -30,7 +30,7 @@ lowvram_available = True
|
|
|
|
|
xpu_available = False
|
|
|
|
|
|
|
|
|
|
if args.deterministic:
|
|
|
|
|
logging.warning("Using deterministic algorithms for pytorch")
|
|
|
|
|
logging.info("Using deterministic algorithms for pytorch")
|
|
|
|
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
|
|
|
|
|
|
|
|
directml_enabled = False
|
|
|
|
@ -42,7 +42,7 @@ if args.directml is not None:
|
|
|
|
|
directml_device = torch_directml.device()
|
|
|
|
|
else:
|
|
|
|
|
directml_device = torch_directml.device(device_index)
|
|
|
|
|
logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
|
|
|
|
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
|
|
|
|
# torch_directml.disable_tiled_resources(True)
|
|
|
|
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
|
|
|
|
|
|
|
|
@ -118,7 +118,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|
|
|
|
|
|
|
|
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
|
|
|
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|
|
|
|
logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
|
|
|
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
|
|
|
|
if not args.normalvram and not args.cpu:
|
|
|
|
|
if lowvram_available and total_vram <= 4096:
|
|
|
|
|
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
|
|
|
@ -144,7 +144,7 @@ else:
|
|
|
|
|
pass
|
|
|
|
|
try:
|
|
|
|
|
XFORMERS_VERSION = xformers.version.__version__
|
|
|
|
|
logging.warning("xformers version: {}".format(XFORMERS_VERSION))
|
|
|
|
|
logging.info("xformers version: {}".format(XFORMERS_VERSION))
|
|
|
|
|
if XFORMERS_VERSION.startswith("0.0.18"):
|
|
|
|
|
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
|
|
|
|
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
|
|
|
@ -212,11 +212,11 @@ elif args.highvram or args.gpu_only:
|
|
|
|
|
FORCE_FP32 = False
|
|
|
|
|
FORCE_FP16 = False
|
|
|
|
|
if args.force_fp32:
|
|
|
|
|
logging.warning("Forcing FP32, if this improves things please report it.")
|
|
|
|
|
logging.info("Forcing FP32, if this improves things please report it.")
|
|
|
|
|
FORCE_FP32 = True
|
|
|
|
|
|
|
|
|
|
if args.force_fp16:
|
|
|
|
|
logging.warning("Forcing FP16.")
|
|
|
|
|
logging.info("Forcing FP16.")
|
|
|
|
|
FORCE_FP16 = True
|
|
|
|
|
|
|
|
|
|
if lowvram_available:
|
|
|
|
@ -230,12 +230,12 @@ if cpu_state != CPUState.GPU:
|
|
|
|
|
if cpu_state == CPUState.MPS:
|
|
|
|
|
vram_state = VRAMState.SHARED
|
|
|
|
|
|
|
|
|
|
logging.warning(f"Set vram state to: {vram_state.name}")
|
|
|
|
|
logging.info(f"Set vram state to: {vram_state.name}")
|
|
|
|
|
|
|
|
|
|
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
|
|
|
|
|
|
|
|
|
if DISABLE_SMART_MEMORY:
|
|
|
|
|
logging.warning("Disabling smart memory management")
|
|
|
|
|
logging.info("Disabling smart memory management")
|
|
|
|
|
|
|
|
|
|
def get_torch_device_name(device):
|
|
|
|
|
if hasattr(device, 'type'):
|
|
|
|
@ -253,11 +253,11 @@ def get_torch_device_name(device):
|
|
|
|
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
logging.warning("Device: {}".format(get_torch_device_name(get_torch_device())))
|
|
|
|
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
|
|
|
|
except:
|
|
|
|
|
logging.warning("Could not pick default device.")
|
|
|
|
|
|
|
|
|
|
logging.warning("VAE dtype: {}".format(VAE_DTYPE))
|
|
|
|
|
logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
|
|
|
|
|
|
|
|
|
current_loaded_models = []
|
|
|
|
|
|
|
|
|
@ -300,7 +300,7 @@ class LoadedModel:
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
if lowvram_model_memory > 0:
|
|
|
|
|
logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
|
|
|
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
|
|
|
|
mem_counter = 0
|
|
|
|
|
for m in self.real_model.modules():
|
|
|
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
|
|
@ -347,7 +347,7 @@ def unload_model_clones(model):
|
|
|
|
|
to_unload = [i] + to_unload
|
|
|
|
|
|
|
|
|
|
for i in to_unload:
|
|
|
|
|
logging.warning("unload clone {}".format(i))
|
|
|
|
|
logging.debug("unload clone {}".format(i))
|
|
|
|
|
current_loaded_models.pop(i).model_unload()
|
|
|
|
|
|
|
|
|
|
def free_memory(memory_required, device, keep_loaded=[]):
|
|
|
|
@ -389,7 +389,7 @@ def load_models_gpu(models, memory_required=0):
|
|
|
|
|
models_already_loaded.append(loaded_model)
|
|
|
|
|
else:
|
|
|
|
|
if hasattr(x, "model"):
|
|
|
|
|
logging.warning(f"Requested to load {x.model.__class__.__name__}")
|
|
|
|
|
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
|
|
|
|
models_to_load.append(loaded_model)
|
|
|
|
|
|
|
|
|
|
if len(models_to_load) == 0:
|
|
|
|
@ -399,7 +399,7 @@ def load_models_gpu(models, memory_required=0):
|
|
|
|
|
free_memory(extra_mem, d, models_already_loaded)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
|
|
|
|
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
|
|
|
|
|
|
|
|
|
total_memory_required = {}
|
|
|
|
|
for loaded_model in models_to_load:
|
|
|
|
|