Skip to content

encode_csv

CLI module for encoding CSV data files.

Functions:

encode_batch

encode_batch(
    batch: LazyBatch, encoders_config: dict[str, Any]
) -> dict[str, list]

Encode a batch of data.

This function applies configured encoders to specified columns within a batch. Each encoder's batch_encode method is called to transform the column data.

Parameters:

  • batch (LazyBatch) –

    The input batch of data (a Hugging Face LazyBatch).

  • encoders_config (dict[str, Any]) –

    A dictionary where keys are column names and values are encoder objects to be applied to that column.

Returns:

  • dict[str, list]

    A dictionary representing the encoded batch, with all original columns

  • dict[str, list]

    present and encoded columns updated according to the encoders.

Source code in src/stimulus/cli/encode_csv.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def encode_batch(
    batch: datasets.formatting.formatting.LazyBatch,
    encoders_config: dict[str, Any],
) -> dict[str, list]:
    """Encode a batch of data.

    This function applies configured encoders to specified columns within a batch.
    Each encoder's `batch_encode` method is called to transform the column data.

    Args:
        batch: The input batch of data (a Hugging Face LazyBatch).
        encoders_config: A dictionary where keys are column names and values are
                        encoder objects to be applied to that column.

    Returns:
        A dictionary representing the encoded batch, with all original columns
        present and encoded columns updated according to the encoders.
    """
    result_dict = dict(batch)

    for column_name, encoder in encoders_config.items():
        if column_name not in batch:
            logger.warning(
                f"Column '{column_name}' specified in encoders_config was not found "
                f"in the batch columns (columns: {list(batch.keys())}). Skipping encoding for this column.",
            )
            continue

        # Get the column data as numpy array
        column_data = np.array(batch[column_name])

        # Apply the encoder
        try:
            encoded_data = encoder.batch_encode(column_data)
            result_dict[column_name] = encoded_data.tolist() if isinstance(encoded_data, np.ndarray) else encoded_data
        except Exception:
            logger.exception(f"Failed to encode column '{column_name}'")
            raise

    return result_dict

load_encoders_from_config

load_encoders_from_config(
    data_config_path: str,
) -> dict[str, Any]

Load the encoders from the data config.

Parameters:

  • data_config_path (str) –

    Path to the data config file.

Returns:

  • dict[str, Any]

    A dictionary mapping column names to encoder instances.

Source code in src/stimulus/cli/encode_csv.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def load_encoders_from_config(data_config_path: str) -> dict[str, Any]:
    """Load the encoders from the data config.

    Args:
        data_config_path: Path to the data config file.

    Returns:
        A dictionary mapping column names to encoder instances.
    """
    with open(data_config_path) as file:
        data_config_dict = yaml.safe_load(file)
        data_config_obj = data_config_parser.EncodingConfigDict(**data_config_dict)

    encoders, _input_columns, _label_columns, _meta_columns = data_config_parser.parse_encoding_config(
        data_config_obj,
    )

    # Return all encoders for all column types
    return encoders

main

main(
    data_path: str,
    config_yaml: str,
    out_path: str,
    num_proc: Optional[int] = None,
) -> None

Encode the data according to the configuration.

Parameters:

  • data_path (str) –

    Path to input data (CSV, parquet, or HuggingFace dataset directory).

  • config_yaml (str) –

    Path to config YAML file.

  • out_path (str) –

    Path to output encoded dataset directory.

  • num_proc (Optional[int], default: None ) –

    Number of processes to use for encoding.

Source code in src/stimulus/cli/encode_csv.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def main(data_path: str, config_yaml: str, out_path: str, num_proc: Optional[int] = None) -> None:
    """Encode the data according to the configuration.

    Args:
        data_path: Path to input data (CSV, parquet, or HuggingFace dataset directory).
        config_yaml: Path to config YAML file.
        out_path: Path to output encoded dataset directory.
        num_proc: Number of processes to use for encoding.
    """
    # Load the dataset
    dataset = load_dataset_from_path(data_path)

    # Set format to numpy for processing
    dataset.set_format(type="numpy")

    # Load encoders from config
    encoders = load_encoders_from_config(config_yaml)
    logger.info("Encoders initialized successfully.")
    logger.info(f"Loaded encoders for columns: {list(encoders.keys())}")

    # Identify and remove columns that aren't in the encoder configuration
    encoder_columns = set(encoders.keys())
    columns_to_remove = set()

    for split_name, split_dataset in dataset.items():
        dataset_columns = set(split_dataset.column_names)
        split_columns_to_remove = dataset_columns - encoder_columns
        columns_to_remove.update(split_columns_to_remove)
        logger.info(f"Split '{split_name}' columns to remove: {list(split_columns_to_remove)}")

    if columns_to_remove:
        logger.info(f"Removing columns not in encoder configuration: {list(columns_to_remove)}")
        dataset = dataset.remove_columns(list(columns_to_remove))

    # Apply the encoders to the data
    dataset = dataset.map(
        encode_batch,
        batched=True,
        fn_kwargs={"encoders_config": encoders},
        num_proc=num_proc,
    )

    logger.info(f"Dataset encoded successfully. Saving to: {out_path}")

    # Save the encoded dataset to disk
    dataset.save_to_disk(out_path)

    logger.info(f"Encoded dataset saved to: {out_path}")