|
|
|
@ -7,6 +7,7 @@ import threading
|
|
|
|
|
import heapq
|
|
|
|
|
import traceback
|
|
|
|
|
import gc
|
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import nodes
|
|
|
|
@ -402,6 +403,10 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
errors = []
|
|
|
|
|
valid = True
|
|
|
|
|
|
|
|
|
|
validate_function_inputs = []
|
|
|
|
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
|
|
|
|
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
|
|
|
|
|
|
|
|
|
for x in required_inputs:
|
|
|
|
|
if x not in inputs:
|
|
|
|
|
error = {
|
|
|
|
@ -531,29 +536,7 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
errors.append(error)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
|
|
|
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
|
|
|
|
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
|
|
|
|
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
|
|
|
|
for i, r in enumerate(ret):
|
|
|
|
|
if r is not True:
|
|
|
|
|
details = f"{x}"
|
|
|
|
|
if r is not False:
|
|
|
|
|
details += f" - {str(r)}"
|
|
|
|
|
|
|
|
|
|
error = {
|
|
|
|
|
"type": "custom_validation_failed",
|
|
|
|
|
"message": "Custom validation failed for node",
|
|
|
|
|
"details": details,
|
|
|
|
|
"extra_info": {
|
|
|
|
|
"input_name": x,
|
|
|
|
|
"input_config": info,
|
|
|
|
|
"received_value": val,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
errors.append(error)
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
if x not in validate_function_inputs:
|
|
|
|
|
if isinstance(type_input, list):
|
|
|
|
|
if val not in type_input:
|
|
|
|
|
input_config = info
|
|
|
|
@ -580,6 +563,35 @@ def validate_inputs(prompt, item, validated):
|
|
|
|
|
errors.append(error)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if len(validate_function_inputs) > 0:
|
|
|
|
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
|
|
|
|
input_filtered = {}
|
|
|
|
|
for x in input_data_all:
|
|
|
|
|
if x in validate_function_inputs:
|
|
|
|
|
input_filtered[x] = input_data_all[x]
|
|
|
|
|
|
|
|
|
|
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
|
|
|
|
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
|
|
|
|
for x in input_filtered:
|
|
|
|
|
for i, r in enumerate(ret):
|
|
|
|
|
if r is not True:
|
|
|
|
|
details = f"{x}"
|
|
|
|
|
if r is not False:
|
|
|
|
|
details += f" - {str(r)}"
|
|
|
|
|
|
|
|
|
|
error = {
|
|
|
|
|
"type": "custom_validation_failed",
|
|
|
|
|
"message": "Custom validation failed for node",
|
|
|
|
|
"details": details,
|
|
|
|
|
"extra_info": {
|
|
|
|
|
"input_name": x,
|
|
|
|
|
"input_config": info,
|
|
|
|
|
"received_value": val,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
errors.append(error)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if len(errors) > 0 or valid is not True:
|
|
|
|
|
ret = (False, errors, unique_id)
|
|
|
|
|
else:
|
|
|
|
|