import re
import shutil
from copy import deepcopy
from .options import Options
from ..utils import common_utils
[docs]class ModelOptions(Options):
DEFAULT_VALUES = {
"suffix": "",
"suffix_prepend": {"default": ""},
"intermediate_save_dir": "intermediate",
}
DICT_PATH_SEPARATOR = "--"
def __init__(self, opts):
super().__init__(opts)
self._model_id = ""
self._previous_version = None
@property
def results_dir_root(self):
return self.model_dir / self.model_id
@property
def results_save_dir(self):
return self.get_results_dir()
@property
def results_load_dir(self):
return self.get_results_dir(save=False)
def version(self, save=True):
default = self.previous_version
if save or not default:
# * Add one if we are saving or if we are loading and the previous number is 0
default += 1
return self.opts.get("version", default)
@property
def previous_version(self):
if self._previous_version is None:
self._previous_version = self.get_last_version()
return self._previous_version
@property
def load_version(self):
return self.version(save=False)
@property
def save_version(self):
return self.version(save=True)
def get_results_dir(self, save=True):
return self.results_dir_root / str(self.version(save))
def get_weights_path(self, epoch=None, version=None, as_string=True):
weight_opts = self.get("weights_opts", {})
path = weight_opts.get("path", "")
if path:
return path
name = weight_opts.get("name", "")
if name and name != self.model_id:
# * Load weights from another model
# * For that, create a new model options with the new name
tmp_opts = ModelOptions(deepcopy(self.opts))
tmp_opts.model_id = name
# * Update options from weight opts
# TODO: allow redefining?
if "model_dir" in weight_opts:
tmp_opts.add_option("model_dir", weight_opts["model_dir"])
path = tmp_opts.get_weights_path()
return path
from_epoch = weight_opts.get("from_epoch", 0)
if from_epoch:
if version is None:
version = weight_opts.get("version", -1)
path = self.get_intermediate_path(
from_epoch,
version=version,
as_string=as_string,
)
else:
path = self.results_load_dir / self.model_id
if as_string:
return str(path)
return path
@property
def model_id(self):
if not self._model_id:
conf_id = self.opts.get("model_id", "")
if conf_id:
self._model_id = conf_id
else:
self._model_id = self.name + self.add_suffix()
# self.opts["model_id"] = self._model_id
return self._model_id
@model_id.setter
def model_id(self, value):
self._model_id = value
def add_suffix(self):
pattern = self.suffix
prefixes = self.suffix_prepend
to_replace = re.findall("\\{(.+?)\\}", pattern)
res = {}
for key in to_replace:
mid = ""
if prefixes:
prefix = prefixes.get(key, prefixes.get("default", ""))
mid += str(prefix)
if self.DICT_PATH_SEPARATOR in key:
value = common_utils.get_dict_path(
self.opts, key, key, sep=self.DICT_PATH_SEPARATOR
)
else:
value = self.opts.get(key, key)
mid += common_utils.list2str(value)
res[key] = mid
mid = pattern.format(**res)
return mid
def get_last_version(self):
version = 0
path = self.results_dir_root
if path.exists():
for item in path.iterdir():
if item.is_dir():
try:
# TODO: check if empty (or only databases)
if self.get("clean_empty_models", False):
is_empty = len(list(item.iterdir())) <= 1
if is_empty:
common_utils.print_info(
"Found empty directory {} and clean_empty_models is True. Removing folder".format(
item
)
)
shutil.rmtree(item)
continue
res = int(item.name)
if res >= version:
version = res
except ValueError:
continue
return version