|
|
|
@ -20,31 +20,36 @@ import torch
|
|
|
|
|
import comfy.model_management
|
|
|
|
|
from comfy.cli_args import args
|
|
|
|
|
|
|
|
|
|
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
|
|
|
|
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
|
|
|
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True):
|
|
|
|
|
if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
|
|
|
|
return weight
|
|
|
|
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
|
|
|
r.copy_(weight, non_blocking=non_blocking)
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
def cast_to_input(weight, input, non_blocking=False):
|
|
|
|
|
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
|
|
|
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|
|
|
|
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
|
|
|
|
|
|
|
|
|
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
|
|
|
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
|
|
|
if input is not None:
|
|
|
|
|
if dtype is None:
|
|
|
|
|
dtype = input.dtype
|
|
|
|
|
if bias_dtype is None:
|
|
|
|
|
bias_dtype = dtype
|
|
|
|
|
if device is None:
|
|
|
|
|
device = input.device
|
|
|
|
|
|
|
|
|
|
bias = None
|
|
|
|
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
|
|
|
|
if s.bias is not None:
|
|
|
|
|
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
|
|
|
|
if s.bias_function is not None:
|
|
|
|
|
has_function = s.bias_function is not None
|
|
|
|
|
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
|
|
|
|
if has_function:
|
|
|
|
|
bias = s.bias_function(bias)
|
|
|
|
|
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
|
|
|
|
if s.weight_function is not None:
|
|
|
|
|
|
|
|
|
|
has_function = s.weight_function is not None
|
|
|
|
|
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
|
|
|
|
if has_function:
|
|
|
|
|
weight = s.weight_function(weight)
|
|
|
|
|
return weight, bias
|
|
|
|
|
|
|
|
|
@ -252,7 +257,8 @@ def fp8_linear(self, input):
|
|
|
|
|
if len(input.shape) == 3:
|
|
|
|
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
|
|
|
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
|
|
|
|
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
|
|
|
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
|
|
|
|
w = w.t()
|
|
|
|
|
|
|
|
|
|
scale_weight = self.scale_weight
|
|
|
|
|
scale_input = self.scale_input
|
|
|
|
@ -263,8 +269,8 @@ def fp8_linear(self, input):
|
|
|
|
|
if scale_input is None:
|
|
|
|
|
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight)
|
|
|
|
|
if bias is not None:
|
|
|
|
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
|
|
|
|
else:
|
|
|
|
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
|
|
|
|
|
|
|
|
|