Skip to content

splitters

This file contains the splitter classes for splitting data accordingly.

Classes:

AbstractSplitter

AbstractSplitter(seed: float = 42)

Bases: ABC

Abstract class for splitters.

A splitter splits the data into train, validation, and test sets.

Methods:

  • get_split_indexes

    calculates split indices for the data

  • distance

    calculates the distance between two elements of the data

Parameters:

  • seed (float, default: 42 ) –

    Random seed for reproducibility

Source code in src/stimulus/data/splitting/splitters.py
22
23
24
25
26
27
28
def __init__(self, seed: float = 42) -> None:
    """Initialize the splitter.

    Args:
        seed: Random seed for reproducibility
    """
    self.seed = seed

distance abstractmethod

distance(data_one: Any, data_two: Any) -> float

Calculates the distance between two elements of the data.

This is an abstract method that should be implemented by the child class.

Parameters:

  • data_one (Any) –

    the first data point

  • data_two (Any) –

    the second data point

Returns:

  • distance ( float ) –

    the distance between the two data points

Source code in src/stimulus/data/splitting/splitters.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@abstractmethod
def distance(self, data_one: Any, data_two: Any) -> float:
    """Calculates the distance between two elements of the data.

    This is an abstract method that should be implemented by the child class.

    Args:
        data_one (Any): the first data point
        data_two (Any): the second data point

    Returns:
        distance (float): the distance between the two data points
    """
    raise NotImplementedError

get_split_indexes abstractmethod

get_split_indexes(data: dict) -> tuple[list, list, list]

Splits the data. Always return indices mapping to the original list.

This is an abstract method that should be implemented by the child class.

Parameters:

  • data (DataFrame) –

    the data to be split

Returns:

  • split_indices ( list ) –

    the indices for train, validation, and test sets

Source code in src/stimulus/data/splitting/splitters.py
30
31
32
33
34
35
36
37
38
39
40
41
42
@abstractmethod
def get_split_indexes(self, data: dict) -> tuple[list, list, list]:
    """Splits the data. Always return indices mapping to the original list.

    This is an abstract method that should be implemented by the child class.

    Args:
        data (pl.DataFrame): the data to be split

    Returns:
        split_indices (list): the indices for train, validation, and test sets
    """
    raise NotImplementedError

RandomSplit

RandomSplit(split: Optional[list] = None, seed: int = 42)

Bases: AbstractSplitter

This splitter randomly splits the data.

Parameters:

  • split (Optional[list], default: None ) –

    List of proportions for train/val/test splits

  • seed (int, default: 42 ) –

    Random seed for reproducibility

Methods:

  • distance

    Calculate distance between two data points.

  • get_split_indexes

    Splits the data indices into train, validation, and test sets.

Source code in src/stimulus/data/splitting/splitters.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(self, split: Optional[list] = None, seed: int = 42) -> None:
    """Initialize the random splitter.

    Args:
        split: List of proportions for train/val/test splits
        seed: Random seed for reproducibility
    """
    super().__init__()
    self.split = [0.7, 0.2, 0.1] if split is None else split
    self.seed = seed
    if len(self.split) != SPLIT_SIZE:
        raise ValueError(
            "The split argument should be a list with length 3 that contains the proportions for [train, validation, test] splits.",
        )

distance

distance(data_one: Any, data_two: Any) -> float

Calculate distance between two data points.

Parameters:

  • data_one (Any) –

    First data point

  • data_two (Any) –

    Second data point

Returns:

  • float

    Distance between the points

Source code in src/stimulus/data/splitting/splitters.py
123
124
125
126
127
128
129
130
131
132
133
def distance(self, data_one: Any, data_two: Any) -> float:
    """Calculate distance between two data points.

    Args:
        data_one: First data point
        data_two: Second data point

    Returns:
        Distance between the points
    """
    raise NotImplementedError

get_split_indexes

get_split_indexes(data: dict) -> tuple[list, list, list]

Splits the data indices into train, validation, and test sets.

One can use these lists of indices to parse the data afterwards.

Parameters:

  • data (dict) –

    Dictionary mapping column names to lists of data values.

Returns:

  • train ( list ) –

    The indices for the training set.

  • validation ( list ) –

    The indices for the validation set.

  • test ( list ) –

    The indices for the test set.

Raises:

  • ValueError

    If the split argument is not a list with length 3.

  • ValueError

    If the sum of the split proportions is not 1.

Source code in src/stimulus/data/splitting/splitters.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def get_split_indexes(
    self,
    data: dict,
) -> tuple[list, list, list]:
    """Splits the data indices into train, validation, and test sets.

    One can use these lists of indices to parse the data afterwards.

    Args:
        data (dict): Dictionary mapping column names to lists of data values.

    Returns:
        train (list): The indices for the training set.
        validation (list): The indices for the validation set.
        test (list): The indices for the test set.

    Raises:
        ValueError: If the split argument is not a list with length 3.
        ValueError: If the sum of the split proportions is not 1.
    """
    # Use round to avoid errors due to floating point imprecisions
    if round(sum(self.split), 3) < 1.0:
        raise ValueError(f"The sum of the split proportions should be 1. Instead, it is {sum(self.split)}.")

    if not data:
        raise ValueError("No data provided for splitting")
    # Get length from first column's data list
    length_of_data = len(next(iter(data.values())))

    # Generate a list of indices and shuffle it
    indices = np.arange(length_of_data)
    np.random.seed(self.seed)
    np.random.shuffle(indices)

    # Calculate the sizes of the train, validation, and test sets
    train_size = int(self.split[0] * length_of_data)
    validation_size = int(self.split[1] * length_of_data)

    # Split the shuffled indices according to the calculated sizes
    train = indices[:train_size].tolist()
    validation = indices[train_size : train_size + validation_size].tolist()
    test = indices[train_size + validation_size :].tolist()

    return train, validation, test