|
|
@ -826,9 +826,14 @@ class UNETLoader:
|
|
|
|
CATEGORY = "advanced/loaders"
|
|
|
|
CATEGORY = "advanced/loaders"
|
|
|
|
|
|
|
|
|
|
|
|
def load_unet(self, unet_name, weight_dtype):
|
|
|
|
def load_unet(self, unet_name, weight_dtype):
|
|
|
|
weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype]
|
|
|
|
dtype = None
|
|
|
|
|
|
|
|
if weight_dtype == "fp8_e4m3fn":
|
|
|
|
|
|
|
|
dtype = torch.float8_e4m3fn
|
|
|
|
|
|
|
|
elif weight_dtype == "fp8_e5m2":
|
|
|
|
|
|
|
|
dtype = torch.float8_e5m2
|
|
|
|
|
|
|
|
|
|
|
|
unet_path = folder_paths.get_full_path("unet", unet_name)
|
|
|
|
unet_path = folder_paths.get_full_path("unet", unet_name)
|
|
|
|
model = comfy.sd.load_unet(unet_path, dtype=weight_dtype)
|
|
|
|
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
|
|
|
return (model,)
|
|
|
|
return (model,)
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPLoader:
|
|
|
|
class CLIPLoader:
|
|
|
|