From 04d9bc13afd684a5bd4cb637e26972bb5aee43d1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Apr 2023 15:33:43 -0400 Subject: [PATCH] Safely load pickled embeds that don't load with weights_only=True. --- comfy/sd1_clip.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 1f057f7..42c9b4c 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -3,6 +3,7 @@ import os from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig import torch import traceback +import zipfile class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): @@ -171,6 +172,26 @@ def unescape_important(text): text = text.replace("\0\2", "(") return text +def safe_load_embed_zip(embed_path): + with zipfile.ZipFile(embed_path) as myzip: + names = list(filter(lambda a: "data/" in a, myzip.namelist())) + names.reverse() + for n in names: + with myzip.open(n) as myfile: + data = myfile.read() + number = len(data) // 4 + length_embed = 1024 #sd2.x + if number < 768: + continue + if number % 768 == 0: + length_embed = 768 #sd1.x + num_embeds = number // length_embed + embed = torch.frombuffer(data, dtype=torch.float) + out = embed.reshape((num_embeds, length_embed)).clone() + del embed + return out + + def load_embed(embedding_name, embedding_directory): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] @@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory): embed_path = valid_file + embed_out = None + try: if embed_path.lower().endswith(".safetensors"): import safetensors.torch embed = safetensors.torch.load_file(embed_path, device="cpu") else: if 'weights_only' in torch.load.__code__.co_varnames: - embed = torch.load(embed_path, weights_only=True, map_location="cpu") + try: + embed = torch.load(embed_path, weights_only=True, map_location="cpu") + except: + embed_out = safe_load_embed_zip(embed_path) else: embed = torch.load(embed_path, map_location="cpu") except Exception as e: @@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory): print("error loading embedding, skipping loading:", embedding_name) return None - if 'string_to_param' in embed: - values = embed['string_to_param'].values() - else: - values = embed.values() - return next(iter(values)) + if embed_out is None: + if 'string_to_param' in embed: + values = embed['string_to_param'].values() + else: + values = embed.values() + embed_out = next(iter(values)) + return embed_out class SD1Tokenizer: def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):