|
|
|
@ -20,9 +20,13 @@ import torch
|
|
|
|
|
import comfy.model_management
|
|
|
|
|
from comfy.cli_args import args
|
|
|
|
|
|
|
|
|
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True):
|
|
|
|
|
if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
|
|
|
|
return weight
|
|
|
|
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
|
|
|
|
if device is None or weight.device == device:
|
|
|
|
|
if not copy:
|
|
|
|
|
if dtype is None or weight.dtype == dtype:
|
|
|
|
|
return weight
|
|
|
|
|
return weight.to(dtype=dtype, copy=copy)
|
|
|
|
|
|
|
|
|
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
|
|
|
r.copy_(weight, non_blocking=non_blocking)
|
|
|
|
|
return r
|
|
|
|
|