|
|
@ -171,7 +171,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|
|
|
for i in range(1, len(to_batch_temp) + 1):
|
|
|
|
for i in range(1, len(to_batch_temp) + 1):
|
|
|
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
|
|
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
|
|
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
|
|
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
|
|
|
if model.memory_required(input_shape) < free_memory:
|
|
|
|
if model.memory_required(input_shape) * 1.5 < free_memory:
|
|
|
|
to_batch = batch_amount
|
|
|
|
to_batch = batch_amount
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|