From 7cb924f68469cd2481b2313f8e5fc02587279bf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Thu, 6 Apr 2023 14:24:47 +0800 Subject: [PATCH] Use separate variables instead of `vram_state` --- comfy/model_management.py | 70 +++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 379cc18..a841677 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -5,9 +5,9 @@ LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 MPS = 5 -XPU = 6 accelerate_enabled = False +xpu_available = False vram_state = NORMAL_VRAM total_vram = 0 @@ -22,7 +22,12 @@ set_vram_to = NORMAL_VRAM try: import torch - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + else: + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) forced_normal_vram = "--normalvram" in sys.argv if not forced_normal_vram and not forced_cpu: @@ -86,17 +91,10 @@ try: except: pass -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - vram_state = XPU -except: - pass - if forced_cpu: vram_state = CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state]) +print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) current_loaded_model = None @@ -133,6 +131,7 @@ def load_model_gpu(model): global current_loaded_model global vram_state global model_accelerated + global xpu_available if model is current_loaded_model: return @@ -149,19 +148,19 @@ def load_model_gpu(model): mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == XPU: - real_model.to("xpu") - pass elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: model_accelerated = False - real_model.cuda() + if xpu_available: + real_model.to("xpu") + else: + real_model.cuda() else: if vram_state == NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_state == LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") + accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda") model_accelerated = True return current_loaded_model @@ -187,8 +186,12 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state + global xpu_available if vram_state == LOW_VRAM or vram_state == NO_VRAM: - return model.cuda() + if xpu_available: + return model.to("xpu") + else: + return model.cuda() return model def unload_if_low_vram(model): @@ -198,14 +201,16 @@ def unload_if_low_vram(model): return model def get_torch_device(): + global xpu_available if vram_state == MPS: return torch.device("mps") - if vram_state == XPU: - return torch.device("xpu") if vram_state == CPU: return torch.device("cpu") else: - return torch.cuda.current_device() + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() def get_autocast_device(dev): if hasattr(dev, 'type'): @@ -235,22 +240,24 @@ def pytorch_attention_enabled(): return ENABLE_PYTORCH_ATTENTION def get_free_memory(dev=None, torch_free_too=False): + global xpu_available if dev is None: dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total - elif hasattr(dev, 'type') and (dev.type == 'xpu'): - mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) - mem_free_torch = mem_free_total else: - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + if xpu_available: + mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) + mem_free_torch = mem_free_total + else: + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch if torch_free_too: return (mem_free_total, mem_free_torch) @@ -274,12 +281,9 @@ def mps_mode(): global vram_state return vram_state == MPS -def xpu_mode(): - global vram_state - return vram_state == XPU - def should_use_fp16(): - if cpu_mode() or mps_mode() or xpu_mode(): + global xpu_available + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? if torch.cuda.is_bf16_supported():