import pickle
import traceback
from pathlib import Path
import feather
import pandas as pd
from ..utils import common_utils
[docs]class DataLoader:
"""Basic class for loading raw data into the dataset.
By default, only the method :meth:`load_dataset` is called by the
:class:`.data_handler.DataHandler` instance during the dataset generation call
(by :meth:`.data_handler.DataHandler.generate_datasets`).
A basic implementation of :meth:`load_dataset` is provided, however this method calls two
other methods, :meth:`load_data_options` and :meth:`load_file_data`, that should be overriden
since they do nothing by default.
"""
CALLBACKS = {}
def __init__(self, structure):
self.data = structure
def dataset_options(self, *args, **kwargs):
common_utils.print_warning(
(
"WARNING! Calling load_data_options() method from the default DataLoader class which "
+ "does nothing. Please inherit this class and override this method for loading "
+ "any options relevant to the loading of the files of the dataset."
)
)
return {}
[docs] def load_file_data(self, *args, **kwargs):
"""Load data for the file at file_path. This usually include loading the raw data
and the tags associated with the file. This method should then fill the tmp_db_data
attribute to save the intermediate results
Args:
file_path ([type]): [description]
tags_dir ([type]): [description]
opts ([type]): [description]
"""
common_utils.print_warning(
(
"Calling load_file_data() method from the default DataLoader class which "
+ "does nothing. Please inherit this class and override this method for loading "
+ "the files and tags of the dataset."
)
)
data, tags = [], []
return data, tags
[docs] def finalize_dataset(self):
"""Callback function called after data generation is finished but before it is saved
in case some further action must be done after all files are loaded
(e.g. dataframe concatenation)
"""
def load_classes(self, database):
class_type = database.class_type
classes_file = database.classes_file
classes_df = pd.read_csv(classes_file, skip_blank_lines=True)
classes = (
classes_df.loc[
classes_df["class_type"] # pylint: disable=unsubscriptable-object
== class_type
]
.tag.str.lower()
.values
)
return classes
[docs] def generate_dataset(self, database, paths, file_list, db_type, overwrite):
"""[summary]
Args:
database ([type]): [description]
paths ([type]): [description]
file_list ([type]): [description]
db_type ([type]): [description]
overwrite ([type]): [description]
"""
db_opts = self.dataset_options(database)
split = database.get("split", {})
if split and db_type in split:
tags_dir = paths["tags"]["training"]
else:
tags_dir = paths["tags"][db_type]
for file_path in file_list:
try:
if not isinstance(file_path, Path):
file_path = Path(file_path)
intermediate = self.load_file_data(
file_path=file_path, tags_dir=tags_dir, opts=db_opts
)
if database.save_intermediates:
savename = (
paths["dest"][db_type] / "intermediate" / file_path.name
).with_suffix(".pkl")
if not savename.exists() or overwrite:
with open(savename, "wb") as f:
pickle.dump(intermediate, f, -1)
except Exception:
print("Error loading: " + str(file_path) + ", skipping.")
print(traceback.format_exc())
self.data = None
self.finalize_dataset()
def get_file_types(self, load_opts):
file_types = load_opts.get("file_types", "all")
if file_types == "all":
file_types = self.data.keys()
else:
if isinstance(file_types, str):
file_types = [file_types]
# * Make sure we only have valid keys
file_types = [ft for ft in file_types if ft in self.data.keys()]
return file_types
def load_dataset(self, paths, db_type, load_opts=None):
# opts = common_utils.deepcopy(self.DEFAULT_LOADING_OPTIONS)
# if load_opts is not None and isinstance(load_opts, dict):
# opts.update(load_opts)
load_opts = load_opts or {}
file_types = self.get_file_types(load_opts)
# * Get paths
for key in file_types:
path = paths["save_dests"][db_type][key]
if not path.exists():
raise ValueError(
"Database file {} not found. Please run check_datasets() before".format(
str(path)
)
)
tmp = self.load_dataset_file(path)
callback = self.CALLBACKS.get("onload", {}).get(key, None)
if callback:
tmp = callback(tmp)
self.data[key] = tmp
def load_dataset_file(self, file_name):
print("Loading file: ", file_name)
if file_name.suffix == ".feather":
df = feather.read_dataframe(str(file_name))
if df.empty:
common_utils.print_warning(
"Warning, loaded dataset file {} is empty".format(file_name)
)
return df
else:
return pickle.load(open(file_name, "rb"))