Stimulus refactor
Context
stimulus separates data processing in three components:
- split (into train/validation/test)
- transform (modifications to raw data - downsampling for example)
- encode (raw data to pytorch tensors)
Some key observations considering the input data (example adapted from kaggle):
passenger_id | survived | pclass | sex | age | sibsp | parch | fare | embarked |
---|---|---|---|---|---|---|---|---|
1 | 0 | 3 | male | 22.0 | 1 | 0 | 7.25 | S |
2 | 1 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C |
3 | 1 | 3 | female | 26.0 | 0 | 0 | 7.925 | S |
4 | 1 | 1 | female | 35.0 | 1 | 0 | 53.1 | S |
- Split needs to be called on a specific (group of) column(s).
- Column types vary, thus transforms are column-specific.
- Likewise for encoding.
- Columns serve different purposes:
- input (e.g.,
age
) - target (
survived
) - meta-data (
passenger_id
)
- input (e.g.,
Since stimulus is thought to be a command line tool first, these bits and pieces are defined in an external .yaml configuration file. For encoders, it looks like this:
columns:
- column_name: fare
column_type: input
data_type: float
encoder:
- name: NumericEncoder
params:
Here, we are defining that fare
is an input column, and need to be encoded using NumericEncoder
We do the same for transforms:
transforms:
- transformation_name: noise
columns:
- column_name: age
transformations:
- name: GaussianNoise
params:
std: [0.1, 0.2, 0.3]
- column_name: fare
transformations:
- name: GaussianNoise
params:
std: [0.1, 0.2, 0.3]
Here we apply GaussianNoise to two different columns with three different standard deviation parameters.
and we define the split parameters as such :
split:
- split_method: RandomSplit
split_input_columns: [age]
params:
split: [0.7, 0.15, 0.15]
We use a RandomSplit splitter on the age
column, separating the data into 70% for training, 15% for validation and 15% for testing.
Interfacing configuration files with the data, the old way
The configuration file serves as an interface between the data and the code, for this to happen correctly, we need to:
- General
- understand which columns are input, meta-data or target(label).
- Encoders
- link the encoder with its column.
- fetch the proper encoder from the codebase.
- use the right encoder on the right column when considering the full dataset, taking care of input, meta-data or target considerations.
- Transforms
- link transforms with their columns.
- fetch the proper transforms from the codebase.
- remember the order of transforms within a group (e.g. the
noise
group defined above). - use the right transform group on the right columns when considering the full dataset.
- Split
- fetch the proper splitter from the codebase.
- use it on the right column or set of columns.
Until now, we dealt with the problem by providing three levels of abstractions.
- Loaders, for example, EncoderLoader includes boilerplate code to load from the configuration file, fetch the proper encoder from the codebase and get its
encode_all
method. - Managers, for example, EncodeManager, holds the logic of which encoder to use on which column and boilerplate code to apply encoders to a subset of the dataframe.
- Finally, the dataset handler main class, split into
DatasetLoader
andDatasetProcessor
, contains boilerplate code to load the data from disk, apply transformations and split, and feed the data to the network as a dictionary of PyTorch tensors.
If you would need to visualize it, would look like this:
This outlines two main problems:
- Three levels of abstraction, the codebase is too hard to understand.
- Tight coupling between the various classes, if you change one module, everything breaks.
point 1. is fairly easy to understand by looking at the above diagram, so let’s discuss point 2.
point 2. is best understood by looking at the current TorchDataset class which interfaces the data with the neural network.
class TorchDataset(torch.utils.data.Dataset):
"""Class for creating a torch dataset."""
def __init__(
self,
config_path: str,
csv_path: str,
encoder_loader: loaders.EncoderLoader,
split: Optional[int] = None,
) -> None:
"""Initialize the TorchDataset.
Args:
config_path: Path to the configuration file
csv_path: Path to the CSV data file
encoder_loader: Encoder loader instance
split: Optional tuple containing split information
"""
self.loader = data_handlers.DatasetLoader(
config_path=config_path,
csv_path=csv_path,
encoder_loader=encoder_loader,
split=split,
)
def __len__(self) -> int:
return len(self.loader)
def __getitem__(self, idx: int) -> tuple[dict, dict, dict]:
return self.loader[idx]
This function, which should be simple in its implementation, will perform lots of things under the hood:
- load the data from disk using the proper split (train/validation/test) if needed
- load the configuration file defined above
- make sure encoder loader fits
- return the data in a format that can be used by the neural network
If any of those steps is changed somewhere else in the code (for instance, DatasetLoader
config path argument is renamed), TorchDataset will break. Every little change requires hours of debugging.
Since scientific code requires to be extremely flexible (to try the new cool things), we want to minimize the time it takes to implement a change in the codebase.
Interfacing configuration files with the data, the new way
The first idea here is to remove as many abstractions as possible. Encoders
, Transforms
, and Splitters
are non-compressible core functionalities, same goes with DatasetHandler
classes.
Focusing on native python data structures (such as dictionaries, lists, etc.) will make the codebase readable (our contributors understand dictionaries, not necessarly DatasetManagers
).
Think about it in this way: the more concepts contributors need to learn, the more likely they will quit before the first PR.
For this, we need to rethink the way we do I/O management in between our elements. Config parsing has to be outsourced to a dedicated module, which shall output simple, native python objects. DatasetHandler
classes will expect those objects as input, and will not do the parsing themselves (one class should do one thing):
class DatasetLoader(DatasetHandler):
"""Class for loading dataset and passing it to the deep learning model."""
def __init__(
self,
encoders: dict[str, encoders_module.AbstractEncoder],
input_columns: list[str],
label_columns: list[str],
meta_columns: list[str],
csv_path: str,
split: Optional[int] = None,
) -> None:
...
Here, instead of having loaders and managers, we load needed information directly to the DatasetHandler
class, in simple, native python objects. For example, the class itself is not expected to find out which columns are input, labels or meta-data; this was done beforehand.
Notice as well, that the encoders are expected in a simple dictionary of format {column_name: encoder_instance}
. When needing to find an encoder for a specific column, we only need to access the dictionary with the name of the column that we want to encode as the key.
This allows us to completely decouple the DatasetHandler
class from the configuration file parsing. If the configuration file format changes, DatasetHandler does not care about it (as it always expects objects to be already parsed), which addresses point 2!
To further explain how we decouple the codebase, lets rewrite the TorchDataset
class:
class TorchDataset(torch.utils.data.Dataset):
"""Class for creating a torch dataset."""
def __init__(
self,
loader: DatasetLoader,
# loader is initialized outside of TorchDataset
) -> None:
"""Initialize the TorchDataset.
Args:
loader: A DatasetLoader instance
"""
self.loader = loader
def __len__(self) -> int:
return len(self.loader)
def __getitem__(self, idx: int) -> tuple[dict, dict, dict]:
return self.loader[idx]
Here, as long as DatasetLoader
implements the __getitem__
and __len__
methods, TorchDataset will work. Changing the inner working of DatasetLoader
will not affect TorchDataset
!
Altogether, those changes make the codebase intuitive, and the data flows streamlined, if we rebuild the class diagram, it would look like this:
Way better right ?
You can follow the refactoring effort on the project board.