diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 00a2078..59683f6 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -9,6 +9,8 @@ from typing import Optional, Any from ldm.modules.diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention +import model_management + try: import xformers import xformers.ops @@ -189,12 +191,8 @@ class CrossAttentionBirchSan(nn.Module): _, _, k_tokens = key_t.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - stats = torch.cuda.memory_stats(query.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) + chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD kv_chunk_size_min = None @@ -276,12 +274,7 @@ class CrossAttentionDoggettx(nn.Module): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = model_management.get_free_memory(q.device) gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() diff --git a/comfy/model_management.py b/comfy/model_management.py index 36f925c..14acb0a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -145,14 +145,25 @@ def unload_if_low_vram(model): return model -def get_free_memory(): - dev = torch.cuda.current_device() - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - return mem_free_cuda + mem_free_torch +def get_free_memory(dev=None, torch_free_too=False): + if dev is None: + dev = torch.cuda.current_device() + + if hasattr(dev, 'type') and dev.type == 'cpu': + mem_free_total = psutil.virtual_memory().available + mem_free_torch = mem_free_total + else: + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + if torch_free_too: + return (mem_free_total, mem_free_torch) + else: + return mem_free_total def maximum_batch_area(): global vram_state