import pickle
import traceback
from pathlib import Path
import feather
import pandas as pd
from ..utils import common_utils
[docs]class DataLoader:
"""Basic class for loading raw data into the dataset.
By default, only the method :meth:`load_dataset` is called by the
:class:`.data_handler.DataHandler` instance during the dataset generation call
(by :meth:`.data_handler.DataHandler.generate_datasets`).
A basic implementation of :meth:`load_dataset` is provided, however this method calls two
other methods, :meth:`load_data_options` and :meth:`load_file_data`, that should be overriden
since they do nothing by default.
"""
CALLBACKS = {}
def __init__(self, structure):
self.data = structure
def dataset_options(self, *args, **kwargs):
common_utils.print_warning(
(
"WARNING! Calling load_data_options() method from the default DataLoader class which "
+ "does nothing. Please inherit this class and override this method for loading "
+ "any options relevant to the loading of the files of the dataset."
)
)
return {}
[docs] def load_file_data(self, *args, **kwargs):
"""Load data for the file at file_path. This usually include loading the raw data
and the tags associated with the file. This method should then fill the tmp_db_data
attribute to save the intermediate results
Args:
file_path ([type]): [description]
tags_dir ([type]): [description]
opts ([type]): [description]
"""
common_utils.print_warning(
(
"Calling load_file_data() method from the default DataLoader class which "
+ "does nothing. Please inherit this class and override this method for loading "
+ "the files and tags of the dataset."
)
)
data, tags = [], []
return data, tags
[docs] def finalize_dataset(self, missing):
"""Callback function called after data generation is finished but before it is saved
in case some further action must be done after all files are loaded
(e.g. dataframe concatenation)
"""
[docs] def generate_dataset(
self, database, paths, file_list, db_type, missing=None, overwrite=False
):
"""[summary]
Args:
database ([type]): [description]
paths ([type]): [description]
file_list ([type]): [description]
db_type ([type]): [description]
overwrite ([type]): [description]
"""
db_opts = self.dataset_options(database)
split = database.get("split", {})
if split and db_type in split:
tags_dir = paths["tags"]["training"]
else:
tags_dir = paths["tags"][db_type]
for file_path in file_list:
try:
if not isinstance(file_path, Path):
file_path = Path(file_path)
intermediate = self.load_file_data(
file_path=file_path,
tags_dir=tags_dir,
opts=db_opts,
missing=missing,
)
if database.save_intermediates:
savename = (
paths["dest"][db_type] / "intermediate" / file_path.name
).with_suffix(".pkl")
if not savename.exists() or overwrite:
with open(savename, "wb") as f:
pickle.dump(intermediate, f, -1)
except Exception:
print("Error loading: " + str(file_path) + ", skipping.")
print(traceback.format_exc())
# self.data = None
self.finalize_dataset(missing)
def get_file_types(self, load_opts):
file_types = load_opts.get("file_types", "all")
if file_types == "all":
file_types = self.data.keys()
else:
if isinstance(file_types, str):
file_types = [file_types]
# * Make sure we only have valid keys
file_types = [ft for ft in file_types if ft in self.data.keys()]
return file_types
def load_dataset(self, paths, db_type, load_opts=None):
# opts = common_utils.deepcopy(self.DEFAULT_LOADING_OPTIONS)
# if load_opts is not None and isinstance(load_opts, dict):
# opts.update(load_opts)
load_opts = load_opts or {}
file_types = self.get_file_types(load_opts)
# * Get paths
for key in file_types:
path = paths["save_dests"][db_type][key]
if not path.exists():
raise ValueError(
"Database file {} not found. Please run check_datasets() before".format(
str(path)
)
)
tmp = self.load_dataset_file(path)
callback = self.CALLBACKS.get("onload", {}).get(key, None)
if callback:
tmp = callback(tmp)
self.data[key] = tmp
def load_dataset_file(self, file_name):
print("Loading file: ", file_name)
if file_name.suffix == ".feather":
df = feather.read_dataframe(str(file_name))
if df.empty:
common_utils.print_warning(
"Warning, loaded dataset file {} is empty".format(file_name)
)
return df
else:
return pickle.load(open(file_name, "rb"))