from abc import ABC, abstractmethod
import yaml
from ..utils import file as file_utils
[docs]class Model(ABC):
"""Base abstract class for defining models. Must be subclassed.
Attributes:
NAME: Name of the model class. Will be used to identify and save classes
STEP_TRAINING: The name of the training step. Will be used in configuration file
STEP_VALIDATION: The name of the validation step. Will be used in configuration file
"""
NAME = "MODEL"
STEP_TRAINING = "train"
STEP_VALIDATION = "validation"
def __init__(self, opts=None):
"""Create the layers of the neural network, with the same options we used in training"""
self.model = None
self._opts = None
if opts:
self.opts = opts
def check_options(self):
return True
@property
def n_parameters(self):
return -1
@property
def opts(self):
"""Property that contains the options related to the model as read in the configuration file"""
return self._opts
@opts.setter
def opts(self, opts):
"""When setting options the name of the model as described by the NAME attribute is set in
the options and the method create_model() is called to initialize the model
Args:
opts (dict): Model options
"""
self._opts = opts
if not opts["name"]:
self._opts.opts["name"] = self.NAME
[docs] @abstractmethod
def create_model(self):
"""Abstract method where model creation / network initialization should take place
Raises:
NotImplementedError: [description]
"""
raise NotImplementedError(
"create_model function not implemented for this class"
)
[docs] @abstractmethod
def predict(self, x):
"""Predict data using the model
Args:
x (data): the data on which we want to makae a prediction. This function should make only
one prediction and be called numerous times if needed
Raises:
NotImplementedError: [description]
"""
raise NotImplementedError("predict function not implemented for this class")
[docs] @abstractmethod
def save_model(self, path=None):
"""Save the model to disk
Args:
path (Object, optional): Path where the model should be saved. Defaults to None.
Raises:
NotImplementedError: Should be inherited by the final model
"""
raise NotImplementedError("save_model function not implemented for this class")
[docs] def prepare_data(self, data):
"""Prepare the data before training. Allows to perform last minute changes to the data just before
training.
Args:
data (Object): The data to be prepared
Returns:
Object: The modified data
"""
return data
[docs] def save_options(self, file_name, options):
"""Save the options related to the model for logging purposes as a yaml file. By default, uses the
"results_save_dir" property from the ModelOptions class associated to the model.
By default, this will be the following combination:
model_dir/model_id/version
where model_dir and model_id can be found in the model configuration file and version is
calculated automatically.
Args:
file_name (string or pathlib.Path): The name of the file to be saved
options (Object): The options to save that can be transcribed as a yaml file
"""
file_utils.ensure_path_exists(self.opts.results_save_dir)
with open(self.opts.results_save_dir / file_name, "w") as f:
yaml.dump(options, f, default_flow_style=False)