@ -389,7 +389,10 @@ def should_use_fp16():
def soft_empty_cache():
global xpu_available
if xpu_available:
global vram_state
if vram_state == VRAMState.MPS:
torch.mps.empty_cache()
elif xpu_available:
torch.xpu.empty_cache()
elif torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda