import time
import traceback
from abc import abstractmethod
from datetime import datetime
from itertools import product
from pathlib import Path
import feather
import pandas as pd
from ..options import ModelOptions
from ..plotting import plot
from ..utils import ModelHandler, common_utils, file_utils
from . import EVALUATORS
[docs]class EvaluationHandler(ModelHandler):
"""Base class for evaluating models. Inherits ModelHandler
Relevant options:
<Global>
data_config: Path to the data configuration file used to initialize the data handler.
evaluation_dir: Directory where to save results
<Models>
model_dir: Directory where to load models.
predictions_dir: Directory where to load/save predictions.
reclassify: Run the model again even if a prediction file is found
<Evaluators>
type: The type of evaluator corresponding to one of the keys of EVALUATORS attribute
of the subclass
Args:
ModelHandler ([type]): [description]
Raises:
AttributeError: [description]
Returns:
[type]: [description]
"""
PREDICTIONS_STATS_FILE_NAME = "predictions_stats.csv"
PREDICTIONS_STATS_DUPLICATE_COLUMNS = ["database", "model_id"]
# EVALUATORS = {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# plot.set_plotting_package(options=self.opts) # pylint: disable=no-member
[docs] @abstractmethod
def predict_database(self, model, database, db_type="test"):
"""
This function calls a model to get predictions for an entire database. The data to be classified
is usually loaded there, since predictions can be saved to avoid the reclassification.
This avoids loading the data for nothing.
This function also logs general information about the classification that is stored in the
infos dict
Args:
model (_type_): _description_
database (_type_): _description_
db_type (str, optional): _description_. Defaults to "test".
Returns:
tuple : a tuple containing an instance of pandas Dataframe with the predictions given by
the model and a dict with general informations about the classification process (e.g.
number of items process, time spent during the classification, etc.)
"""
preds = pd.DataFrame()
infos = {}
return (preds, infos)
def get_predictions_dir(self, model_opts, database):
preds_dir = self.get_option("predictions_dir", model_opts)
if not preds_dir:
raise AttributeError(
"Please provide a directory where to save the predictions using"
+ " the predictions_dir option in the config file"
)
return Path(preds_dir)
def get_predictions_file_name(self, model_opts, database):
return (
database.name
+ "_"
+ model_opts.model_id
+ "_v"
+ str(model_opts.load_version)
+ ".feather"
)
def on_get_predictions_end(self, preds, model_opts):
return preds
def get_predictions(self, model_opts, database):
preds_dir = self.get_predictions_dir(model_opts, database)
file_name = self.get_predictions_file_name(model_opts, database)
pred_file = preds_dir / file_name
if not model_opts.get("repredict", False) and pred_file.exists():
preds = feather.read_dataframe(pred_file)
else:
# * Load predictions stats database
scenario_info = {}
preds_stats = None
preds_stats_dir = Path(self.get_option("predictions_dir", model_opts))
preds_stats_path = preds_stats_dir / self.PREDICTIONS_STATS_FILE_NAME
if preds_stats_path.exists():
preds_stats = pd.read_csv(preds_stats_path)
model_opts.opts["data_config"] = self.opts["data_config"]
model_opts.opts["model_dir"] = self.get_option("model_dir", model_opts)
model_opts.opts["inference"] = True
common_utils.print_info("Loading model with options: " + str(model_opts))
model = self.load_model(model_opts)
preds, infos = self.predict_database(model, database, db_type="test")
# * save classification stats
scenario_info["date"] = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
scenario_info["model_id"] = model_opts.model_id
scenario_info.update(infos)
df = pd.DataFrame([scenario_info])
if preds_stats is not None:
preds_stats = pd.concat([preds_stats, df])
preds_stats = preds_stats.drop_duplicates(
subset=self.PREDICTIONS_STATS_DUPLICATE_COLUMNS, keep="last"
)
else:
preds_stats = df
file_utils.ensure_path_exists(preds_stats_path, is_file=True)
preds_stats.to_csv(preds_stats_path, index=False)
pred_file.parent.mkdir(parents=True, exist_ok=True)
feather.write_dataframe(preds, pred_file)
preds = self.on_get_predictions_end(preds, model_opts)
return preds
def consolidate_results(self, results):
res = common_utils.listdict2dictlist(results)
if res:
if "matches" in res:
res["matches"] = pd.concat(res["matches"])
if "stats" in res:
res["stats"] = pd.concat(res["stats"])
if "plots" in res:
res["plots"] = common_utils.listdict2dictlist(
res.get("plots", []), flatten=True
)
return res
def save_pr_curve_data(self, pr_df):
print("saving_pr_curve data")
res_dir = Path(self.opts.get("evaluation_dir", "."))
pr_file = res_dir / self.opts.get("PR_curve_save_file", "PR_curves.feather")
if pr_file.exists():
pr_curves = pd.read_feather(pr_file)
res = pd.concat([pr_df, pr_curves]).drop_duplicates()
else:
res = pr_df
res.reset_index(inplace=True, drop=True)
res.to_feather(pr_file)
def check_plotting_package(self, key):
plt_pkg = self.opts.get("plot_options", {}).get(key, {}).get("package", "")
if plt_pkg and plt_pkg != plot.pkg: # pylint: disable=no-member
plot.set_plotting_package(pkg=plt_pkg) # pylint: disable=no-member
elif plot.pkg != plot.DEFAULT_PLOTTING_PACKAGE: # pylint: disable=no-member
plot.set_plotting_package() # pylint: disable=no-member
def save_results(self, res):
file_names = {}
if res:
prefix = ""
cur_time = datetime.now()
res_dir = Path(self.opts.get("evaluation_dir", "."))
if self.opts.get("save_use_date_subfolder", True):
res_dir /= cur_time.strftime("%Y%m%d")
if self.opts.get("save_use_time_prefix", True):
prefix = cur_time.strftime("%H%M%S")
eval_id = self.opts.get("id", "")
stats_file_path = str(
file_utils.ensure_path_exists(
res_dir / ("_".join(filter(None, [prefix, eval_id, "stats.csv"]))),
is_file=True,
)
)
res["stats"].to_csv(stats_file_path, index=False)
file_names["stats"] = stats_file_path
pr_df = res["stats"].loc[
res["stats"]["PR_curve"] == True # pylint: disable=singleton-comparison
]
if not pr_df.empty:
self.save_pr_curve_data(pr_df)
for plot_type in ["", "global_"]:
plots = res.get(plot_type + "plots", {})
if plots:
for key, values in plots.items():
self.check_plotting_package(key)
if values:
plot_file_path = res_dir / (
"_".join(
filter(
None,
[prefix, eval_id, "{}.pdf".format(key)],
)
)
)
plot.save_as_pdf(
values,
plot_file_path,
)
file_names[plot_type + "plot_" + key] = plot_file_path
return file_names
def draw_global_plots(self, results):
plts = {}
plots = self.opts.get("global_plots", [])
for to_plot in plots:
func_name = "plot_" + to_plot.strip()
if hasattr(self, func_name) and callable(getattr(self, func_name)):
tmp = getattr(self, func_name)(results)
plts[to_plot] = tmp
return plts
def expand_scenarios(self, element_type):
if not element_type in self.opts:
elements = []
else:
elements = self.opts[element_type]
default = self.opts.get(element_type + "_options", {})
scenarios = []
for element in elements:
# * Add default options to scenario
tmp = common_utils.deep_dict_update(default, element, copy=True)
if "scenarios" in tmp:
scenario = tmp.pop("scenarios")
for opts in common_utils.expand_options_dict(scenario):
# * Add expanded options
res = common_utils.deep_dict_update(tmp, opts, copy=True)
scenarios.append(res)
else:
scenarios.append(tmp)
return scenarios
def get_models_by_id(self):
model_ids = self.opts.get("model_ids", [])
for model_id in model_ids:
if isinstance(model_id, dict):
pass
if isinstance(model_id, str):
pass
return []
def get_model_scenarios(self):
model_scenarios = self.expand_scenarios("models")
return model_scenarios
def load_scenarios(self):
db_scenarios = self.expand_scenarios("databases")
model_scenarios = self.get_model_scenarios()
evaluator_scenarios = self.expand_scenarios("evaluators")
res = product(db_scenarios, model_scenarios, evaluator_scenarios)
return list(res)
def load_tags(self, database, types):
return self.data_handler.load_dataset(
database, "test", load_opts={"file_types": types}
)
def add_global_options(self, opts):
if "models_options" in self.opts:
opts.add_option(
"models_options", self.opts["models_options"], overwrite=False
)
if "databases_options" in self.opts:
opts.add_option(
"databases_options", self.opts["databases_options"], overwrite=False
)
return opts
def skip_database(self, db, evaluator_opts):
include = evaluator_opts.get("databases", [])
if include and not db in include:
common_utils.print_info(
(
"Database {} is not in the accepted databases list of evaluator {}. "
+ "Skipping."
).format(db, evaluator_opts["type"])
)
return True
exclude = evaluator_opts.get("exclude_databases", [])
if exclude and db in exclude:
common_utils.print_info(
"Database {} is in the excluded databases of evaluator {}. Skipping.".format(
db, evaluator_opts["type"]
)
)
return True
return False
def perform_evaluation(
self, evaluator, evaluation_data, scenario_infos, scenario_opts
):
eval_result = {}
print(
"\033[92m"
+ "Processing model {0} on dataset {1} with evaluator {2}".format(
scenario_infos["model"],
scenario_infos["database"],
scenario_infos["evaluator"],
)
+ "\033[0m"
)
start = time.time()
eval_result = evaluator.run_evaluation(
evaluation_data, scenario_opts["evaluator_opts"], scenario_infos
)
end = time.time()
if eval_result:
eval_result["stats"] = eval_result.get("stats", {})
eval_result["stats"]["PR_curve"] = scenario_opts["evaluator_opts"].get(
"do_PR_curve", False
)
eval_result["stats"]["duration"] = round(end - start, 2)
eval_result["stats"] = pd.concat(
[
pd.DataFrame([scenario_infos]),
eval_result["stats"].assign(
**{key: str(value) for key, value in scenario_opts.items()}
),
],
axis=1,
)
return eval_result
def get_evaluation_data(self, evaluator, database, model_opts, evaluator_opts):
eval_requires = evaluator.requires(evaluator_opts)
database.check_dataset("test", file_types=eval_requires)
preds = self.get_predictions(model_opts, database)
if evaluator_opts.get("filter_only", False):
tags = None
else:
tags = self.data_handler.load_dataset(
"test",
database,
load_opts={"file_types": eval_requires},
)
return preds, tags
def evaluate_scenario(self, opts):
try:
db_opts, model_opts, evaluator_opts = opts
if self.skip_database(db_opts["name"], evaluator_opts):
return {}
model_opts = ModelOptions(model_opts)
# * Add global option to model options for id resolution
model_opts = self.add_global_options(model_opts)
if "databases_options" in model_opts:
common_utils.deep_dict_update(db_opts, model_opts.databases_options)
# * Duplicate database options
try:
database = self.data_handler.duplicate_database(db_opts)
except KeyError:
common_utils.print_error(
(
"Database '{}' does not exists. Please check that the "
+ " database is properly defined in the data configuration file"
).format(db_opts["name"])
)
return {}
eval_result = {}
if database and database.has_type("test"):
scenario_infos = {
"database": database.name,
"model": model_opts.model_id,
"class": database.class_type,
"evaluator": evaluator_opts.get("type", None),
"evaluation_id": self.opts.get("id", ""),
}
scenario_opts = {
"evaluator_opts": evaluator_opts,
"database_opts": database.updated_opts,
"model_opts": model_opts,
}
evaluator_opts["scenario_info"] = scenario_infos
evaluator = EVALUATORS[evaluator_opts.get("type", None)]
if evaluator:
evaluation_data = self.get_evaluation_data(
evaluator, database, model_opts, evaluator_opts
)
eval_result = self.perform_evaluation(
evaluator, evaluation_data, scenario_infos, scenario_opts
)
return eval_result
except Exception:
print(traceback.format_exc())
common_utils.print_error(
"Error evaluating the model for scenario {}".format(opts)
)
return {}
def evaluate(self):
if not self.scenarios:
common_utils.print_warning("No scenarios found for this evaluator")
return []
res = [self.evaluate_scenario(scenario) for scenario in self.scenarios]
results = self.consolidate_results(res)
if self.opts.get("draw_global_plots", False):
results["global_plots"] = self.draw_global_plots(results)
if self.opts.get("save_results", True):
self.save_results(results)
return results