Source code for mouffet.data.data_loader

import pickle
import traceback
from pathlib import Path
import feather

import pandas as pd

from ..utils import common_utils


[docs]class DataLoader: """Basic class for loading raw data into the dataset. By default, only the method :meth:`load_dataset` is called by the :class:`.data_handler.DataHandler` instance during the dataset generation call (by :meth:`.data_handler.DataHandler.generate_datasets`). A basic implementation of :meth:`load_dataset` is provided, however this method calls two other methods, :meth:`load_data_options` and :meth:`load_file_data`, that should be overriden since they do nothing by default. """ CALLBACKS = {} def __init__(self, structure): self.data = structure def dataset_options(self, *args, **kwargs): common_utils.print_warning( ( "WARNING! Calling load_data_options() method from the default DataLoader class which " + "does nothing. Please inherit this class and override this method for loading " + "any options relevant to the loading of the files of the dataset." ) ) return {}
[docs] def load_file_data(self, *args, **kwargs): """Load data for the file at file_path. This usually include loading the raw data and the tags associated with the file. This method should then fill the tmp_db_data attribute to save the intermediate results Args: file_path ([type]): [description] tags_dir ([type]): [description] opts ([type]): [description] """ common_utils.print_warning( ( "Calling load_file_data() method from the default DataLoader class which " + "does nothing. Please inherit this class and override this method for loading " + "the files and tags of the dataset." ) ) data, tags = [], [] return data, tags
[docs] def finalize_dataset(self): """Callback function called after data generation is finished but before it is saved in case some further action must be done after all files are loaded (e.g. dataframe concatenation) """
def load_classes(self, database): class_type = database.class_type classes_file = database.classes_file classes_df = pd.read_csv(classes_file, skip_blank_lines=True) classes = ( classes_df.loc[ classes_df["class_type"] # pylint: disable=unsubscriptable-object == class_type ] .tag.str.lower() .values ) return classes
[docs] def generate_dataset(self, database, paths, file_list, db_type, overwrite): """[summary] Args: database ([type]): [description] paths ([type]): [description] file_list ([type]): [description] db_type ([type]): [description] overwrite ([type]): [description] """ db_opts = self.dataset_options(database) split = database.get("split", {}) if split and db_type in split: tags_dir = paths["tags"]["training"] else: tags_dir = paths["tags"][db_type] for file_path in file_list: try: if not isinstance(file_path, Path): file_path = Path(file_path) intermediate = self.load_file_data( file_path=file_path, tags_dir=tags_dir, opts=db_opts ) if database.save_intermediates: savename = ( paths["dest"][db_type] / "intermediate" / file_path.name ).with_suffix(".pkl") if not savename.exists() or overwrite: with open(savename, "wb") as f: pickle.dump(intermediate, f, -1) except Exception: print("Error loading: " + str(file_path) + ", skipping.") print(traceback.format_exc()) self.data = None self.finalize_dataset()
def get_file_types(self, load_opts): file_types = load_opts.get("file_types", "all") if file_types == "all": file_types = self.data.keys() else: if isinstance(file_types, str): file_types = [file_types] # * Make sure we only have valid keys file_types = [ft for ft in file_types if ft in self.data.keys()] return file_types def load_dataset(self, paths, db_type, load_opts=None): # opts = common_utils.deepcopy(self.DEFAULT_LOADING_OPTIONS) # if load_opts is not None and isinstance(load_opts, dict): # opts.update(load_opts) load_opts = load_opts or {} file_types = self.get_file_types(load_opts) # * Get paths for key in file_types: path = paths["save_dests"][db_type][key] if not path.exists(): raise ValueError( "Database file {} not found. Please run check_datasets() before".format( str(path) ) ) tmp = self.load_dataset_file(path) callback = self.CALLBACKS.get("onload", {}).get(key, None) if callback: tmp = callback(tmp) self.data[key] = tmp def load_dataset_file(self, file_name): print("Loading file: ", file_name) if file_name.suffix == ".feather": df = feather.read_dataframe(str(file_name)) if df.empty: common_utils.print_warning( "Warning, loaded dataset file {} is empty".format(file_name) ) return df else: return pickle.load(open(file_name, "rb"))