Skip to content

data_handlers

This module provides classes for handling CSV data files in the STIMULUS format.

The module contains three main classes: - DatasetHandler: Base class for loading and managing CSV data - DatasetProcessor: Class for preprocessing data with transformations and splits - DatasetLoader: Class for loading processed data for model training

The data format consists of: 1. A CSV file containing the raw data 2. A YAML configuration file that defines: - Column names and their roles (input/label/meta) - Data types and encoders for each column - Transformations to apply (noise, augmentation, etc.) - Split configuration for train/val/test sets

The data handling pipeline consists of: 1. Loading raw CSV data according to the YAML config 2. Applying configured transformations 3. Splitting into train/val/test sets based on config 4. Encoding data for model training using specified encoders

See titanic.yaml in tests/test_data/titanic/ for an example configuration file format.

Classes:

  • DatasetHandler

    Main class for handling dataset loading, encoding, transformation and splitting.

  • DatasetLoader

    Class for loading dataset and passing it to the deep learning model.

  • DatasetProcessor

    Class for loading dataset, applying transformations and splitting.

  • TorchDataset

    Class for creating a torch dataset.

DatasetHandler

DatasetHandler(csv_path: str)

Main class for handling dataset loading, encoding, transformation and splitting.

This class coordinates the interaction between different managers to process CSV datasets according to the provided configuration.

Attributes:

  • encoder_manager (EncodeManager) –

    Manager for handling data encoding operations.

  • transform_manager (TransformManager) –

    Manager for handling data transformations.

  • split_manager (SplitManager) –

    Manager for handling dataset splitting.

Parameters:

  • csv_path (str) –

    Path to the CSV data file.

Methods:

  • load_csv

    Load the CSV file into a polars DataFrame.

  • read_csv_header

    Get the column names from the header of the CSV file.

  • save

    Saves the data to a csv file.

  • select_columns

    Select specific columns from the DataFrame and return as a dictionary.

Source code in src/stimulus/data/data_handlers.py
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    csv_path: str,
) -> None:
    """Initialize the DatasetHandler with required config.

    Args:
        csv_path (str): Path to the CSV data file.
    """
    self.columns = self.read_csv_header(csv_path)
    self.data = self.load_csv(csv_path)

load_csv

load_csv(csv_path: str) -> DataFrame

Load the CSV file into a polars DataFrame.

Parameters:

  • csv_path (str) –

    Path to the CSV file to load.

Returns:

  • DataFrame

    pl.DataFrame: Polars DataFrame containing the loaded CSV data.

Source code in src/stimulus/data/data_handlers.py
89
90
91
92
93
94
95
96
97
98
def load_csv(self, csv_path: str) -> pl.DataFrame:
    """Load the CSV file into a polars DataFrame.

    Args:
        csv_path (str): Path to the CSV file to load.

    Returns:
        pl.DataFrame: Polars DataFrame containing the loaded CSV data.
    """
    return pl.read_csv(csv_path)

read_csv_header

read_csv_header(csv_path: str) -> list

Get the column names from the header of the CSV file.

Parameters:

  • csv_path (str) –

    Path to the CSV file to read headers from.

Returns:

  • list ( list ) –

    List of column names from the CSV header.

Source code in src/stimulus/data/data_handlers.py
60
61
62
63
64
65
66
67
68
69
70
def read_csv_header(self, csv_path: str) -> list:
    """Get the column names from the header of the CSV file.

    Args:
        csv_path (str): Path to the CSV file to read headers from.

    Returns:
        list: List of column names from the CSV header.
    """
    with open(csv_path) as f:
        return f.readline().strip().split(",")

save

save(path: str) -> None

Saves the data to a csv file.

Source code in src/stimulus/data/data_handlers.py
100
101
102
def save(self, path: str) -> None:
    """Saves the data to a csv file."""
    self.data.write_csv(path)

select_columns

select_columns(columns: list) -> dict

Select specific columns from the DataFrame and return as a dictionary.

Parameters:

  • columns (list) –

    List of column names to select.

Returns:

  • dict ( dict ) –

    A dictionary where keys are column names and values are lists containing the column data.

Example

handler = DatasetHandler(...) data_dict = handler.select_columns(["col1", "col2"])

Returns

Source code in src/stimulus/data/data_handlers.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def select_columns(self, columns: list) -> dict:
    """Select specific columns from the DataFrame and return as a dictionary.

    Args:
        columns (list): List of column names to select.

    Returns:
        dict: A dictionary where keys are column names and values are lists containing the column data.

    Example:
        >>> handler = DatasetHandler(...)
        >>> data_dict = handler.select_columns(["col1", "col2"])
        >>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]}
    """
    df = self.data.select(columns)
    return {col: df[col].to_list() for col in columns}

DatasetLoader

DatasetLoader(
    encoders: dict[str, AbstractEncoder],
    input_columns: list[str],
    label_columns: list[str],
    meta_columns: list[str],
    csv_path: str,
    split: Optional[int] = None,
)

Bases: DatasetHandler

Class for loading dataset and passing it to the deep learning model.

Methods:

  • encode_dataframe

    Encode the dataframe columns using the configured encoders.

  • get_all_items

    Get the full dataset as three separate dictionaries for inputs, labels and metadata.

  • get_all_items_and_length

    Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data.

  • load_csv

    Load the CSV file into a polars DataFrame.

  • load_csv_per_split

    Load the part of csv file that has the specified split value.

  • read_csv_header

    Get the column names from the header of the CSV file.

  • save

    Saves the data to a csv file.

  • select_columns

    Select specific columns from the DataFrame and return as a dictionary.

Source code in src/stimulus/data/data_handlers.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def __init__(
    self,
    encoders: dict[str, encoders_module.AbstractEncoder],
    input_columns: list[str],
    label_columns: list[str],
    meta_columns: list[str],
    csv_path: str,
    split: Optional[int] = None,
) -> None:
    """Initialize the DatasetLoader."""
    super().__init__(csv_path)
    self.encoders = encoders
    self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path)
    self.input_columns = input_columns
    self.label_columns = label_columns
    self.meta_columns = meta_columns

encode_dataframe

encode_dataframe(dataframe: DataFrame) -> dict[str, Tensor]

Encode the dataframe columns using the configured encoders.

Takes a polars DataFrame and encodes each column using its corresponding encoder from self.encoders.

Parameters:

  • dataframe (DataFrame) –

    Polars DataFrame containing the columns to encode

Returns:

  • dict[str, Tensor]

    Dict mapping column names to their encoded tensors. The exact shape of each

  • dict[str, Tensor]

    tensor depends on the encoder used for that column.

Example

df = pl.DataFrame({"dna_seq": ["ACGT", "TGCA"], "labels": [1, 2]}) encoded = dataset_loader.encode_dataframe(df) print(encoded["dna_seq"].shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded

Source code in src/stimulus/data/data_handlers.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def encode_dataframe(self, dataframe: pl.DataFrame) -> dict[str, torch.Tensor]:
    """Encode the dataframe columns using the configured encoders.

    Takes a polars DataFrame and encodes each column using its corresponding encoder
    from self.encoders.

    Args:
        dataframe: Polars DataFrame containing the columns to encode

    Returns:
        Dict mapping column names to their encoded tensors. The exact shape of each
        tensor depends on the encoder used for that column.

    Example:
        >>> df = pl.DataFrame({"dna_seq": ["ACGT", "TGCA"], "labels": [1, 2]})
        >>> encoded = dataset_loader.encode_dataframe(df)
        >>> print(encoded["dna_seq"].shape)
        torch.Size([2, 4, 4])  # 2 sequences, length 4, one-hot encoded
    """
    return {col: self.encoders[col].encode_all(dataframe[col].to_list()) for col in dataframe.columns}

get_all_items

get_all_items() -> tuple[dict, dict, dict]

Get the full dataset as three separate dictionaries for inputs, labels and metadata.

Returns:

  • tuple[dict, dict, dict]

    tuple[dict, dict, dict]: Three dictionaries containing: - Input dictionary mapping input column names to encoded input data - Label dictionary mapping label column names to encoded label data - Meta dictionary mapping meta column names to meta data

Example

handler = DatasetHandler(...) input_dict, label_dict, meta_dict = handler.get_dataset() print(input_dict.keys()) dict_keys(['age', 'fare']) print(label_dict.keys()) dict_keys(['survived']) print(meta_dict.keys()) dict_keys(['passenger_id'])

Source code in src/stimulus/data/data_handlers.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_all_items(self) -> tuple[dict, dict, dict]:
    """Get the full dataset as three separate dictionaries for inputs, labels and metadata.

    Returns:
        tuple[dict, dict, dict]: Three dictionaries containing:
            - Input dictionary mapping input column names to encoded input data
            - Label dictionary mapping label column names to encoded label data
            - Meta dictionary mapping meta column names to meta data

    Example:
        >>> handler = DatasetHandler(...)
        >>> input_dict, label_dict, meta_dict = handler.get_dataset()
        >>> print(input_dict.keys())
        dict_keys(['age', 'fare'])
        >>> print(label_dict.keys())
        dict_keys(['survived'])
        >>> print(meta_dict.keys())
        dict_keys(['passenger_id'])
    """
    input_data = self.encode_dataframe(self.data[self.input_columns])
    label_data = self.encode_dataframe(self.data[self.label_columns])
    meta_data = {key: self.data[key].to_list() for key in self.meta_columns}
    return input_data, label_data, meta_data

get_all_items_and_length

get_all_items_and_length() -> (
    tuple[tuple[dict, dict, dict], int]
)

Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data.

Source code in src/stimulus/data/data_handlers.py
245
246
247
def get_all_items_and_length(self) -> tuple[tuple[dict, dict, dict], int]:
    """Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data."""
    return self.get_all_items(), len(self.data)

load_csv

load_csv(csv_path: str) -> DataFrame

Load the CSV file into a polars DataFrame.

Parameters:

  • csv_path (str) –

    Path to the CSV file to load.

Returns:

  • DataFrame

    pl.DataFrame: Polars DataFrame containing the loaded CSV data.

Source code in src/stimulus/data/data_handlers.py
89
90
91
92
93
94
95
96
97
98
def load_csv(self, csv_path: str) -> pl.DataFrame:
    """Load the CSV file into a polars DataFrame.

    Args:
        csv_path (str): Path to the CSV file to load.

    Returns:
        pl.DataFrame: Polars DataFrame containing the loaded CSV data.
    """
    return pl.read_csv(csv_path)

load_csv_per_split

load_csv_per_split(csv_path: str, split: int) -> DataFrame

Load the part of csv file that has the specified split value.

Split is a number that for 0 is train, 1 is validation, 2 is test.

Source code in src/stimulus/data/data_handlers.py
249
250
251
252
253
254
255
256
257
258
259
260
def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
    """Load the part of csv file that has the specified split value.

    Split is a number that for 0 is train, 1 is validation, 2 is test.
    """
    if "split" not in self.columns:
        raise ValueError("The category split is not present in the csv file")
    if split not in [0, 1, 2]:
        raise ValueError(
            f"The split value should be 0, 1 or 2. The specified split value is {split}",
        )
    return pl.scan_csv(csv_path).filter(pl.col("split") == split).collect()

read_csv_header

read_csv_header(csv_path: str) -> list

Get the column names from the header of the CSV file.

Parameters:

  • csv_path (str) –

    Path to the CSV file to read headers from.

Returns:

  • list ( list ) –

    List of column names from the CSV header.

Source code in src/stimulus/data/data_handlers.py
60
61
62
63
64
65
66
67
68
69
70
def read_csv_header(self, csv_path: str) -> list:
    """Get the column names from the header of the CSV file.

    Args:
        csv_path (str): Path to the CSV file to read headers from.

    Returns:
        list: List of column names from the CSV header.
    """
    with open(csv_path) as f:
        return f.readline().strip().split(",")

save

save(path: str) -> None

Saves the data to a csv file.

Source code in src/stimulus/data/data_handlers.py
100
101
102
def save(self, path: str) -> None:
    """Saves the data to a csv file."""
    self.data.write_csv(path)

select_columns

select_columns(columns: list) -> dict

Select specific columns from the DataFrame and return as a dictionary.

Parameters:

  • columns (list) –

    List of column names to select.

Returns:

  • dict ( dict ) –

    A dictionary where keys are column names and values are lists containing the column data.

Example

handler = DatasetHandler(...) data_dict = handler.select_columns(["col1", "col2"])

Returns

Source code in src/stimulus/data/data_handlers.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def select_columns(self, columns: list) -> dict:
    """Select specific columns from the DataFrame and return as a dictionary.

    Args:
        columns (list): List of column names to select.

    Returns:
        dict: A dictionary where keys are column names and values are lists containing the column data.

    Example:
        >>> handler = DatasetHandler(...)
        >>> data_dict = handler.select_columns(["col1", "col2"])
        >>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]}
    """
    df = self.data.select(columns)
    return {col: df[col].to_list() for col in columns}

DatasetProcessor

DatasetProcessor(
    csv_path: str,
    transforms: dict[str, list[AbstractTransform]],
    split_columns: list[str],
    splitter: AbstractSplitter,
)

Bases: DatasetHandler

Class for loading dataset, applying transformations and splitting.

Methods:

  • add_split

    Add a column specifying the train, validation, test splits of the data.

  • apply_transformations

    Apply the transformation group to the data.

  • load_csv

    Load the CSV file into a polars DataFrame.

  • read_csv_header

    Get the column names from the header of the CSV file.

  • save

    Saves the data to a csv file.

  • select_columns

    Select specific columns from the DataFrame and return as a dictionary.

  • shuffle_labels

    Shuffles the labels in the data.

Source code in src/stimulus/data/data_handlers.py
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    csv_path: str,
    transforms: dict[str, list[transforms_module.AbstractTransform]],
    split_columns: list[str],
    splitter: splitters_module.AbstractSplitter,
) -> None:
    """Initialize the DatasetProcessor."""
    super().__init__(csv_path)
    self.transforms = transforms
    self.split_columns = split_columns
    self.splitter = splitter

add_split

add_split(*, force: bool = False) -> None

Add a column specifying the train, validation, test splits of the data.

An error exception is raised if the split column is already present in the csv file. This behaviour can be overriden by setting force=True.

Parameters:

  • force (bool, default: False ) –

    If True, the split column present in the csv file will be overwritten.

Source code in src/stimulus/data/data_handlers.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def add_split(self, *, force: bool = False) -> None:
    """Add a column specifying the train, validation, test splits of the data.

    An error exception is raised if the split column is already present in the csv file. This behaviour can be overriden by setting force=True.

    Args:
        force (bool): If True, the split column present in the csv file will be overwritten.
    """
    if ("split" in self.columns) and (not force):
        raise ValueError(
            "The category split is already present in the csv file. If you want to still use this function, set force=True",
        )
    # get relevant split columns from the dataset_manager
    split_input_data = self.select_columns(self.split_columns)

    # get the split indices
    train, validation, test = self.splitter.get_split_indexes(split_input_data)

    # add the split column to the data
    split_column = np.full(len(self.data), -1).astype(int)
    split_column[train] = 0
    split_column[validation] = 1
    split_column[test] = 2
    self.data = self.data.with_columns(pl.Series("split", split_column))

    if "split" not in self.columns:
        self.columns.append("split")

apply_transformations

apply_transformations() -> None

Apply the transformation group to the data.

Applies all transformations defined in self.transforms to their corresponding columns. Each column can have multiple transformations that are applied sequentially.

Source code in src/stimulus/data/data_handlers.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def apply_transformations(self) -> None:
    """Apply the transformation group to the data.

    Applies all transformations defined in self.transforms to their corresponding columns.
    Each column can have multiple transformations that are applied sequentially.
    """
    for column_name, transforms_list in self.transforms.items():
        for transform in transforms_list:
            transformed_data = transform.transform_all(self.data[column_name].to_list())

            if transform.add_row:
                new_rows = self.data.with_columns(
                    pl.Series(column_name, transformed_data),
                )
                self.data = pl.concat([self.data, new_rows], how="vertical")
            else:
                self.data = self.data.with_columns(
                    pl.Series(column_name, transformed_data),
                )

load_csv

load_csv(csv_path: str) -> DataFrame

Load the CSV file into a polars DataFrame.

Parameters:

  • csv_path (str) –

    Path to the CSV file to load.

Returns:

  • DataFrame

    pl.DataFrame: Polars DataFrame containing the loaded CSV data.

Source code in src/stimulus/data/data_handlers.py
89
90
91
92
93
94
95
96
97
98
def load_csv(self, csv_path: str) -> pl.DataFrame:
    """Load the CSV file into a polars DataFrame.

    Args:
        csv_path (str): Path to the CSV file to load.

    Returns:
        pl.DataFrame: Polars DataFrame containing the loaded CSV data.
    """
    return pl.read_csv(csv_path)

read_csv_header

read_csv_header(csv_path: str) -> list

Get the column names from the header of the CSV file.

Parameters:

  • csv_path (str) –

    Path to the CSV file to read headers from.

Returns:

  • list ( list ) –

    List of column names from the CSV header.

Source code in src/stimulus/data/data_handlers.py
60
61
62
63
64
65
66
67
68
69
70
def read_csv_header(self, csv_path: str) -> list:
    """Get the column names from the header of the CSV file.

    Args:
        csv_path (str): Path to the CSV file to read headers from.

    Returns:
        list: List of column names from the CSV header.
    """
    with open(csv_path) as f:
        return f.readline().strip().split(",")

save

save(path: str) -> None

Saves the data to a csv file.

Source code in src/stimulus/data/data_handlers.py
100
101
102
def save(self, path: str) -> None:
    """Saves the data to a csv file."""
    self.data.write_csv(path)

select_columns

select_columns(columns: list) -> dict

Select specific columns from the DataFrame and return as a dictionary.

Parameters:

  • columns (list) –

    List of column names to select.

Returns:

  • dict ( dict ) –

    A dictionary where keys are column names and values are lists containing the column data.

Example

handler = DatasetHandler(...) data_dict = handler.select_columns(["col1", "col2"])

Returns

Source code in src/stimulus/data/data_handlers.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def select_columns(self, columns: list) -> dict:
    """Select specific columns from the DataFrame and return as a dictionary.

    Args:
        columns (list): List of column names to select.

    Returns:
        dict: A dictionary where keys are column names and values are lists containing the column data.

    Example:
        >>> handler = DatasetHandler(...)
        >>> data_dict = handler.select_columns(["col1", "col2"])
        >>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]}
    """
    df = self.data.select(columns)
    return {col: df[col].to_list() for col in columns}

shuffle_labels

shuffle_labels(
    label_columns: list[str], seed: Optional[float] = None
) -> None

Shuffles the labels in the data.

Source code in src/stimulus/data/data_handlers.py
169
170
171
172
173
174
175
176
177
def shuffle_labels(self, label_columns: list[str], seed: Optional[float] = None) -> None:
    """Shuffles the labels in the data."""
    # set the np seed
    np.random.seed(seed)

    for key in label_columns:
        self.data = self.data.with_columns(
            pl.Series(key, np.random.permutation(list(self.data[key]))),
        )

TorchDataset

TorchDataset(loader: DatasetLoader)

Bases: Dataset

Class for creating a torch dataset.

Parameters:

Source code in src/stimulus/data/data_handlers.py
321
322
323
324
325
326
327
328
329
330
def __init__(
    self,
    loader: DatasetLoader,
) -> None:
    """Initialize the TorchDataset.

    Args:
        loader: A DatasetLoader instance
    """
    self.loader = loader