You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
590 B
Python
27 lines
590 B
Python
2 years ago
|
|
||
|
|
||
|
current_loaded_model = None
|
||
|
|
||
|
|
||
|
def unload_model():
|
||
|
global current_loaded_model
|
||
|
if current_loaded_model is not None:
|
||
|
current_loaded_model.model.cpu()
|
||
|
current_loaded_model.unpatch_model()
|
||
|
current_loaded_model = None
|
||
|
|
||
|
|
||
|
def load_model_gpu(model):
|
||
|
global current_loaded_model
|
||
|
if model is current_loaded_model:
|
||
|
return
|
||
|
unload_model()
|
||
|
try:
|
||
|
real_model = model.patch_model()
|
||
|
except Exception as e:
|
||
|
model.unpatch_model()
|
||
|
raise e
|
||
|
current_loaded_model = model
|
||
|
real_model.cuda()
|
||
|
return current_loaded_model
|