From 64ccb3c7e3e1f3e4075669b71b870aae66cec077 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Fri, 23 Aug 2024 00:59:57 -0700 Subject: [PATCH] Rework IPEX check for future inclusion of XPU into Pytorch upstream and do a bit more optimization of ipex.optimize(). (#4562) --- comfy/model_management.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 44e3541..c86b67e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -44,8 +44,10 @@ cpu_state = CPUState.GPU total_vram = 0 +torch_version = torch.version.__version__ + lowvram_available = True -xpu_available = False +xpu_available = int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4) if args.deterministic: logging.info("Using deterministic algorithms for pytorch") @@ -66,10 +68,10 @@ if args.directml is not None: try: import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True + _ = torch.xpu.device_count() + xpu_available = torch.xpu.is_available() except: - pass + xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) try: if torch.backends.mps.is_available(): @@ -189,7 +191,6 @@ VAE_DTYPES = [torch.float32] try: if is_nvidia(): - torch_version = torch.version.__version__ if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True @@ -321,8 +322,9 @@ class LoadedModel: self.model_unload() raise e - if is_intel_xpu() and not args.disable_ipex_optimize: - self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) + if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None: + with torch.no_grad(): + self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) self.weights_loaded = True return self.real_model