Skip to content

data_handlers

This module provides classes for handling CSV data files in the STIMULUS format.

The module contains three main classes: - DatasetHandler: Base class for loading and managing CSV data - DatasetProcessor: Class for preprocessing data with transformations and splits - DatasetLoader: Class for loading processed data for model training

The data format consists of: 1. A CSV file containing the raw data 2. A YAML configuration file that defines: - Column names and their roles (input/label/meta) - Data types and encoders for each column - Transformations to apply (noise, augmentation, etc.) - Split configuration for train/val/test sets

The data handling pipeline consists of: 1. Loading raw CSV data according to the YAML config 2. Applying configured transformations 3. Splitting into train/val/test sets based on config 4. Encoding data for model training using specified encoders

See titanic.yaml in tests/test_data/titanic/ for an example configuration file format.

Classes:

  • DatasetHandler

    Main class for handling dataset loading, encoding, transformation and splitting.

  • DatasetLoader

    Class for loading dataset and passing it to the deep learning model.

  • DatasetManager

    Class for managing the dataset.

  • DatasetProcessor

    Class for loading dataset, applying transformations and splitting.

  • EncodeManager

    Manages the encoding of data columns using configured encoders.

  • SplitManager

    Class for managing the splitting.

  • TransformManager

    Class for managing the transformations.

DatasetHandler

DatasetHandler(config_path: str, csv_path: str)

Main class for handling dataset loading, encoding, transformation and splitting.

This class coordinates the interaction between different managers to process CSV datasets according to the provided configuration.

Attributes:

  • encoder_manager (EncodeManager) –

    Manager for handling data encoding operations.

  • transform_manager (TransformManager) –

    Manager for handling data transformations.

  • split_manager (SplitManager) –

    Manager for handling dataset splitting.

  • dataset_manager (DatasetManager) –

    Manager for organizing dataset columns and config.

Parameters:

  • config_path (str) –

    Path to the dataset configuration file.

  • csv_path (str) –

    Path to the CSV data file.

Methods:

  • load_csv

    Load the CSV file into a polars DataFrame.

  • read_csv_header

    Get the column names from the header of the CSV file.

  • save

    Saves the data to a csv file.

  • select_columns

    Select specific columns from the DataFrame and return as a dictionary.

Source code in src/stimulus/data/data_handlers.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def __init__(
    self,
    config_path: str,
    csv_path: str,
) -> None:
    """Initialize the DatasetHandler with required config.

    Args:
        config_path (str): Path to the dataset configuration file.
        csv_path (str): Path to the CSV data file.
    """
    self.dataset_manager = DatasetManager(config_path)
    self.columns = self.read_csv_header(csv_path)
    self.data = self.load_csv(csv_path)

load_csv

load_csv(csv_path: str) -> DataFrame

Load the CSV file into a polars DataFrame.

Parameters:

  • csv_path (str) –

    Path to the CSV file to load.

Returns:

  • DataFrame

    pl.DataFrame: Polars DataFrame containing the loaded CSV data.

Source code in src/stimulus/data/data_handlers.py
315
316
317
318
319
320
321
322
323
324
def load_csv(self, csv_path: str) -> pl.DataFrame:
    """Load the CSV file into a polars DataFrame.

    Args:
        csv_path (str): Path to the CSV file to load.

    Returns:
        pl.DataFrame: Polars DataFrame containing the loaded CSV data.
    """
    return pl.read_csv(csv_path)

read_csv_header

read_csv_header(csv_path: str) -> list

Get the column names from the header of the CSV file.

Parameters:

  • csv_path (str) –

    Path to the CSV file to read headers from.

Returns:

  • list ( list ) –

    List of column names from the CSV header.

Source code in src/stimulus/data/data_handlers.py
286
287
288
289
290
291
292
293
294
295
296
def read_csv_header(self, csv_path: str) -> list:
    """Get the column names from the header of the CSV file.

    Args:
        csv_path (str): Path to the CSV file to read headers from.

    Returns:
        list: List of column names from the CSV header.
    """
    with open(csv_path) as f:
        return f.readline().strip().split(",")

save

save(path: str) -> None

Saves the data to a csv file.

Source code in src/stimulus/data/data_handlers.py
326
327
328
def save(self, path: str) -> None:
    """Saves the data to a csv file."""
    self.data.write_csv(path)

select_columns

select_columns(columns: list) -> dict

Select specific columns from the DataFrame and return as a dictionary.

Parameters:

  • columns (list) –

    List of column names to select.

Returns:

  • dict ( dict ) –

    A dictionary where keys are column names and values are lists containing the column data.

Example

handler = DatasetHandler(...) data_dict = handler.select_columns(["col1", "col2"])

Returns

Source code in src/stimulus/data/data_handlers.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def select_columns(self, columns: list) -> dict:
    """Select specific columns from the DataFrame and return as a dictionary.

    Args:
        columns (list): List of column names to select.

    Returns:
        dict: A dictionary where keys are column names and values are lists containing the column data.

    Example:
        >>> handler = DatasetHandler(...)
        >>> data_dict = handler.select_columns(["col1", "col2"])
        >>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]}
    """
    df = self.data.select(columns)
    return {col: df[col].to_list() for col in columns}

DatasetLoader

DatasetLoader(
    config_path: str,
    csv_path: str,
    encoder_loader: EncoderLoader,
    split: Union[int, None] = None,
)

Bases: DatasetHandler

Class for loading dataset and passing it to the deep learning model.

Methods:

  • get_all_items

    Get the full dataset as three separate dictionaries for inputs, labels and metadata.

  • get_all_items_and_length

    Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data.

  • load_csv

    Load the CSV file into a polars DataFrame.

  • load_csv_per_split

    Load the part of csv file that has the specified split value.

  • read_csv_header

    Get the column names from the header of the CSV file.

  • save

    Saves the data to a csv file.

  • select_columns

    Select specific columns from the DataFrame and return as a dictionary.

Source code in src/stimulus/data/data_handlers.py
395
396
397
398
399
400
401
402
403
404
405
def __init__(
    self,
    config_path: str,
    csv_path: str,
    encoder_loader: loaders.EncoderLoader,
    split: Union[int, None] = None,
) -> None:
    """Initialize the DatasetLoader."""
    super().__init__(config_path, csv_path)
    self.encoder_manager = EncodeManager(encoder_loader)
    self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path)

get_all_items

get_all_items() -> tuple[dict, dict, dict]

Get the full dataset as three separate dictionaries for inputs, labels and metadata.

Returns:

  • tuple[dict, dict, dict]

    tuple[dict, dict, dict]: Three dictionaries containing: - Input dictionary mapping input column names to encoded input data - Label dictionary mapping label column names to encoded label data - Meta dictionary mapping meta column names to meta data

Example

handler = DatasetHandler(...) input_dict, label_dict, meta_dict = handler.get_dataset() print(input_dict.keys()) dict_keys(['age', 'fare']) print(label_dict.keys()) dict_keys(['survived']) print(meta_dict.keys()) dict_keys(['passenger_id'])

Source code in src/stimulus/data/data_handlers.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
def get_all_items(self) -> tuple[dict, dict, dict]:
    """Get the full dataset as three separate dictionaries for inputs, labels and metadata.

    Returns:
        tuple[dict, dict, dict]: Three dictionaries containing:
            - Input dictionary mapping input column names to encoded input data
            - Label dictionary mapping label column names to encoded label data
            - Meta dictionary mapping meta column names to meta data

    Example:
        >>> handler = DatasetHandler(...)
        >>> input_dict, label_dict, meta_dict = handler.get_dataset()
        >>> print(input_dict.keys())
        dict_keys(['age', 'fare'])
        >>> print(label_dict.keys())
        dict_keys(['survived'])
        >>> print(meta_dict.keys())
        dict_keys(['passenger_id'])
    """
    input_columns, label_columns, meta_columns = (
        self.dataset_manager.column_categories["input"],
        self.dataset_manager.column_categories["label"],
        self.dataset_manager.column_categories["meta"],
    )
    input_data = self.encoder_manager.encode_dataframe(self.data[input_columns])
    label_data = self.encoder_manager.encode_dataframe(self.data[label_columns])
    meta_data = {key: self.data[key].to_list() for key in meta_columns}
    return input_data, label_data, meta_data

get_all_items_and_length

get_all_items_and_length() -> (
    tuple[tuple[dict, dict, dict], int]
)

Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data.

Source code in src/stimulus/data/data_handlers.py
436
437
438
def get_all_items_and_length(self) -> tuple[tuple[dict, dict, dict], int]:
    """Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data."""
    return self.get_all_items(), len(self.data)

load_csv

load_csv(csv_path: str) -> DataFrame

Load the CSV file into a polars DataFrame.

Parameters:

  • csv_path (str) –

    Path to the CSV file to load.

Returns:

  • DataFrame

    pl.DataFrame: Polars DataFrame containing the loaded CSV data.

Source code in src/stimulus/data/data_handlers.py
315
316
317
318
319
320
321
322
323
324
def load_csv(self, csv_path: str) -> pl.DataFrame:
    """Load the CSV file into a polars DataFrame.

    Args:
        csv_path (str): Path to the CSV file to load.

    Returns:
        pl.DataFrame: Polars DataFrame containing the loaded CSV data.
    """
    return pl.read_csv(csv_path)

load_csv_per_split

load_csv_per_split(csv_path: str, split: int) -> DataFrame

Load the part of csv file that has the specified split value.

Split is a number that for 0 is train, 1 is validation, 2 is test. This is accessed through the column with category split. Example column name could be split:split:int.

NOTE that the aim of having this function is that depending on the training, validation and test scenarios, we are gonna load only the relevant data for it.

Source code in src/stimulus/data/data_handlers.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
    """Load the part of csv file that has the specified split value.

    Split is a number that for 0 is train, 1 is validation, 2 is test.
    This is accessed through the column with category `split`. Example column name could be `split:split:int`.

    NOTE that the aim of having this function is that depending on the training, validation and test scenarios,
    we are gonna load only the relevant data for it.
    """
    if "split" not in self.columns:
        raise ValueError("The category split is not present in the csv file")
    if split not in [0, 1, 2]:
        raise ValueError(f"The split value should be 0, 1 or 2. The specified split value is {split}")
    return pl.scan_csv(csv_path).filter(pl.col("split") == split).collect()

read_csv_header

read_csv_header(csv_path: str) -> list

Get the column names from the header of the CSV file.

Parameters:

  • csv_path (str) –

    Path to the CSV file to read headers from.

Returns:

  • list ( list ) –

    List of column names from the CSV header.

Source code in src/stimulus/data/data_handlers.py
286
287
288
289
290
291
292
293
294
295
296
def read_csv_header(self, csv_path: str) -> list:
    """Get the column names from the header of the CSV file.

    Args:
        csv_path (str): Path to the CSV file to read headers from.

    Returns:
        list: List of column names from the CSV header.
    """
    with open(csv_path) as f:
        return f.readline().strip().split(",")

save

save(path: str) -> None

Saves the data to a csv file.

Source code in src/stimulus/data/data_handlers.py
326
327
328
def save(self, path: str) -> None:
    """Saves the data to a csv file."""
    self.data.write_csv(path)

select_columns

select_columns(columns: list) -> dict

Select specific columns from the DataFrame and return as a dictionary.

Parameters:

  • columns (list) –

    List of column names to select.

Returns:

  • dict ( dict ) –

    A dictionary where keys are column names and values are lists containing the column data.

Example

handler = DatasetHandler(...) data_dict = handler.select_columns(["col1", "col2"])

Returns

Source code in src/stimulus/data/data_handlers.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def select_columns(self, columns: list) -> dict:
    """Select specific columns from the DataFrame and return as a dictionary.

    Args:
        columns (list): List of column names to select.

    Returns:
        dict: A dictionary where keys are column names and values are lists containing the column data.

    Example:
        >>> handler = DatasetHandler(...)
        >>> data_dict = handler.select_columns(["col1", "col2"])
        >>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]}
    """
    df = self.data.select(columns)
    return {col: df[col].to_list() for col in columns}

DatasetManager

DatasetManager(config_path: str)

Class for managing the dataset.

This class handles loading and organizing dataset configuration from YAML files. It manages column categorization into input, label and meta types based on the config.

Attributes:

  • config (dict) –

    The loaded configuration dictionary from YAML

  • column_categories (dict) –

    Dictionary mapping column types to lists of column names

Methods:

  • _load_config

    str) -> dict: Loads the config from a YAML file.

  • categorize_columns_by_type

    Organizes the columns into input, label, meta based on the config.

Methods:

Source code in src/stimulus/data/data_handlers.py
51
52
53
54
55
56
57
def __init__(
    self,
    config_path: str,
) -> None:
    """Initialize the DatasetManager."""
    self.config = self._load_config(config_path)
    self.column_categories = self.categorize_columns_by_type()

categorize_columns_by_type

categorize_columns_by_type() -> dict

Organizes columns from config into input, label, and meta categories.

Reads the column definitions from the config and sorts them into categories based on their column_type field.

Returns:

  • dict ( dict ) –

    Dictionary containing lists of column names for each category: { "input": ["col1", "col2"], # Input columns "label": ["target"], # Label/output columns "meta": ["id"] # Metadata columns }

Example

manager = DatasetManager("config.yaml") categories = manager.categorize_columns_by_type() print(categories) { 'input': ['hello', 'bonjour'], 'label': ['ciao'], 'meta': ["id"] }

Source code in src/stimulus/data/data_handlers.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def categorize_columns_by_type(self) -> dict:
    """Organizes columns from config into input, label, and meta categories.

    Reads the column definitions from the config and sorts them into categories
    based on their column_type field.

    Returns:
        dict: Dictionary containing lists of column names for each category:
            {
                "input": ["col1", "col2"],  # Input columns
                "label": ["target"],        # Label/output columns
                "meta": ["id"]     # Metadata columns
            }

    Example:
        >>> manager = DatasetManager("config.yaml")
        >>> categories = manager.categorize_columns_by_type()
        >>> print(categories)
        {
            'input': ['hello', 'bonjour'],
            'label': ['ciao'],
            'meta': ["id"]
        }
    """
    input_columns = []
    label_columns = []
    meta_columns = []
    for column in self.config.columns:
        if column.column_type == "input":
            input_columns.append(column.column_name)
        elif column.column_type == "label":
            label_columns.append(column.column_name)
        elif column.column_type == "meta":
            meta_columns.append(column.column_name)

    return {"input": input_columns, "label": label_columns, "meta": meta_columns}

get_split_columns

get_split_columns() -> list[str]

Get the columns that are used for splitting.

Source code in src/stimulus/data/data_handlers.py
114
115
116
def get_split_columns(self) -> list[str]:
    """Get the columns that are used for splitting."""
    return self.config.split.split_input_columns

get_transform_logic

get_transform_logic() -> dict

Get the transformation logic.

Returns a dictionary in the following structure : { "transformation_name": str, "transformations": list[tuple[str, str, dict]] }

Source code in src/stimulus/data/data_handlers.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def get_transform_logic(self) -> dict:
    """Get the transformation logic.

    Returns a dictionary in the following structure :
    {
        "transformation_name": str,
        "transformations": list[tuple[str, str, dict]]
    }
    """
    transformation_logic = {
        "transformation_name": self.config.transforms.transformation_name,
        "transformations": [],
    }
    for column in self.config.transforms.columns:
        for transformation in column.transformations:
            transformation_logic["transformations"].append(
                (column.column_name, transformation.name, transformation.params),
            )
    return transformation_logic

DatasetProcessor

DatasetProcessor(config_path: str, csv_path: str)

Bases: DatasetHandler

Class for loading dataset, applying transformations and splitting.

Methods:

  • add_split

    Add a column specifying the train, validation, test splits of the data.

  • apply_transformation_group

    Apply the transformation group to the data.

  • load_csv

    Load the CSV file into a polars DataFrame.

  • read_csv_header

    Get the column names from the header of the CSV file.

  • save

    Saves the data to a csv file.

  • select_columns

    Select specific columns from the DataFrame and return as a dictionary.

  • shuffle_labels

    Shuffles the labels in the data.

Source code in src/stimulus/data/data_handlers.py
334
335
336
def __init__(self, config_path: str, csv_path: str) -> None:
    """Initialize the DatasetProcessor."""
    super().__init__(config_path, csv_path)

add_split

add_split(
    split_manager: SplitManager, *, force: bool = False
) -> None

Add a column specifying the train, validation, test splits of the data.

An error exception is raised if the split column is already present in the csv file. This behaviour can be overriden by setting force=True.

Parameters:

  • split_manager (SplitManager) –

    Manager for handling dataset splitting

  • force (bool, default: False ) –

    If True, the split column present in the csv file will be overwritten.

Source code in src/stimulus/data/data_handlers.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None:
    """Add a column specifying the train, validation, test splits of the data.

    An error exception is raised if the split column is already present in the csv file. This behaviour can be overriden by setting force=True.

    Args:
        split_manager (SplitManager): Manager for handling dataset splitting
        force (bool): If True, the split column present in the csv file will be overwritten.
    """
    if ("split" in self.columns) and (not force):
        raise ValueError(
            "The category split is already present in the csv file. If you want to still use this function, set force=True",
        )
    # get relevant split columns from the dataset_manager
    split_columns = self.dataset_manager.get_split_columns()
    split_input_data = self.select_columns(split_columns)

    # get the split indices
    train, validation, test = split_manager.get_split_indices(split_input_data)

    # add the split column to the data
    split_column = np.full(len(self.data), -1).astype(int)
    split_column[train] = 0
    split_column[validation] = 1
    split_column[test] = 2
    self.data = self.data.with_columns(pl.Series("split", split_column))

    if "split" not in self.columns:
        self.columns.append("split")

apply_transformation_group

apply_transformation_group(
    transform_manager: TransformManager,
) -> None

Apply the transformation group to the data.

Source code in src/stimulus/data/data_handlers.py
368
369
370
371
372
373
374
375
376
377
378
379
380
def apply_transformation_group(self, transform_manager: TransformManager) -> None:
    """Apply the transformation group to the data."""
    for column_name, transform_name, _params in self.dataset_manager.get_transform_logic()["transformations"]:
        transformed_data, add_row = transform_manager.transform_column(
            column_name,
            transform_name,
            self.data[column_name],
        )
        if add_row:
            new_rows = self.data.with_columns(pl.Series(column_name, transformed_data))
            self.data = pl.vstack(self.data, new_rows)
        else:
            self.data = self.data.with_columns(pl.Series(column_name, transformed_data))

load_csv

load_csv(csv_path: str) -> DataFrame

Load the CSV file into a polars DataFrame.

Parameters:

  • csv_path (str) –

    Path to the CSV file to load.

Returns:

  • DataFrame

    pl.DataFrame: Polars DataFrame containing the loaded CSV data.

Source code in src/stimulus/data/data_handlers.py
315
316
317
318
319
320
321
322
323
324
def load_csv(self, csv_path: str) -> pl.DataFrame:
    """Load the CSV file into a polars DataFrame.

    Args:
        csv_path (str): Path to the CSV file to load.

    Returns:
        pl.DataFrame: Polars DataFrame containing the loaded CSV data.
    """
    return pl.read_csv(csv_path)

read_csv_header

read_csv_header(csv_path: str) -> list

Get the column names from the header of the CSV file.

Parameters:

  • csv_path (str) –

    Path to the CSV file to read headers from.

Returns:

  • list ( list ) –

    List of column names from the CSV header.

Source code in src/stimulus/data/data_handlers.py
286
287
288
289
290
291
292
293
294
295
296
def read_csv_header(self, csv_path: str) -> list:
    """Get the column names from the header of the CSV file.

    Args:
        csv_path (str): Path to the CSV file to read headers from.

    Returns:
        list: List of column names from the CSV header.
    """
    with open(csv_path) as f:
        return f.readline().strip().split(",")

save

save(path: str) -> None

Saves the data to a csv file.

Source code in src/stimulus/data/data_handlers.py
326
327
328
def save(self, path: str) -> None:
    """Saves the data to a csv file."""
    self.data.write_csv(path)

select_columns

select_columns(columns: list) -> dict

Select specific columns from the DataFrame and return as a dictionary.

Parameters:

  • columns (list) –

    List of column names to select.

Returns:

  • dict ( dict ) –

    A dictionary where keys are column names and values are lists containing the column data.

Example

handler = DatasetHandler(...) data_dict = handler.select_columns(["col1", "col2"])

Returns

Source code in src/stimulus/data/data_handlers.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def select_columns(self, columns: list) -> dict:
    """Select specific columns from the DataFrame and return as a dictionary.

    Args:
        columns (list): List of column names to select.

    Returns:
        dict: A dictionary where keys are column names and values are lists containing the column data.

    Example:
        >>> handler = DatasetHandler(...)
        >>> data_dict = handler.select_columns(["col1", "col2"])
        >>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]}
    """
    df = self.data.select(columns)
    return {col: df[col].to_list() for col in columns}

shuffle_labels

shuffle_labels(seed: Optional[float] = None) -> None

Shuffles the labels in the data.

Source code in src/stimulus/data/data_handlers.py
382
383
384
385
386
387
388
389
def shuffle_labels(self, seed: Optional[float] = None) -> None:
    """Shuffles the labels in the data."""
    # set the np seed
    np.random.seed(seed)

    label_keys = self.dataset_manager.column_categories["label"]
    for key in label_keys:
        self.data = self.data.with_columns(pl.Series(key, np.random.permutation(list(self.data[key]))))

EncodeManager

EncodeManager(encoder_loader: EncoderLoader)

Manages the encoding of data columns using configured encoders.

This class handles encoding of data columns based on the encoders specified in the configuration. It uses an EncoderLoader to get the appropriate encoder for each column and applies the encoding.

Attributes:

  • encoder_loader (EncoderLoader) –

    Loader that provides encoders based on config.

Example

encoder_loader = EncoderLoader(config) encode_manager = EncodeManager(encoder_loader) data = ["ACGT", "TGCA", "GCTA"] encoded = encode_manager.encode_column("dna_seq", data) print(encoded.shape) torch.Size([3, 4, 4]) # 3 sequences, length 4, one-hot encoded

Parameters:

  • encoder_loader (EncoderLoader) –

    Loader that provides encoders based on configuration.

Methods:

  • encode_column

    Encodes a column of data using the configured encoder.

  • encode_columns

    Encodes multiple columns of data using the configured encoders.

  • encode_dataframe

    Encode the dataframe using the encoders.

Source code in src/stimulus/data/data_handlers.py
158
159
160
161
162
163
164
165
166
167
def __init__(
    self,
    encoder_loader: loaders.EncoderLoader,
) -> None:
    """Initialize the EncodeManager.

    Args:
        encoder_loader: Loader that provides encoders based on configuration.
    """
    self.encoder_loader = encoder_loader

encode_column

encode_column(
    column_name: str, column_data: list
) -> Tensor

Encodes a column of data using the configured encoder.

Gets the appropriate encoder for the column from the encoder_loader and uses it to encode all the data in the column.

Parameters:

  • column_name (str) –

    Name of the column to encode.

  • column_data (list) –

    List of data values from the column to encode.

Returns:

  • Tensor

    Encoded data as a torch.Tensor. The exact shape depends on the encoder used.

Example

data = ["ACGT", "TGCA"] encoded = encode_manager.encode_column("dna_seq", data) print(encoded.shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded

Source code in src/stimulus/data/data_handlers.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def encode_column(self, column_name: str, column_data: list) -> torch.Tensor:
    """Encodes a column of data using the configured encoder.

    Gets the appropriate encoder for the column from the encoder_loader and uses it
    to encode all the data in the column.

    Args:
        column_name: Name of the column to encode.
        column_data: List of data values from the column to encode.

    Returns:
        Encoded data as a torch.Tensor. The exact shape depends on the encoder used.

    Example:
        >>> data = ["ACGT", "TGCA"]
        >>> encoded = encode_manager.encode_column("dna_seq", data)
        >>> print(encoded.shape)
        torch.Size([2, 4, 4])  # 2 sequences, length 4, one-hot encoded
    """
    encode_all_function = self.encoder_loader.get_function_encode_all(column_name)
    return encode_all_function(column_data)

encode_columns

encode_columns(column_data: dict) -> dict

Encodes multiple columns of data using the configured encoders.

Gets the appropriate encoder for each column from the encoder_loader and encodes all data values in those columns.

Parameters:

  • column_data (dict) –

    Dict mapping column names to lists of data values to encode.

Returns:

  • dict

    Dict mapping column names to their encoded tensors. The exact shape of each

  • dict

    tensor depends on the encoder used for that column.

Example

data = {"dna_seq": ["ACGT", "TGCA"], "labels": ["1", "2"]} encoded = encode_manager.encode_columns(data) print(encoded["dna_seq"].shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded

Source code in src/stimulus/data/data_handlers.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def encode_columns(self, column_data: dict) -> dict:
    """Encodes multiple columns of data using the configured encoders.

    Gets the appropriate encoder for each column from the encoder_loader and encodes
    all data values in those columns.

    Args:
        column_data: Dict mapping column names to lists of data values to encode.

    Returns:
        Dict mapping column names to their encoded tensors. The exact shape of each
        tensor depends on the encoder used for that column.

    Example:
        >>> data = {"dna_seq": ["ACGT", "TGCA"], "labels": ["1", "2"]}
        >>> encoded = encode_manager.encode_columns(data)
        >>> print(encoded["dna_seq"].shape)
        torch.Size([2, 4, 4])  # 2 sequences, length 4, one-hot encoded
    """
    return {col: self.encode_column(col, values) for col, values in column_data.items()}

encode_dataframe

encode_dataframe(dataframe: DataFrame) -> dict[str, Tensor]

Encode the dataframe using the encoders.

Source code in src/stimulus/data/data_handlers.py
212
213
214
def encode_dataframe(self, dataframe: pl.DataFrame) -> dict[str, torch.Tensor]:
    """Encode the dataframe using the encoders."""
    return {col: self.encode_column(col, dataframe[col].to_list()) for col in dataframe.columns}

SplitManager

SplitManager(split_loader: SplitLoader)

Class for managing the splitting.

Methods:

Source code in src/stimulus/data/data_handlers.py
246
247
248
249
250
251
def __init__(
    self,
    split_loader: loaders.SplitLoader,
) -> None:
    """Initialize the SplitManager."""
    self.split_loader = split_loader

get_split_indices

get_split_indices(
    data: dict,
) -> tuple[ndarray, ndarray, ndarray]

Get the indices for train, validation, and test splits.

Source code in src/stimulus/data/data_handlers.py
253
254
255
def get_split_indices(self, data: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Get the indices for train, validation, and test splits."""
    return self.split_loader.get_function_split()(data)

TransformManager

TransformManager(transform_loader: TransformLoader)

Class for managing the transformations.

Methods:

  • transform_column

    Transform a column of data using the specified transformation.

Source code in src/stimulus/data/data_handlers.py
220
221
222
223
224
225
def __init__(
    self,
    transform_loader: loaders.TransformLoader,
) -> None:
    """Initialize the TransformManager."""
    self.transform_loader = transform_loader

transform_column

transform_column(
    column_name: str, transform_name: str, column_data: list
) -> tuple[list, bool]

Transform a column of data using the specified transformation.

Parameters:

  • column_name (str) –

    The name of the column to transform.

  • transform_name (str) –

    The name of the transformation to use.

  • column_data (list) –

    The data to transform.

Returns:

  • list ( list ) –

    The transformed data.

  • bool ( bool ) –

    Whether the transformation added new rows to the data.

Source code in src/stimulus/data/data_handlers.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def transform_column(self, column_name: str, transform_name: str, column_data: list) -> tuple[list, bool]:
    """Transform a column of data using the specified transformation.

    Args:
        column_name (str): The name of the column to transform.
        transform_name (str): The name of the transformation to use.
        column_data (list): The data to transform.

    Returns:
        list: The transformed data.
        bool: Whether the transformation added new rows to the data.
    """
    transformer = self.transform_loader.__getattribute__(column_name)[transform_name]
    return transformer.transform_all(column_data), transformer.add_row