Source code for mouffet.training.training_handler

import copy
import time
import traceback
from datetime import datetime

import pandas as pd

from ..data import DB_TYPE_TRAINING, DB_TYPE_VALIDATION
from ..options import ModelOptions
from ..utils import ModelHandler, common_utils, file_utils, config_utils


[docs]class TrainingHandler(ModelHandler): DB_TYPES = [ DB_TYPE_TRAINING, DB_TYPE_VALIDATION, ] def check_repeat(self, scenarios): res = [] for scenario in scenarios: repeat = int(scenario.get("repeat_scenario", 0)) if repeat: for i in range(1, repeat + 1): tmp = copy.deepcopy(scenario) tmp["name"] = tmp["name"] + "_rep" + str(i) res.append(tmp) else: res.append(scenario) return res def post_process_scenarios(self, scenarios): scenarios = self.check_repeat(scenarios) return scenarios def expand_training_scenarios(self): scenarios = [] if "scenarios" in self.opts: clean = dict(self.opts) scens = clean.pop("scenarios") for scenario in scens: for opts in common_utils.expand_options_dict(scenario): res = dict(clean) res = common_utils.deep_dict_update(res, opts, copy=True) scenarios.append(res) else: scenarios.append(dict(self.opts)) scenarios = self.post_process_scenarios(scenarios) return scenarios def load_scenarios(self): return self.expand_training_scenarios() def get_scenario_databases_options(self, scenario): db_opts = [] opts_update = scenario.get("databases_options", {}) databases = scenario.get("databases") if not isinstance(databases, list): databases = [databases] for db_name in databases: db_opt = self.data_handler.update_database(opts_update, db_name) if db_opt: db_opts.append(db_opt) return db_opts def is_already_trained(self, scenario, models_stats, model_opts): if model_opts.get("skip_trained", False): print("Checking if already trained") tmp = models_stats.loc[ models_stats.opts == str(scenario) # pylint: disable=no-member ] if not tmp.empty: return True return False def train_scenario(self, scenario): try: scenario_info = {} start = time.time() models_stats = None common_utils.print_title( "Training scenario with options: {}".format(scenario) ) model_opts = ModelOptions(copy.deepcopy(scenario)) # * Load model stats database models_stats_path = ( model_opts.model_dir / config_utils.MODELS_STATS_FILE_NAME ) if models_stats_path.exists(): models_stats = pd.read_csv(models_stats_path) # * Check if model has already been trained if models_stats is not None and self.is_already_trained( scenario, models_stats, model_opts ): common_utils.print_info( "Training for the model has already been completed and 'skip_trained' is True. Skipping scenario" ) return 2 # * Get model instance if model_opts.get("transfer_learning", False): model = self.load_model(model_opts) model_opts = model.opts else: model = self.get_model_instance(model_opts) if not model.check_options(): common_utils.print_error( "Skipping training because options are invalid" ) return -1 # * Check datasets databases = self.get_scenario_databases_options(scenario) self.data_handler.check_datasets( databases=databases, db_types=self.DB_TYPES ) # * Prepare data for training (e.g. preprocessing) data = [ self.data_handler.load_datasets( db_type, databases=databases, prepare=True, prepare_opts=model_opts ) for db_type in self.DB_TYPES ] # * Save databases options for training model.save_options("databases.yaml", databases) train_start = time.time() # * Perform training train_stats = model.train(*data) model.save_model() end = time.time() # * Save model training information scenario_info["date"] = datetime.now().strftime("%d/%m/%Y %H:%M:%S") scenario_info["global_duration"] = round(end - start, 2) scenario_info["training_duration"] = round(end - train_start, 2) scenario_info["n_epochs"] = scenario["n_epochs"] scenario_info["model_id"] = model_opts.model_id scenario_info["version"] = model_opts.save_version scenario_info["n_parameters"] = model.n_parameters scenario_info["opts"] = str(scenario) scenario_info.update(train_stats) df = pd.DataFrame([scenario_info]) if models_stats is not None: models_stats = pd.concat([models_stats, df]) else: models_stats = df file_utils.ensure_path_exists(models_stats_path, is_file=True) models_stats.to_csv(models_stats_path, index=False) return 1 except Exception: print(traceback.format_exc()) common_utils.print_error( "Error training the model for scenario {}".format(scenario) ) def train(self): if not self.data_handler: raise AttributeError( "An instance of class DataHandler must be provided in data_handler" + "attribute or at class initialisation" ) return [self.train_scenario(scenario) for scenario in self.scenarios]