import shutil
from copy import deepcopy
from .options import Options
from ..utils import common_utils
[docs]class ModelOptions(Options):
DEFAULT_VALUES = {
"suffix": "",
"suffix_prepend": {"default": ""},
"intermediate_save_dir": "intermediate",
}
DICT_PATH_SEPARATOR = "--"
def __init__(self, opts):
super().__init__(opts)
self._model_id = ""
self._previous_version = None
@property
def results_dir_root(self):
return self.model_dir / self.model_id
@property
def results_save_dir(self):
return self.get_results_dir()
@property
def results_load_dir(self):
return self.get_results_dir(save=False)
def version(self, save=True):
default = self.previous_version
if save or not default:
# * Add one if we are saving or if we are loading and the previous number is 0
default += 1
return self.opts.get("version", default)
@property
def previous_version(self):
if self._previous_version is None:
self._previous_version = self.get_last_version()
return self._previous_version
@property
def load_version(self):
return self.version(save=False)
@property
def save_version(self):
return self.version(save=True)
def get_results_dir(self, save=True):
return self.results_dir_root / str(self.version(save))
def get_weights_path(self, version=None, as_string=True):
weight_opts = self.get("weights_opts", {})
name = weight_opts.get("name", "")
path = weight_opts.get("path", "")
if self.inference:
# * Use model own weights in inference mode
name = ""
path = ""
if path:
return path
if name and name != self.model_id:
# * Load weights from another model
# * For that, create a new model options with the new name
tmp_opts = ModelOptions(deepcopy(self.opts))
tmp_opts.model_id = name
# * Update options from weight opts
# TODO: allow redefining?
if "model_dir" in weight_opts:
tmp_opts.add_option("model_dir", weight_opts["model_dir"])
path = tmp_opts.get_weights_path()
return path
from_epoch = weight_opts.get("from_epoch", 0)
if from_epoch:
if version is None:
version = weight_opts.get("version", -1)
path = self.get_intermediate_path(
from_epoch,
version=version,
as_string=as_string,
)
else:
path = self.results_load_dir / self.model_id
if as_string:
return str(path)
return path
@property
def model_id(self):
if not self._model_id:
conf_id = self.opts.get("model_id", "")
if conf_id:
self._model_id = conf_id
else:
self._model_id = self.name + common_utils.resolve_dict_pattern(
self.opts, "suffix"
)
return self._model_id
@model_id.setter
def model_id(self, value):
self._model_id = value
def get_last_version(self):
version = 0
path = self.results_dir_root
if path.exists():
for item in path.iterdir():
if item.is_dir():
try:
# TODO: check if empty (or only databases)
if self.get("clean_empty_models", False):
is_empty = len(list(item.iterdir())) <= 1
if is_empty:
common_utils.print_info(
"Found empty directory {} and clean_empty_models is True. Removing folder".format(
item
)
)
shutil.rmtree(item)
continue
res = int(item.name)
if res >= version:
version = res
except ValueError:
continue
return version