Use separate variables instead of `vram_state`

main
藍+85CD 2 years ago
parent 84b9c0ac2f
commit 7cb924f684

@ -5,9 +5,9 @@ LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
MPS = 5
XPU = 6
accelerate_enabled = False
xpu_available = False
vram_state = NORMAL_VRAM
total_vram = 0
@ -22,7 +22,12 @@ set_vram_to = NORMAL_VRAM
try:
import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
else:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv
if not forced_normal_vram and not forced_cpu:
@ -86,17 +91,10 @@ try:
except:
pass
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
vram_state = XPU
except:
pass
if forced_cpu:
vram_state = CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state])
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
current_loaded_model = None
@ -133,6 +131,7 @@ def load_model_gpu(model):
global current_loaded_model
global vram_state
global model_accelerated
global xpu_available
if model is current_loaded_model:
return
@ -149,19 +148,19 @@ def load_model_gpu(model):
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == XPU:
real_model.to("xpu")
pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
model_accelerated = False
real_model.cuda()
if xpu_available:
real_model.to("xpu")
else:
real_model.cuda()
else:
if vram_state == NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
model_accelerated = True
return current_loaded_model
@ -187,8 +186,12 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model):
global vram_state
global xpu_available
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
return model.cuda()
if xpu_available:
return model.to("xpu")
else:
return model.cuda()
return model
def unload_if_low_vram(model):
@ -198,14 +201,16 @@ def unload_if_low_vram(model):
return model
def get_torch_device():
global xpu_available
if vram_state == MPS:
return torch.device("mps")
if vram_state == XPU:
return torch.device("xpu")
if vram_state == CPU:
return torch.device("cpu")
else:
return torch.cuda.current_device()
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev):
if hasattr(dev, 'type'):
@ -235,22 +240,24 @@ def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
elif hasattr(dev, 'type') and (dev.type == 'xpu'):
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if torch_free_too:
return (mem_free_total, mem_free_torch)
@ -274,12 +281,9 @@ def mps_mode():
global vram_state
return vram_state == MPS
def xpu_mode():
global vram_state
return vram_state == XPU
def should_use_fp16():
if cpu_mode() or mps_mode() or xpu_mode():
global xpu_available
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ?
if torch.cuda.is_bf16_supported():

Loading…
Cancel
Save