Skip to content

transform_csv

CLI module for transforming CSV data files.

Functions:

load_transforms_from_config

load_transforms_from_config(
    data_config_path: str,
) -> dict[str, list[Any]]

Load the data config from a path.

Parameters:

  • data_config_path (str) –

    Path to the data config file.

Returns:

  • dict[str, list[Any]]

    A dictionary mapping column names to lists of transform objects.

Source code in src/stimulus/cli/transform_csv.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def load_transforms_from_config(data_config_path: str) -> dict[str, list[Any]]:
    """Load the data config from a path.

    Args:
        data_config_path: Path to the data config file.

    Returns:
        A dictionary mapping column names to lists of transform objects.
    """
    with open(data_config_path) as file:
        data_config_dict = yaml.safe_load(file)
        data_config_obj = data_config_parser.IndividualTransformConfigDict(**data_config_dict)

    return data_config_parser.parse_individual_transform_config(data_config_obj)

main

main(
    data_csv: str, config_yaml: str, out_path: str
) -> None

Transform the data according to the configuration.

Parameters:

  • data_csv (str) –

    Path to input CSV file.

  • config_yaml (str) –

    Path to config YAML file.

  • out_path (str) –

    Path to output transformed CSV.

Source code in src/stimulus/cli/transform_csv.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def main(data_csv: str, config_yaml: str, out_path: str) -> None:
    """Transform the data according to the configuration.

    Args:
        data_csv: Path to input CSV file.
        config_yaml: Path to config YAML file.
        out_path: Path to output transformed CSV.
    """
    dataset = load_dataset_from_path(data_csv)

    dataset.set_format(type="numpy")
    # Create transforms from the config
    transforms = load_transforms_from_config(config_yaml)
    logger.info("Transforms initialized successfully.")

    # Apply the transformations to the data
    dataset = dataset.map(
        transform_batch,
        batched=True,
        fn_kwargs={"transforms_config": transforms},
    )
    logger.debug(f"Dataset type: {type(dataset)}")
    dataset["train"] = dataset["train"].filter(lambda example: not any(pd.isna(value) for value in example.values()))
    if "test" in dataset:
        dataset["test"] = dataset["test"].filter(lambda example: not any(pd.isna(value) for value in example.values()))
    dataset.save_to_disk(out_path)

transform_batch

transform_batch(
    batch: LazyBatch,
    transforms_config: dict[str, list[Any]],
) -> dict[str, list]

Transform a batch of data.

This function applies a series of configured transformations to specified columns within a batch. It assumes that each transformation's transform_all method returns a list of the same length as its input.

For 'remove_row' transforms, np.nan is expected in the output list for removed items. The 'add_row' flag's effect on overall dataset structure (like row duplication) is handled outside this function, based on its output.

Parameters:

  • batch (LazyBatch) –

    The input batch of data (a Hugging Face LazyBatch).

  • transforms_config (dict[str, list[Any]]) –

    A dictionary where keys are column names and values are lists of transform objects to be applied to that column.

Returns:

  • dict[str, list]

    A dictionary representing the transformed batch, with all original columns

  • dict[str, list]

    present and modified columns updated according to the transforms.

Source code in src/stimulus/cli/transform_csv.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def transform_batch(
    batch: datasets.formatting.formatting.LazyBatch,
    transforms_config: dict[str, list[Any]],
) -> dict[str, list]:
    """Transform a batch of data.

    This function applies a series of configured transformations to specified columns
    within a batch. It assumes that each transformation's `transform_all` method
    returns a list of the same length as its input.

    For 'remove_row' transforms, `np.nan` is expected in the output list for removed items.
    The 'add_row' flag's effect on overall dataset structure (like row duplication)
    is handled outside this function, based on its output.

    Args:
        batch: The input batch of data (a Hugging Face LazyBatch).
        transforms_config: A dictionary where keys are column names and values are
                           lists of transform objects to be applied to that column.

    Returns:
        A dictionary representing the transformed batch, with all original columns
        present and modified columns updated according to the transforms.
    """
    # here we should init a result directory from the batch.
    result_dict = dict(batch)
    for column_name, list_of_transforms in transforms_config.items():
        if column_name not in batch:
            logger.warning(
                f"Column '{column_name}' specified in transforms_config was not found "
                f"in the batch columns (columns: {list(batch.keys())}). Skipping transforms for this column.",
            )
            continue

        for transform_obj in list_of_transforms:
            if transform_obj.add_row:
                # here duplicate the batch
                original_values = result_dict[column_name]
                processed_values = transform_obj.transform_all(original_values)
                for key, value in result_dict.items():
                    if key != column_name:
                        if isinstance(value, np.ndarray):
                            result_dict[key] = np.char.add(value, value)
                        else:
                            result_dict[key] = value + value
                    elif isinstance(value, np.ndarray):
                        result_dict[key] = np.char.add(value, processed_values)
                    else:
                        result_dict[key] = value + processed_values
            else:
                result_dict[column_name] = transform_obj.transform_all(result_dict[column_name])

    return result_dict