Source code for mouffet.evaluation.evaluator

from abc import ABC, abstractmethod

import pandas as pd

from ..utils import common_utils
from ..plotting import plot


[docs]class Evaluator(ABC): NAME = "" DEFAULT_PR_CURVE_OPTIONS = { "variable": "activity_threshold", "values": {"start": 0, "end": 1, "step": 0.05}, } PLOTS = {} REQUIRES = [] def requires(self, options): return self.REQUIRES def run_evaluation(self, data, options, infos): res = {} if options.get("filter_only", False): predictions, _ = data res["events"] = self.filter_predictions(predictions, options) res["stats"] = {} elif options.get("do_PR_curve", False): res = self.get_PR_curve(data, options, infos) else: res = self.evaluate_scenario(data, options, infos) return res def evaluate_scenario(self, data, options, infos): res = self.evaluate(data, options, infos) return res @abstractmethod def evaluate(self, data, options, infos): return {"stats": None, "matches": None} def get_PR_scenarios(self, options): pr_scenarios = options.get("scenarios_PR_curve", {}) if not pr_scenarios: common_utils.print_warning( "do_PR_curve is set to True but no option for scenarios_PR_curve has been found." ) scenarios = [] for scenario in common_utils.expand_options_dict(pr_scenarios): tmp = common_utils.deep_dict_update(options, scenario, copy=True) # options[opts["variable"]] = opts["values"] scenarios.append(tmp) # scenarios = common_utils.expand_options_dict(options) return scenarios def get_PR_curve(self, data, options, infos): scenarios = self.get_PR_scenarios(options) tmp = [] for scenario in scenarios: tmp.append(self.evaluate_scenario(data, scenario, infos)) res = common_utils.listdict2dictlist(tmp) res["matches"] = pd.concat(res["matches"]) res["stats"] = pd.concat(res["stats"]) res["plots"] = common_utils.listdict2dictlist(res.get("plots", [])) if options.get("draw_plots", True): res = plot.plot_PR_curve(res, options) # pylint: disable=no-member return res def draw_plots(self, data, options, infos): res = {} plots = options.get("plots", []) for to_plot in plots: func = self.PLOTS.get(to_plot, None) if func is not None: tmp = func(data, options, infos) res[to_plot] = tmp # func_name = "plot_" + to_plot.strip() # if hasattr(self, func_name) and callable(getattr(self, func_name)): # tmp = getattr(self, func_name)(data, options, infos) # res[to_plot] = tmp return res def filter_predictions(self, predictions, options, tags=None): return [] def check_database(self, data, options, infos): if infos["database"] not in options.get(self.NAME + "_databases", []): common_utils.print_info( ( "Database {0} is not part of the accepted databases for the '{1}' " + "evaluator described in the '{1}_databases' option. Skipping." ).format(options["scenario_info"]["database"], self.NAME) ) return False return True