diff --git a/comfy/caching.py b/comfy/caching.py new file mode 100644 index 0000000..abcf68a --- /dev/null +++ b/comfy/caching.py @@ -0,0 +1,299 @@ +import itertools +from typing import Sequence, Mapping +from comfy.graph import DynamicPrompt + +import nodes + +from comfy.graph_utils import is_link + +class CacheKeySet: + def __init__(self, dynprompt, node_ids, is_changed_cache): + self.keys = {} + self.subcache_keys = {} + + def add_keys(self, node_ids): + raise NotImplementedError() + + def all_node_ids(self): + return set(self.keys.keys()) + + def get_used_keys(self): + return self.keys.values() + + def get_used_subcache_keys(self): + return self.subcache_keys.values() + + def get_data_key(self, node_id): + return self.keys.get(node_id, None) + + def get_subcache_key(self, node_id): + return self.subcache_keys.get(node_id, None) + +class Unhashable: + def __init__(self): + self.value = float("NaN") + +def to_hashable(obj): + # So that we don't infinitely recurse since frozenset and tuples + # are Sequences. + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + else: + # TODO - Support other objects like tensors? + return Unhashable() + +class CacheKeySetID(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.add_keys(node_ids) + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = (node_id, node["class_type"]) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + +class CacheKeySetInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.is_changed_cache = is_changed_cache + self.add_keys(node_ids) + + def include_node_id_in_input(self) -> bool: + return False + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def get_node_signature(self, dynprompt, node_id): + signature = [] + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + for ancestor_id in ancestors: + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + return to_hashable(signature) + + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, self.is_changed_cache.get(node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): + signature.append(node_id) + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, dynprompt, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + inputs = dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + +class BasicCache: + def __init__(self, key_class): + self.key_class = key_class + self.initialized = False + self.dynprompt: DynamicPrompt + self.cache_key_set: CacheKeySet + self.cache = {} + self.subcaches = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + self.dynprompt = dynprompt + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.is_changed_cache = is_changed_cache + self.initialized = True + + def all_node_ids(self): + assert self.initialized + node_ids = self.cache_key_set.all_node_ids() + for subcache in self.subcaches.values(): + node_ids = node_ids.union(subcache.all_node_ids()) + return node_ids + + def _clean_cache(self): + preserve_keys = set(self.cache_key_set.get_used_keys()) + to_remove = [] + for key in self.cache: + if key not in preserve_keys: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + def _clean_subcaches(self): + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + + to_remove = [] + for key in self.subcaches: + if key not in preserve_subcaches: + to_remove.append(key) + for key in to_remove: + del self.subcaches[key] + + def clean_unused(self): + assert self.initialized + self._clean_cache() + self._clean_subcaches() + + def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + return self.cache[cache_key] + else: + return None + + def _ensure_subcache(self, node_id, children_ids): + subcache_key = self.cache_key_set.get_subcache_key(node_id) + subcache = self.subcaches.get(subcache_key, None) + if subcache is None: + subcache = BasicCache(self.key_class) + self.subcaches[subcache_key] = subcache + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + return subcache + + def _get_subcache(self, node_id): + assert self.initialized + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + return self.subcaches[subcache_key] + else: + return None + + def recursive_debug_dump(self): + result = [] + for key in self.cache: + result.append({"key": key, "value": self.cache[key]}) + for key in self.subcaches: + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) + return result + +class HierarchicalCache(BasicCache): + def __init__(self, key_class): + super().__init__(key_class) + + def _get_cache_for(self, node_id): + assert self.dynprompt is not None + parent_id = self.dynprompt.get_parent_node_id(node_id) + if parent_id is None: + return self + + hierarchy = [] + while parent_id is not None: + hierarchy.append(parent_id) + parent_id = self.dynprompt.get_parent_node_id(parent_id) + + cache = self + for parent_id in reversed(hierarchy): + cache = cache._get_subcache(parent_id) + if cache is None: + return None + return cache + + def get(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return cache._get_immediate(node_id) + + def set(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + cache._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._ensure_subcache(node_id, children_ids) + +class LRUCache(BasicCache): + def __init__(self, key_class, max_size=100): + super().__init__(key_class) + self.max_size = max_size + self.min_generation = 0 + self.generation = 0 + self.used_generation = {} + self.children = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + super().set_prompt(dynprompt, node_ids, is_changed_cache) + self.generation += 1 + for node_id in node_ids: + self._mark_used(node_id) + + def clean_unused(self): + while len(self.cache) > self.max_size and self.min_generation < self.generation: + self.min_generation += 1 + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + for key in to_remove: + del self.cache[key] + del self.used_generation[key] + if key in self.children: + del self.children[key] + self._clean_subcaches() + + def get(self, node_id): + self._mark_used(node_id) + return self._get_immediate(node_id) + + def _mark_used(self, node_id): + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key is not None: + self.used_generation[cache_key] = self.generation + + def set(self, node_id, value): + self._mark_used(node_id) + return self._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + # Just uses subcaches for tracking 'live' nodes + super()._ensure_subcache(node_id, children_ids) + + self.cache_key_set.add_keys(children_ids) + self._mark_used(node_id) + cache_key = self.cache_key_set.get_data_key(node_id) + self.children[cache_key] = [] + for child_id in children_ids: + self._mark_used(child_id) + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) + return self + diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 2397de3..a895c7e 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -92,6 +92,10 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) +cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") +cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") diff --git a/comfy/graph.py b/comfy/graph.py new file mode 100644 index 0000000..8980c69 --- /dev/null +++ b/comfy/graph.py @@ -0,0 +1,237 @@ +import nodes + +from comfy.graph_utils import is_link + +class DependencyCycleError(Exception): + pass + +class NodeInputError(Exception): + pass + +class NodeNotFoundError(Exception): + pass + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + self.ephemeral_display = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + raise NodeNotFoundError(f"Node {node_id} not found") + + def has_node(self, node_id): + return node_id in self.original_prompt or node_id in self.ephemeral_prompt + + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = parent_id + self.ephemeral_display[node_id] = display_id + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + + def get_original_prompt(self): + return self.original_prompt + +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if "required" in valid_inputs and input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + +class TopologicalSort: + def __init__(self, dynprompt): + self.dynprompt = dynprompt + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_info(class_def, input_name) + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all") + value = inputs[to_input] + if not is_link(value): + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant") + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): + if unique_id in self.pendingNodes: + return + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] + if include_lazy or not is_lazy: + self.add_strong_link(from_node_id, from_socket, unique_id) + + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + +class ExecutionList(TopologicalSort): + """ + ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, + it can still be returned to the graph after having further dependencies added. + """ + def __init__(self, dynprompt, output_cache): + super().__init__(dynprompt) + self.output_cache = output_cache + self.staged_node_id = None + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if self.output_cache.get(from_node_id) is not None: + # Nothing to do + return + super().add_strong_link(from_node_id, from_socket, to_node_id) + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None, None, None + available = self.get_ready_nodes() + if len(available) == 0: + cycled_nodes = self.get_nodes_in_cycle() + # Because cycles composed entirely of static nodes are caught during initial validation, + # we will 'blame' the first node in the cycle that is not a static node. + blamed_node = cycled_nodes[0] + for node_id in cycled_nodes: + display_node_id = self.dynprompt.get_display_node_id(node_id) + if display_node_id != node_id: + blamed_node = display_node_id + break + ex = DependencyCycleError("Dependency cycle detected") + error_details = { + "node_id": blamed_node, + "exception_message": str(ex), + "exception_type": "graph.DependencyCycleError", + "traceback": [], + "current_inputs": [] + } + return None, error_details, ex + next_node = available[0] + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + for node_id in available: + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + next_node = node_id + break + self.staged_node_id = next_node + return self.staged_node_id, None, None + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + self.pop_node(node_id) + self.staged_node_id = None + + def get_nodes_in_cycle(self): + # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle. + # We're skipping some of the performance optimizations from the original TopologicalSort to keep + # the code simple (and because having a cycle in the first place is a catastrophic error) + blocked_by = { node_id: {} for node_id in self.pendingNodes } + for from_node_id in self.blocking: + for to_node_id in self.blocking[from_node_id]: + if True in self.blocking[from_node_id][to_node_id].values(): + blocked_by[to_node_id][from_node_id] = True + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + while len(to_remove) > 0: + for node_id in to_remove: + for to_node_id in blocked_by: + if node_id in blocked_by[to_node_id]: + del blocked_by[to_node_id][node_id] + del blocked_by[node_id] + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + return list(blocked_by.keys()) + +class ExecutionBlocker: + """ + Return this from a node and any users will be blocked with the given error message. + If the message is None, execution will be blocked silently instead. + Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's + possible, a lazy input will be more efficient and have a better user experience. + This functionality is useful in two cases: + 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node + like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using + lazy evaluation to let it conditionally disable itself.) + 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. + (I would recommend not making nodes like this in the future -- instead, make multiple nodes with + different outputs. Unfortunately, there are several popular existing nodes using this pattern.) + """ + def __init__(self, message): + self.message = message + diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py new file mode 100644 index 0000000..8595e94 --- /dev/null +++ b/comfy/graph_utils.py @@ -0,0 +1,139 @@ +def is_link(obj): + if not isinstance(obj, list): + return False + if len(obj) != 2: + return False + if not isinstance(obj[0], str): + return False + if not isinstance(obj[1], int) and not isinstance(obj[1], float): + return False + return True + +# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end +class GraphBuilder: + _default_prefix_root = "" + _default_prefix_call_index = 0 + _default_prefix_graph_index = 0 + + def __init__(self, prefix = None): + if prefix is None: + self.prefix = GraphBuilder.alloc_prefix() + else: + self.prefix = prefix + self.nodes = {} + self.id_gen = 1 + + @classmethod + def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): + cls._default_prefix_root = prefix_root + cls._default_prefix_call_index = call_index + cls._default_prefix_graph_index = graph_index + + @classmethod + def alloc_prefix(cls, root=None, call_index=None, graph_index=None): + if root is None: + root = GraphBuilder._default_prefix_root + if call_index is None: + call_index = GraphBuilder._default_prefix_call_index + if graph_index is None: + graph_index = GraphBuilder._default_prefix_graph_index + result = f"{root}.{call_index}.{graph_index}." + GraphBuilder._default_prefix_graph_index += 1 + return result + + def node(self, class_type, id=None, **kwargs): + if id is None: + id = str(self.id_gen) + self.id_gen += 1 + id = self.prefix + id + if id in self.nodes: + return self.nodes[id] + + node = Node(id, class_type, kwargs) + self.nodes[id] = node + return node + + def lookup_node(self, id): + id = self.prefix + id + return self.nodes.get(id) + + def finalize(self): + output = {} + for node_id, node in self.nodes.items(): + output[node_id] = node.serialize() + return output + + def replace_node_output(self, node_id, index, new_value): + node_id = self.prefix + node_id + to_remove = [] + for node in self.nodes.values(): + for key, value in node.inputs.items(): + if is_link(value) and value[0] == node_id and value[1] == index: + if new_value is None: + to_remove.append((node, key)) + else: + node.inputs[key] = new_value + for node, key in to_remove: + del node.inputs[key] + + def remove_node(self, id): + id = self.prefix + id + del self.nodes[id] + +class Node: + def __init__(self, id, class_type, inputs): + self.id = id + self.class_type = class_type + self.inputs = inputs + self.override_display_id = None + + def out(self, index): + return [self.id, index] + + def set_input(self, key, value): + if value is None: + if key in self.inputs: + del self.inputs[key] + else: + self.inputs[key] = value + + def get_input(self, key): + return self.inputs.get(key) + + def set_override_display_id(self, override_display_id): + self.override_display_id = override_display_id + + def serialize(self): + serialized = { + "class_type": self.class_type, + "inputs": self.inputs + } + if self.override_display_id is not None: + serialized["override_display_id"] = self.override_display_id + return serialized + +def add_graph_prefix(graph, outputs, prefix): + # Change the node IDs and any internal links + new_graph = {} + for node_id, node_info in graph.items(): + # Make sure the added nodes have unique IDs + new_node_id = prefix + node_id + new_node = { "class_type": node_info["class_type"], "inputs": {} } + for input_name, input_value in node_info.get("inputs", {}).items(): + if is_link(input_value): + new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] + else: + new_node["inputs"][input_name] = input_value + new_graph[new_node_id] = new_node + + # Change the node IDs in the outputs + new_outputs = [] + for n in range(len(outputs)): + output = outputs[n] + if is_link(output): + new_outputs.append([prefix + output[0], output[1]]) + else: + new_outputs.append(output) + + return new_graph, tuple(new_outputs) + diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 72ca368..9c68ab7 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -54,7 +54,8 @@ class Example: "min": 0, #Minimum value "max": 4096, #Maximum value "step": 64, #Slider's step - "display": "number" # Cosmetic only: display as "number" or "slider" + "display": "number", # Cosmetic only: display as "number" or "slider" + "lazy": True # Will only be evaluated if check_lazy_status requires it }), "float_field": ("FLOAT", { "default": 1.0, @@ -62,11 +63,14 @@ class Example: "max": 10.0, "step": 0.01, "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. - "display": "number"}), + "display": "number", + "lazy": True + }), "print_to_screen": (["enable", "disable"],), "string_field": ("STRING", { "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node - "default": "Hello World!" + "default": "Hello World!", + "lazy": True }), }, } @@ -80,6 +84,23 @@ class Example: CATEGORY = "Example" + def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + """ + Return a list of input names that need to be evaluated. + + This function will be called if there are any lazy inputs which have not yet been + evaluated. As long as you return at least one field which has not yet been evaluated + (and more exist), this function will be called again once the value of the requested + field is available. + + Any evaluated inputs will be passed as arguments to this function. Any unevaluated + inputs will have the value None. + """ + if print_to_screen == "enable": + return ["int_field", "float_field", "string_field"] + else: + return [] + def test(self, image, string_field, int_field, float_field, print_to_screen): if print_to_screen == "enable": print(f"""Your input contains: diff --git a/execution.py b/execution.py index d207e1b..ee67589 100644 --- a/execution.py +++ b/execution.py @@ -5,6 +5,7 @@ import threading import heapq import time import traceback +from enum import Enum import inspect from typing import List, Literal, NamedTuple, Optional @@ -12,102 +13,216 @@ import torch import nodes import comfy.model_management +import comfy.graph_utils +from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy.graph_utils import is_link, GraphBuilder +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy.cli_args import args + +class ExecutionResult(Enum): + SUCCESS = 0 + FAILURE = 1 + PENDING = 2 + +class DuplicateNodeError(Exception): + pass + +class IsChangedCache: + def __init__(self, dynprompt, outputs_cache): + self.dynprompt = dynprompt + self.outputs_cache = outputs_cache + self.is_changed = {} + + def get(self, node_id): + if node_id in self.is_changed: + return self.is_changed[node_id] + + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if not hasattr(class_def, "IS_CHANGED"): + self.is_changed[node_id] = False + return self.is_changed[node_id] + + if "is_changed" in node: + self.is_changed[node_id] = node["is_changed"] + return self.is_changed[node_id] + + input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + except: + node["is_changed"] = float("NaN") + finally: + self.is_changed[node_id] = node["is_changed"] + return self.is_changed[node_id] + +class CacheSet: + def __init__(self, lru_size=None): + if lru_size is None or lru_size == 0: + self.init_classic_cache() + else: + self.init_lru_cache(lru_size) + self.all = [self.outputs, self.ui, self.objects] + + # Useful for those with ample RAM/VRAM -- allows experimenting without + # blowing away the cache every time + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # Performs like the old cache -- dump data ASAP + def init_classic_cache(self): + self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.ui = HierarchicalCache(CacheKeySetInputSignature) + self.objects = HierarchicalCache(CacheKeySetID) + + def recursive_debug_dump(self): + result = { + "outputs": self.outputs.recursive_debug_dump(), + "ui": self.ui.recursive_debug_dump(), + } + return result -def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): +def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} + missing_keys = {} for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + input_type, input_category, input_info = get_input_info(class_def, x) + def mark_missing(): + missing_keys[x] = True + input_data_all[x] = (None,) + if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if input_unique_id not in outputs: - input_data_all[x] = (None,) + if outputs is None: + mark_missing() + continue # This might be a lazily-evaluated input + cached_output = outputs.get(input_unique_id) + if cached_output is None: + mark_missing() continue - obj = outputs[input_unique_id][output_index] + if output_index >= len(cached_output): + mark_missing() + continue + obj = cached_output[output_index] input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + elif input_category is not None: + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = [prompt] + input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] - return input_data_all + return input_data_all, missing_keys -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists - input_is_list = False - if hasattr(obj, "INPUT_IS_LIST"): - input_is_list = obj.INPUT_IS_LIST + input_is_list = getattr(obj, "INPUT_IS_LIST", False) if len(input_data_all) == 0: max_len_input = 0 else: - max_len_input = max([len(x) for x in input_data_all.values()]) + max_len_input = max(len(x) for x in input_data_all.values()) # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): - d_new = dict() - for k,v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new + return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] - if input_is_list: + def process_inputs(inputs, index=None): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) + execution_block = None + for k, v in inputs.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb else v + break + if execution_block is None: + if pre_execute_cb is not None and index is not None: + pre_execute_cb(index) + results.append(getattr(obj, func)(**inputs)) + else: + results.append(execution_block) + + if input_is_list: + process_inputs(input_data_all, 0) elif max_len_input == 0: - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)()) - else: + process_inputs({}) + else: for i in range(max_len_input): - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + input_dict = slice_dict(input_data_all, i) + process_inputs(input_dict, i) return results -def get_output_data(obj, input_data_all): +def merge_result_data(results, obj): + # check which outputs need concatenating + output = [] + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + return output + +def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): results = [] uis = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) - - for r in return_values: + subgraph_results = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + has_subgraph = False + for i in range(len(return_values)): + r = return_values[i] if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) + if 'expand' in r: + # Perform an expansion, but do not append results + has_subgraph = True + new_graph = r['expand'] + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif 'result' in r: + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: + if isinstance(r, ExecutionBlocker): + r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) + subgraph_results.append((None, r)) - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST - - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) - else: - output.append([o[i] for o in results]) - + if has_subgraph: + output = subgraph_results + elif len(results) > 0: + output = merge_result_data(results, obj) + else: + output = [] ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + return output, ui, has_subgraph def format_value(x): if x is None: @@ -117,53 +232,145 @@ def format_value(x): else: return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): +def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + real_node_id = dynprompt.get_real_node_id(unique_id) + display_node_id = dynprompt.get_display_node_id(unique_id) + parent_node_id = dynprompt.get_parent_node_id(unique_id) + inputs = dynprompt.get_node(unique_id)['inputs'] + class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if unique_id in outputs: - return (True, None, None) - - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) - if result[0] is not True: - # Another node failed further upstream - return result + if caches.outputs.get(unique_id) is not None: + if server.client_id is not None: + cached_output = caches.ui.get(unique_id) or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - - obj = object_storage.get((unique_id, class_type), None) - if obj is None: - obj = class_def() - object_storage[(unique_id, class_type)] = obj + if unique_id in pending_subgraph_results: + cached_results = pending_subgraph_results[unique_id] + resolved_outputs = [] + for is_subgraph, result in cached_results: + if not is_subgraph: + resolved_outputs.append(result) + else: + resolved_output = [] + for r in result: + if is_link(r): + source_node, source_output = r[0], r[1] + node_output = caches.outputs.get(source_node)[source_output] + for o in node_output: + resolved_output.append(o) - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data + else: + resolved_output.append(r) + resolved_outputs.append(tuple(resolved_output)) + output_data = merge_result_data(resolved_outputs, class_def) + output_ui = [] + has_subgraph = False + else: + input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + if server.client_id is not None: + server.last_node_id = display_node_id + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + + obj = caches.objects.get(unique_id) + if obj is None: + obj = class_def() + caches.objects.set(unique_id, obj) + + if hasattr(obj, "check_lazy_status"): + required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) + required_inputs = [x for x in required_inputs if isinstance(x,str) and ( + x not in input_data_all or x in missing_keys + )] + if len(required_inputs) > 0: + for i in required_inputs: + execution_list.make_input_strong_link(unique_id, i) + return (ExecutionResult.PENDING, None, None) + + def execution_block_cb(block): + if block.message is not None: + mes = { + "prompt_id": prompt_id, + "node_id": unique_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": f"Execution Blocked: {block.message}", + "exception_type": "ExecutionBlocked", + "traceback": [], + "current_inputs": [], + "current_outputs": [], + } + server.send_sync("execution_error", mes, server.client_id) + return ExecutionBlocker(None) + else: + return block + def pre_execute_cb(call_index): + GraphBuilder.set_default_prefix(unique_id, call_index, 0) + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + caches.ui.set(unique_id, { + "meta": { + "node_id": unique_id, + "display_node": display_node_id, + "parent_node": parent_node_id, + "real_node_id": real_node_id, + }, + "output": output_ui + }) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + if has_subgraph: + cached_outputs = [] + new_node_ids = [] + new_output_ids = [] + new_output_links = [] + for i in range(len(output_data)): + new_graph, node_outputs = output_data[i] + if new_graph is None: + cached_outputs.append((False, node_outputs)) + else: + # Check for conflicts + for node_id in new_graph.keys(): + if dynprompt.has_node(node_id): + raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") + for node_id, node_info in new_graph.items(): + new_node_ids.append(node_id) + display_id = node_info.get("override_display_id", unique_id) + dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) + # Figure out if the newly created node is an output node + class_type = node_info["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + new_output_ids.append(node_id) + for i in range(len(node_outputs)): + if is_link(node_outputs[i]): + from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + new_output_links.append((from_node_id, from_socket)) + cached_outputs.append((True, node_outputs)) + new_node_ids = set(new_node_ids) + for cache in caches.all: + cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + for node_id in new_output_ids: + execution_list.add_node(node_id) + for link in new_output_links: + execution_list.add_strong_link(link[0], link[1], unique_id) + pending_subgraph_results[unique_id] = cached_outputs + return (ExecutionResult.PENDING, None, None) + caches.outputs.set(unique_id, output_data) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") # skip formatting inputs/outputs error_details = { - "node_id": unique_id, + "node_id": real_node_id, } - return (False, error_details, iex) + return (ExecutionResult.FAILURE, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) @@ -173,121 +380,36 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - output_data_formatted = {} - for node_id, node_outputs in outputs.items(): - output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] - - logging.error(f"!!! Exception during processing!!! {ex}") + logging.error(f"!!! Exception during processing !!! {ex}") logging.error(traceback.format_exc()) error_details = { - "node_id": unique_id, + "node_id": real_node_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), - "current_inputs": input_data_formatted, - "current_outputs": output_data_formatted + "current_inputs": input_data_formatted } - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() - return (False, error_details, ex) + return (ExecutionResult.FAILURE, error_details, ex) executed.add(unique_id) - return (True, None, None) - -def recursive_will_execute(prompt, outputs, current_item, memo={}): - unique_id = current_item - - if unique_id in memo: - return memo[unique_id] - - inputs = prompt[unique_id]['inputs'] - will_execute = [] - if unique_id in outputs: - return [] - - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo) - - memo[unique_id] = will_execute + [unique_id] - return memo[unique_id] - -def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - is_changed_old = '' - is_changed = '' - to_delete = False - if hasattr(class_def, 'IS_CHANGED'): - if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: - is_changed_old = old_prompt[unique_id]['is_changed'] - if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: - try: - #is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - prompt[unique_id]['is_changed'] = is_changed - except: - to_delete = True - else: - is_changed = prompt[unique_id]['is_changed'] - - if unique_id not in outputs: - return True - - if not to_delete: - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif class_type != old_prompt[unique_id]['class_type']: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True - - if to_delete: - d = outputs.pop(unique_id) - del d - return to_delete + return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server): + def __init__(self, server, lru_size=None): + self.lru_size = lru_size self.server = server self.reset() def reset(self): - self.outputs = {} - self.object_storage = {} - self.outputs_ui = {} + self.caches = CacheSet(self.lru_size) self.status_messages = [] self.success = True - self.old_prompt = {} def add_message(self, event, data: dict, broadcast: bool): data = { @@ -318,27 +440,14 @@ class PromptExecutor: "node_id": node_id, "node_type": class_type, "executed": list(executed), - "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], + "current_outputs": list(current_outputs), } self.add_message("execution_error", mes, broadcast=False) - # Next, remove the subsequent outputs since they will not be executed - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -351,65 +460,58 @@ class PromptExecutor: self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): - #delete cached outputs if nodes don't exist for them - to_delete = [] - for o in self.outputs: - if o not in prompt: - to_delete += [o] - for o in to_delete: - d = self.outputs.pop(o) - del d - to_delete = [] - for o in self.object_storage: - if o[0] not in prompt: - to_delete += [o] - else: - p = prompt[o[0]] - if o[1] != p['class_type']: - to_delete += [o] - for o in to_delete: - d = self.object_storage.pop(o) - del d - - for x in prompt: - recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) - - current_outputs = set(self.outputs.keys()) - for x in list(self.outputs_ui.keys()): - if x not in current_outputs: - d = self.outputs_ui.pop(x) - del d + dynamic_prompt = DynamicPrompt(prompt) + is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() + + cached_nodes = [] + for node_id in prompt: + if self.caches.outputs.get(node_id) is not None: + cached_nodes.append(node_id) comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", - { "nodes": list(current_outputs) , "prompt_id": prompt_id}, + { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) + pending_subgraph_results = {} executed = set() - output_node_id = None - to_execute = [] - + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): - to_execute += [(0, node_id)] - - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - memo = {} - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute))) - output_node_id = to_execute.pop(0)[-1] - - # This call shouldn't raise anything if there's an error deep in - # the actual SD code, instead it will report the node where the - # error was raised - self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if self.success is not True: - self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + execution_list.add_node(node_id) + + while not execution_list.is_empty(): + node_id, error, ex = execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break + + result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) + ui_outputs = {} + meta_outputs = {} + all_node_ids = self.caches.ui.all_node_ids() + for node_id in all_node_ids: + ui_info = self.caches.ui.get(node_id) + if ui_info is not None: + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() @@ -426,31 +528,37 @@ def validate_inputs(prompt, item, validated): obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] valid = True validate_function_inputs = [] + validate_has_kwargs = False if hasattr(obj_class, "VALIDATE_INPUTS"): - validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args - - for x in required_inputs: + argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) + validate_function_inputs = argspec.args + validate_has_kwargs = argspec.varkw is not None + received_types = {} + + for x in valid_inputs: + type_input, input_category, extra_info = get_input_info(obj_class, x) + assert extra_info is not None if x not in inputs: - error = { - "type": "required_input_missing", - "message": "Required input is missing", - "details": f"{x}", - "extra_info": { - "input_name": x + if input_category == "required": + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } } - } - errors.append(error) + errors.append(error) continue val = inputs[x] - info = required_inputs[x] - type_input = info[0] + info = (type_input, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -469,8 +577,9 @@ def validate_inputs(prompt, item, validated): o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: - received_type = r[val[1]] + received_type = r[val[1]] + received_types[x] = received_type + if 'input_types' not in validate_function_inputs and received_type != type_input: details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", @@ -521,6 +630,9 @@ def validate_inputs(prompt, item, validated): if type_input == "STRING": val = str(val) inputs[x] = val + if type_input == "BOOLEAN": + val = bool(val) + inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", @@ -536,11 +648,11 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: + if x not in validate_function_inputs and not validate_has_kwargs: + if "min" in extra_info and val < extra_info["min"]: error = { "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), "details": f"{x}", "extra_info": { "input_name": x, @@ -550,10 +662,10 @@ def validate_inputs(prompt, item, validated): } errors.append(error) continue - if "max" in info[1] and val > info[1]["max"]: + if "max" in extra_info and val > extra_info["max"]: error = { "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), "details": f"{x}", "extra_info": { "input_name": x, @@ -564,7 +676,6 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -591,18 +702,20 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(validate_function_inputs) > 0: - input_data_all = get_input_data(inputs, obj_class, unique_id) + if len(validate_function_inputs) > 0 or validate_has_kwargs: + input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: - if x in validate_function_inputs: + if x in validate_function_inputs or validate_has_kwargs: input_filtered[x] = input_data_all[x] + if 'input_types' in validate_function_inputs: + input_filtered['input_types'] = [received_types] #ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") for x in input_filtered: for i, r in enumerate(ret): - if r is not True: + if r is not True and not isinstance(r, ExecutionBlocker): details = f"{x}" if r is not False: details += f" - {str(r)}" @@ -613,8 +726,6 @@ def validate_inputs(prompt, item, validated): "details": details, "extra_info": { "input_name": x, - "input_config": info, - "received_value": val, } } errors.append(error) @@ -780,7 +891,7 @@ class PromptQueue: completed: bool messages: List[str] - def task_done(self, item_id, outputs, + def task_done(self, item_id, history_result, status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) @@ -793,9 +904,10 @@ class PromptQueue: self.history[prompt[1]] = { "prompt": prompt, - "outputs": copy.deepcopy(outputs), + "outputs": {}, 'status': status_dict, } + self.history[prompt[1]].update(history_result) self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index b878b3e..e9d6ed2 100644 --- a/main.py +++ b/main.py @@ -101,7 +101,7 @@ def cuda_malloc_warning(): logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def prompt_worker(q, server): - e = execution.PromptExecutor(server) + e = execution.PromptExecutor(server, lru_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -121,7 +121,7 @@ def prompt_worker(q, server): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, - e.outputs_ui, + e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, diff --git a/pytest.ini b/pytest.ini index 8b7a747..a224d8c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,7 @@ [pytest] markers = inference: mark as inference test (deselect with '-m "not inference"') + execution: mark as execution test (deselect with '-m "not execution"') testpaths = tests tests-unit diff --git a/server.py b/server.py index 9d8269f..a2bb376 100644 --- a/server.py +++ b/server.py @@ -423,6 +423,7 @@ class PromptServer(): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] info = {} info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] @@ -714,6 +715,9 @@ class PromptServer(): site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) await site.start() + self.address = address + self.port = port + if verbose: logging.info("Starting server\n") logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port)) diff --git a/tests/inference/extra_model_paths.yaml b/tests/inference/extra_model_paths.yaml new file mode 100644 index 0000000..75b2e1a --- /dev/null +++ b/tests/inference/extra_model_paths.yaml @@ -0,0 +1,4 @@ +# Config for testing nodes +testing: + custom_nodes: tests/inference/testing_nodes + diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py new file mode 100644 index 0000000..8616ca1 --- /dev/null +++ b/tests/inference/test_execution.py @@ -0,0 +1,461 @@ +from io import BytesIO +import numpy +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Union, Dict +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import urllib.request +import urllib.parse +import urllib.error +from comfy.graph_utils import GraphBuilder, Node + +class RunResult: + def __init__(self, prompt_id: str): + self.outputs: Dict[str,Dict] = {} + self.runs: Dict[str,bool] = {} + self.prompt_id: str = prompt_id + + def get_output(self, node: Node): + return self.outputs.get(node.id, None) + + def did_run(self, node: Node): + return self.runs.get(node.id, False) + + def get_images(self, node: Node): + output = self.get_output(node) + if output is None: + return [] + return output.get('image_objects', []) + + def get_prompt_id(self): + return self.prompt_id + +class ComfyClient: + def __init__(self): + self.test_name = "" + + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def set_test_name(self, name): + self.test_name = name + + def run(self, graph): + prompt = graph.finalize() + for node in graph.nodes.values(): + if node.class_type == 'SaveImage': + node.inputs['filename_prefix'] = self.test_name + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + result = RunResult(prompt_id) + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] != prompt_id: + continue + if data['node'] is None: + break + result.runs[data['node']] = True + elif message['type'] == 'execution_error': + raise Exception(message['data']) + elif message['type'] == 'execution_cached': + pass # Probably want to store this off for testing + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + result.outputs[node_id] = node_output + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + image_obj = Image.open(BytesIO(image_data)) + images_output.append(image_obj) + node_output['image_objects'] = images_output + + return result + +# +# Loop through these variables +# +@pytest.mark.execution +class TestExecution: + # + # Initialize server and client + # + @fixture(scope="class", autouse=True, params=[ + # (use_lru, lru_size) + (False, 0), + (True, 0), + (True, 100), + ]) + def _server(self, args_pytest, request): + # Start server + pargs = [ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] + print("Running server with args:", pargs) + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=listen, port=port) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + return comfy_client + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + client = self.start_client(args_pytest["listen"], args_pytest["port"]) + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + shared_client.set_test_name(f"execution[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) + + def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + output = g.node("SaveImage", images=lazy_mix.out(0)) + result = client.run(g) + + result_image = result.get_images(output)[0] + assert numpy.array(result_image).any() == 0, "Image should be black" + assert result.did_run(input1) + assert not result.did_run(input2) + assert result.did_run(mask) + assert result.did_run(lazy_mix) + + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + client.run(g) + result2 = client.run(g) + for node_id, node in g.nodes.items(): + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + client.run(g) + mask.inputs['value'] = 0.4 + result2 = client.run(g) + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + + def test_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Different size of the two images + input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + + @pytest.mark.parametrize("test_value, expect_error", [ + (5, True), + ("foo", True), + (5.0, False), + ]) + def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value", [ + ("StubInt", 5), + ("StubFloat", 5.0) + ]) + def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation2.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation3.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation4.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_value1, test_value2, expect_error", [ + (0.0, 0.5, False), + (0.0, 5.0, False), + (0.0, 7.0, True) + ]) + def test_validation_error_kwargs(self, test_value1, test_value2, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + validation5 = g.node("TestCustomValidation5", input1=test_value1, input2=test_value2) + g.node("SaveImage", images=validation5.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix2.out(0)) + + # When the cycle exists on initial submission, it should raise a validation error + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0)) + g.node("SaveImage", images=generator.out(0)) + + # When the cycle is in a graph that is generated dynamically, it should raise a runtime error + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" + + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + g = builder + # Creating the nodes in this specific order previously caused a bug + save = g.node("SaveImage") + is_changed = g.node("TestCustomIsChanged", should_change=False) + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + save.set_input('images', is_changed.out(0)) + is_changed.set_input('image', input1.out(0)) + + result1 = client.run(g) + result2 = client.run(g) + is_changed.set_input('should_change', True) + result3 = client.run(g) + result4 = client.run(g) + assert result1.did_run(is_changed), "is_changed should have been run" + assert not result2.did_run(is_changed), "is_changed should have been cached" + assert result3.did_run(is_changed), "is_changed should have been re-run" + assert result4.did_run(is_changed), "is_changed should not have been cached" + + def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0)) + output = g.node("SaveImage", images=average.out(0)) + + result = client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // 4 + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + + def test_for_loop(self, client: ComfyClient, builder: GraphBuilder): + g = builder + iterations = 4 + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0)) + for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0)) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2)) + for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0)) + output = g.node("SaveImage", images=for_close.out(0)) + + for iterations in range(1, 5): + for_open.set_input('remaining', iterations) + result = client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // (2 ** iterations) + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + assert result.did_run(is_changed) + + def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder): + g = builder + val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3) + mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0)) + output_dynamic = g.node("SaveImage", images=mixed.out(0)) + output_literal = g.node("SaveImage", images=mixed.out(1)) + + result = client.run(g) + images_dynamic = result.get_images(output_dynamic) + assert len(images_dynamic) == 3, "Should have 2 images" + assert numpy.array(images_dynamic[0]).min() == 25 and numpy.array(images_dynamic[0]).max() == 25, "First image should be 0.1" + assert numpy.array(images_dynamic[1]).min() == 51 and numpy.array(images_dynamic[1]).max() == 51, "Second image should be 0.2" + assert numpy.array(images_dynamic[2]).min() == 76 and numpy.array(images_dynamic[2]).max() == 76, "Third image should be 0.3" + + images_literal = result.get_images(output_literal) + assert len(images_literal) == 3, "Should have 2 images" + for i in range(3): + assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" + + def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder): + g = builder + val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0) + mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1) + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + rebatch = g.node("RebatchImages", images=mix.out(0), batch_size=3) + output = g.node("SaveImage", images=rebatch.out(0)) + + result = client.run(g) + images = result.get_images(output) + assert len(images) == 3, "Should have 3 image" + assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be 0.0" + assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5" + assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0" + + def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + output1 = g.node("PreviewImage", images=input1.out(0)) + output2 = g.node("PreviewImage", images=input1.out(0)) + + result = client.run(g) + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have 1 image" + assert len(images2) == 1, "Should have 1 image" + diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/inference/testing_nodes/testing-pack/__init__.py new file mode 100644 index 0000000..dcc7165 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/__init__.py @@ -0,0 +1,23 @@ +from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS +from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS +from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) + diff --git a/tests/inference/testing_nodes/testing-pack/conditions.py b/tests/inference/testing_nodes/testing-pack/conditions.py new file mode 100644 index 0000000..0c200ee --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/conditions.py @@ -0,0 +1,194 @@ +import re +import torch + +class TestIntConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "int_condition" + + CATEGORY = "Testing/Logic" + + def int_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + + +class TestFloatConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "float_condition" + + CATEGORY = "Testing/Logic" + + def float_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + +class TestStringConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("STRING", {"multiline": False}), + "b": ("STRING", {"multiline": False}), + "operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],), + "case_sensitive": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "string_condition" + + CATEGORY = "Testing/Logic" + + def string_condition(self, a, b, operation, case_sensitive): + if not case_sensitive: + a = a.lower() + b = b.lower() + + if operation == "a == b": + return (a == b,) + elif operation == "a != b": + return (a != b,) + elif operation == "a IN b": + return (a in b,) + elif operation == "a MATCH REGEX(b)": + try: + return (re.match(b, a) is not None,) + except: + return (False,) + elif operation == "a BEGINSWITH b": + return (a.startswith(b),) + elif operation == "a ENDSWITH b": + return (a.endswith(b),) + +class TestToBoolNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("*",), + }, + "optional": { + "invert": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "to_bool" + + CATEGORY = "Testing/Logic" + + def to_bool(self, value, invert = False): + if isinstance(value, torch.Tensor): + if value.max().item() == 0 and value.min().item() == 0: + result = False + else: + result = True + else: + try: + result = bool(value) + except: + # Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer. + result = True + + if invert: + result = not result + + return (result,) + +class TestBoolOperationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("BOOLEAN",), + "b": ("BOOLEAN",), + "op": (["a AND b", "a OR b", "a XOR b", "NOT a"],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "bool_operation" + + CATEGORY = "Testing/Logic" + + def bool_operation(self, a, b, op): + if op == "a AND b": + return (a and b,) + elif op == "a OR b": + return (a or b,) + elif op == "a XOR b": + return (a ^ b,) + elif op == "NOT a": + return (not a,) + + +CONDITION_NODE_CLASS_MAPPINGS = { + "TestIntConditions": TestIntConditions, + "TestFloatConditions": TestFloatConditions, + "TestStringConditions": TestStringConditions, + "TestToBoolNode": TestToBoolNode, + "TestBoolOperationNode": TestBoolOperationNode, +} + +CONDITION_NODE_DISPLAY_NAME_MAPPINGS = { + "TestIntConditions": "Int Condition", + "TestFloatConditions": "Float Condition", + "TestStringConditions": "String Condition", + "TestToBoolNode": "To Bool", + "TestBoolOperationNode": "Bool Operation", +} diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py new file mode 100644 index 0000000..1ef1cf8 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -0,0 +1,173 @@ +from comfy.graph_utils import GraphBuilder, is_link +from comfy.graph import ExecutionBlocker +from .tools import VariantSupport + +NUM_FLOW_SOCKETS = 5 +@VariantSupport() +class TestWhileLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "condition": ("BOOLEAN", {"default": True}), + }, + "optional": { + }, + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"][f"initial_value{i}"] = ("*",) + return inputs + + RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_open" + + CATEGORY = "Testing/Flow" + + def while_loop_open(self, condition, **kwargs): + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get(f"initial_value{i}", None)) + return tuple(["stub"] + values) + +@VariantSupport() +class TestWhileLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + "condition": ("BOOLEAN", {"forceInput": True}), + }, + "optional": { + }, + "hidden": { + "dynprompt": "DYNPROMPT", + "unique_id": "UNIQUE_ID", + } + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"][f"initial_value{i}"] = ("*",) + return inputs + + RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_close" + + CATEGORY = "Testing/Flow" + + def explore_dependencies(self, node_id, dynprompt, upstream): + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for k, v in node_info["inputs"].items(): + if is_link(v): + parent_id = v[0] + if parent_id not in upstream: + upstream[parent_id] = [] + self.explore_dependencies(parent_id, dynprompt, upstream) + upstream[parent_id].append(node_id) + + def collect_contained(self, node_id, upstream, contained): + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id not in contained: + contained[child_id] = True + self.collect_contained(child_id, upstream, contained) + + + def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs): + assert dynprompt is not None + if not condition: + # We're done with the loop + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get(f"initial_value{i}", None)) + return tuple(values) + + # We want to loop + upstream = {} + # Get the list of all nodes between the open and close nodes + self.explore_dependencies(unique_id, dynprompt, upstream) + + contained = {} + open_node = flow_control[0] + self.collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + # We'll use the default prefix, but to avoid having node names grow exponentially in size, + # we'll use "Recurse" for the name of the recursively-generated copy of this node. + graph = GraphBuilder() + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) + assert node is not None + for k, v in original_node["inputs"].items(): + if is_link(v) and v[0] in contained: + parent = graph.lookup_node(v[0]) + assert parent is not None + node.set_input(k, parent.out(v[1])) + else: + node.set_input(k, v) + new_open = graph.lookup_node(open_node) + assert new_open is not None + for i in range(NUM_FLOW_SOCKETS): + key = f"initial_value{i}" + new_open.set_input(key, kwargs.get(key, None)) + my_clone = graph.lookup_node("Recurse") + assert my_clone is not None + result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS)) + return { + "result": tuple(result), + "expand": graph.finalize(), + } + +@VariantSupport() +class TestExecutionBlockerNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "input": ("*",), + "block": ("BOOLEAN",), + "verbose": ("BOOLEAN", {"default": False}), + }, + } + return inputs + + RETURN_TYPES = ("*",) + RETURN_NAMES = ("output",) + FUNCTION = "execution_blocker" + + CATEGORY = "Testing/Flow" + + def execution_blocker(self, input, block, verbose): + if block: + return (ExecutionBlocker("Blocked Execution" if verbose else None),) + return (input,) + +FLOW_CONTROL_NODE_CLASS_MAPPINGS = { + "TestWhileLoopOpen": TestWhileLoopOpen, + "TestWhileLoopClose": TestWhileLoopClose, + "TestExecutionBlocker": TestExecutionBlockerNode, +} +FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = { + "TestWhileLoopOpen": "While Loop Open", + "TestWhileLoopClose": "While Loop Close", + "TestExecutionBlocker": "Execution Blocker", +} diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py new file mode 100644 index 0000000..5884cae --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -0,0 +1,335 @@ +import torch +from .tools import VariantSupport +from comfy.graph_utils import GraphBuilder + +class TestLazyMixImages: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",{"lazy": True}), + "image2": ("IMAGE",{"lazy": True}), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mix" + + CATEGORY = "Testing/Nodes" + + def check_lazy_status(self, mask, image1, image2): + mask_min = mask.min() + mask_max = mask.max() + needed = [] + if image1 is None and (mask_min != 1.0 or mask_max != 1.0): + needed.append("image1") + if image2 is None and (mask_min != 0.0 or mask_max != 0.0): + needed.append("image2") + return needed + + # Not trying to handle different batch sizes here just to keep the demo simple + def mix(self, mask, image1, image2): + mask_min = mask.min() + mask_max = mask.max() + if mask_min == 0.0 and mask_max == 0.0: + return (image1,) + elif mask_min == 1.0 and mask_max == 1.0: + return (image2,) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if len(mask.shape) == 3: + mask = mask.unsqueeze(3) + if mask.shape[3] < image1.shape[3]: + mask = mask.repeat(1, 1, 1, image1.shape[3]) + + result = image1 * (1. - mask) + image2 * mask, + return (result[0],) + +class TestVariadicAverage: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "variadic_average" + + CATEGORY = "Testing/Nodes" + + def variadic_average(self, input1, **kwargs): + inputs = [input1] + while 'input' + str(len(inputs) + 1) in kwargs: + inputs.append(kwargs['input' + str(len(inputs) + 1)]) + return (torch.stack(inputs).mean(dim=0),) + + +class TestCustomIsChanged: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + "optional": { + "should_change": ("BOOL", {"default": False}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_is_changed" + + CATEGORY = "Testing/Nodes" + + def custom_is_changed(self, image, should_change=False): + return (image,) + + @classmethod + def IS_CHANGED(cls, should_change=False, *args, **kwargs): + if should_change: + return float("NaN") + else: + return False + +class TestCustomValidation1: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation1" + + CATEGORY = "Testing/Nodes" + + def custom_validation1(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + return True + +class TestCustomValidation2: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation2" + + CATEGORY = "Testing/Nodes" + + def custom_validation2(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + if 'input1' in input_types: + if input_types['input1'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input1: {input_types['input1']}" + if 'input2' in input_types: + if input_types['input2'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input2: {input_types['input2']}" + + return True + +@VariantSupport() +class TestCustomValidation3: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation3" + + CATEGORY = "Testing/Nodes" + + def custom_validation3(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + +class TestCustomValidation4: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT",), + "input2": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation4" + + CATEGORY = "Testing/Nodes" + + def custom_validation4(self, input1, input2): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1, input2): + if input1 is not None: + if not isinstance(input1, float): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, float): + return f"Invalid type of input2: {type(input2)}" + + return True + +class TestCustomValidation5: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT", {"min": 0.0, "max": 1.0}), + "input2": ("FLOAT", {"min": 0.0, "max": 1.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation5" + + CATEGORY = "Testing/Nodes" + + def custom_validation5(self, input1, input2): + value = input1 * input2 + return (torch.ones([1, 512, 512, 3]) * value,) + + @classmethod + def VALIDATE_INPUTS(cls, **kwargs): + if kwargs['input2'] == 7.0: + return "7s are not allowed. I've never liked 7s." + return True + +class TestDynamicDependencyCycle: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + "input2": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "dynamic_dependency_cycle" + + CATEGORY = "Testing/Nodes" + + def dynamic_dependency_cycle(self, input1, input2): + g = GraphBuilder() + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0)) + mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0)) + + # Create the cyle + mix1.set_input("image2", mix2.out(0)) + + return { + "result": (mix2.out(0),), + "expand": g.finalize(), + } + +class TestMixedExpansionReturns: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE","IMAGE") + FUNCTION = "mixed_expansion_returns" + + CATEGORY = "Testing/Nodes" + + def mixed_expansion_returns(self, input1): + white_image = torch.ones([1, 512, 512, 3]) + if input1 <= 0.1: + return (torch.ones([1, 512, 512, 3]) * 0.1, white_image) + elif input1 <= 0.2: + return { + "result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image), + } + else: + g = GraphBuilder() + mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1) + black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0)) + return { + "result": (mix.out(0), white_image), + "expand": g.finalize(), + } + +TEST_NODE_CLASS_MAPPINGS = { + "TestLazyMixImages": TestLazyMixImages, + "TestVariadicAverage": TestVariadicAverage, + "TestCustomIsChanged": TestCustomIsChanged, + "TestCustomValidation1": TestCustomValidation1, + "TestCustomValidation2": TestCustomValidation2, + "TestCustomValidation3": TestCustomValidation3, + "TestCustomValidation4": TestCustomValidation4, + "TestCustomValidation5": TestCustomValidation5, + "TestDynamicDependencyCycle": TestDynamicDependencyCycle, + "TestMixedExpansionReturns": TestMixedExpansionReturns, +} + +TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestLazyMixImages": "Lazy Mix Images", + "TestVariadicAverage": "Variadic Average", + "TestCustomIsChanged": "Custom IsChanged", + "TestCustomValidation1": "Custom Validation 1", + "TestCustomValidation2": "Custom Validation 2", + "TestCustomValidation3": "Custom Validation 3", + "TestCustomValidation4": "Custom Validation 4", + "TestCustomValidation5": "Custom Validation 5", + "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", + "TestMixedExpansionReturns": "Mixed Expansion Returns", +} diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py new file mode 100644 index 0000000..9be6eac --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -0,0 +1,105 @@ +import torch + +class StubImage: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "content": (['WHITE', 'BLACK', 'NOISE'],), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_image" + + CATEGORY = "Testing/Stub Nodes" + + def stub_image(self, content, height, width, batch_size): + if content == "WHITE": + return (torch.ones(batch_size, height, width, 3),) + elif content == "BLACK": + return (torch.zeros(batch_size, height, width, 3),) + elif content == "NOISE": + return (torch.rand(batch_size, height, width, 3),) + +class StubMask: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "stub_mask" + + CATEGORY = "Testing/Stub Nodes" + + def stub_mask(self, value, height, width, batch_size): + return (torch.ones(batch_size, height, width) * value,) + +class StubInt: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "stub_int" + + CATEGORY = "Testing/Stub Nodes" + + def stub_int(self, value): + return (value,) + +class StubFloat: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "stub_float" + + CATEGORY = "Testing/Stub Nodes" + + def stub_float(self, value): + return (value,) + +TEST_STUB_NODE_CLASS_MAPPINGS = { + "StubImage": StubImage, + "StubMask": StubMask, + "StubInt": StubInt, + "StubFloat": StubFloat, +} +TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { + "StubImage": "Stub Image", + "StubMask": "Stub Mask", + "StubInt": "Stub Int", + "StubFloat": "Stub Float", +} diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/inference/testing_nodes/testing-pack/tools.py new file mode 100644 index 0000000..34b28c0 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/tools.py @@ -0,0 +1,53 @@ + +def MakeSmartType(t): + if isinstance(t, str): + return SmartType(t) + return t + +class SmartType(str): + def __ne__(self, other): + if self == "*" or other == "*": + return False + selfset = set(self.split(',')) + otherset = set(other.split(',')) + return not selfset.issubset(otherset) + +def VariantSupport(): + def decorator(cls): + if hasattr(cls, "INPUT_TYPES"): + old_input_types = getattr(cls, "INPUT_TYPES") + def new_input_types(*args, **kwargs): + types = old_input_types(*args, **kwargs) + for category in ["required", "optional"]: + if category not in types: + continue + for key, value in types[category].items(): + if isinstance(value, tuple): + types[category][key] = (MakeSmartType(value[0]),) + value[1:] + return types + setattr(cls, "INPUT_TYPES", new_input_types) + if hasattr(cls, "RETURN_TYPES"): + old_return_types = cls.RETURN_TYPES + setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types)) + if hasattr(cls, "VALIDATE_INPUTS"): + # Reflection is used to determine what the function signature is, so we can't just change the function signature + raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet") + else: + def validate_inputs(input_types): + inputs = cls.INPUT_TYPES() + for key, value in input_types.items(): + if isinstance(value, SmartType): + continue + if "required" in inputs and key in inputs["required"]: + expected_type = inputs["required"][key][0] + elif "optional" in inputs and key in inputs["optional"]: + expected_type = inputs["optional"][key][0] + else: + expected_type = None + if expected_type is not None and MakeSmartType(value) != expected_type: + return f"Invalid type of {key}: {value} (expected {expected_type})" + return True + setattr(cls, "VALIDATE_INPUTS", validate_inputs) + return cls + return decorator + diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py new file mode 100644 index 0000000..fea83e3 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -0,0 +1,364 @@ +from comfy.graph_utils import GraphBuilder +from .tools import VariantSupport + +@VariantSupport() +class TestAccumulateNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "to_add": ("*",), + }, + "optional": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + FUNCTION = "accumulate" + + CATEGORY = "Testing/Lists" + + def accumulate(self, to_add, accumulation = None): + if accumulation is None: + value = [to_add] + else: + value = accumulation["accum"] + [to_add] + return ({"accum": value},) + +@VariantSupport() +class TestAccumulationHeadNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_head" + + CATEGORY = "Testing/Lists" + + def accumulation_head(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (accumulation, None) + else: + return ({"accum": accum[1:]}, accum[0]) + +class TestAccumulationTailNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_tail" + + CATEGORY = "Testing/Lists" + + def accumulation_tail(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (None, accumulation) + else: + return ({"accum": accum[:-1]}, accum[-1]) + +@VariantSupport() +class TestAccumulationToListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("*",) + OUTPUT_IS_LIST = (True,) + + FUNCTION = "accumulation_to_list" + + CATEGORY = "Testing/Lists" + + def accumulation_to_list(self, accumulation): + return (accumulation["accum"],) + +@VariantSupport() +class TestListToAccumulationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + INPUT_IS_LIST = (True,) + + FUNCTION = "list_to_accumulation" + + CATEGORY = "Testing/Lists" + + def list_to_accumulation(self, list): + return ({"accum": list},) + +@VariantSupport() +class TestAccumulationGetLengthNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("INT",) + + FUNCTION = "accumlength" + + CATEGORY = "Testing/Lists" + + def accumlength(self, accumulation): + return (len(accumulation['accum']),) + +@VariantSupport() +class TestAccumulationGetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}) + }, + } + + RETURN_TYPES = ("*",) + + FUNCTION = "get_item" + + CATEGORY = "Testing/Lists" + + def get_item(self, accumulation, index): + return (accumulation['accum'][index],) + +@VariantSupport() +class TestAccumulationSetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}), + "value": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + + FUNCTION = "set_item" + + CATEGORY = "Testing/Lists" + + def set_item(self, accumulation, index, value): + new_accum = accumulation['accum'][:] + new_accum[index] = value + return ({"accum": new_accum},) + +class TestIntMathOperation: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "int_math_operation" + + CATEGORY = "Testing/Logic" + + def int_math_operation(self, a, b, operation): + if operation == "add": + return (a + b,) + elif operation == "subtract": + return (a - b,) + elif operation == "multiply": + return (a * b,) + elif operation == "divide": + return (a // b,) + elif operation == "modulo": + return (a % b,) + elif operation == "power": + return (a ** b,) + + +from .flow_control import NUM_FLOW_SOCKETS +@VariantSupport() +class TestForLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}), + }, + "optional": { + f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS) + }, + "hidden": { + "initial_value0": ("*",) + } + } + + RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_open" + + CATEGORY = "Testing/Flow" + + def for_loop_open(self, remaining, **kwargs): + graph = GraphBuilder() + if "initial_value0" in kwargs: + remaining = kwargs["initial_value0"] + while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}) + outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)] + return { + "result": tuple(["stub", remaining] + outputs), + "expand": graph.finalize(), + } + +@VariantSupport() +class TestForLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + }, + "optional": { + f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS) + }, + } + + RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_close" + + CATEGORY = "Testing/Flow" + + def for_loop_close(self, flow_control, **kwargs): + graph = GraphBuilder() + while_open = flow_control[0] + sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1) + cond = graph.node("TestToBoolNode", value=sub.out(0)) + input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)} + while_close = graph.node("TestWhileLoopClose", + flow_control=flow_control, + condition=cond.out(0), + initial_value0=sub.out(0), + **input_values) + return { + "result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]), + "expand": graph.finalize(), + } + +NUM_LIST_SOCKETS = 10 +@VariantSupport() +class TestMakeListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value1": ("*",), + }, + "optional": { + f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS) + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "make_list" + OUTPUT_IS_LIST = (True,) + + CATEGORY = "Testing/Lists" + + def make_list(self, **kwargs): + result = [] + for i in range(NUM_LIST_SOCKETS): + if f"value{i}" in kwargs: + result.append(kwargs[f"value{i}"]) + return (result,) + +UTILITY_NODE_CLASS_MAPPINGS = { + "TestAccumulateNode": TestAccumulateNode, + "TestAccumulationHeadNode": TestAccumulationHeadNode, + "TestAccumulationTailNode": TestAccumulationTailNode, + "TestAccumulationToListNode": TestAccumulationToListNode, + "TestListToAccumulationNode": TestListToAccumulationNode, + "TestAccumulationGetLengthNode": TestAccumulationGetLengthNode, + "TestAccumulationGetItemNode": TestAccumulationGetItemNode, + "TestAccumulationSetItemNode": TestAccumulationSetItemNode, + "TestForLoopOpen": TestForLoopOpen, + "TestForLoopClose": TestForLoopClose, + "TestIntMathOperation": TestIntMathOperation, + "TestMakeListNode": TestMakeListNode, +} +UTILITY_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAccumulateNode": "Accumulate", + "TestAccumulationHeadNode": "Accumulation Head", + "TestAccumulationTailNode": "Accumulation Tail", + "TestAccumulationToListNode": "Accumulation to List", + "TestListToAccumulationNode": "List to Accumulation", + "TestAccumulationGetLengthNode": "Accumulation Get Length", + "TestAccumulationGetItemNode": "Accumulation Get Item", + "TestAccumulationSetItemNode": "Accumulation Set Item", + "TestForLoopOpen": "For Loop Open", + "TestForLoopClose": "For Loop Close", + "TestIntMathOperation": "Int Math Operation", + "TestMakeListNode": "Make List", +} diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 9a22389..163e42b 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -959,8 +959,8 @@ export class GroupNodeHandler { const executed = handleEvent.call( this, "executed", - (d) => d?.node, - (d, id, node) => ({ ...d, node: id, merge: !node.resetExecution }) + (d) => d?.display_node, + (d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution }) ); const onRemoved = node.onRemoved; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 0815549..88f1f9d 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js"; import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; -const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; +const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); const TARGET = Symbol(); // Used for reroutes to specify the real target widget diff --git a/web/scripts/api.js b/web/scripts/api.js index 03c3fb6..eee1c9a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -128,7 +128,7 @@ class ComfyApi extends EventTarget { this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node })); break; case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index df92450..6e2b395 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1348,7 +1348,7 @@ export class ComfyApp { api.addEventListener("executed", ({ detail }) => { if (this.workflowManager.activePrompt ?.workflow && this.workflowManager.activePrompt.workflow !== this.workflowManager.activeWorkflow) return; - const output = this.nodeOutputs[detail.node]; + const output = this.nodeOutputs[detail.display_node]; if (detail.merge && output) { for (const k in detail.output ?? {}) { const v = output[k]; @@ -1359,9 +1359,9 @@ export class ComfyApp { } } } else { - this.nodeOutputs[detail.node] = detail.output; + this.nodeOutputs[detail.display_node] = detail.output; } - const node = this.graph.getNodeById(detail.node); + const node = this.graph.getNodeById(detail.display_node); if (node) { if (node.onExecuted) node.onExecuted(detail.output); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 2c47412..05258e3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -239,7 +239,11 @@ class ComfyList { onclick: async () => { await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow, true, false); if (item.outputs) { - app.nodeOutputs = item.outputs; + app.nodeOutputs = {}; + for (const [key, value] of Object.entries(item.outputs)) { + const realKey = item?.meta?.[key]?.display_node ?? key; + app.nodeOutputs[realKey] = value; + } } }, }),