Skip to content

predict

CLI module for model prediction on datasets.

Functions:

  • add_meta_info

    Add metadata columns to predictions/labels dictionary.

  • get_args

    Parse command line arguments.

  • get_batch_size

    Get batch size from model config.

  • get_meta_keys

    Extract metadata column keys.

  • load_model

    Load model with hyperparameters and weights.

  • main

    Run model prediction pipeline.

  • parse_y_keys

    Parse dictionary keys to match input data format.

  • run

    Execute model prediction pipeline.

add_meta_info

add_meta_info(
    data: DataFrame, y: dict[str, Any]
) -> dict[str, Any]

Add metadata columns to predictions/labels dictionary.

Parameters:

  • data (DataFrame) –

    Input DataFrame with metadata.

  • y (dict[str, Any]) –

    Dictionary of predictions/labels.

Returns:

  • dict[str, Any]

    Updated dictionary with metadata.

Source code in src/stimulus/cli/predict.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def add_meta_info(data: pl.DataFrame, y: dict[str, Any]) -> dict[str, Any]:
    """Add metadata columns to predictions/labels dictionary.

    Args:
        data: Input DataFrame with metadata.
        y: Dictionary of predictions/labels.

    Returns:
        Updated dictionary with metadata.
    """
    keys = get_meta_keys(data.columns)
    for key in keys:
        y[key] = data[key].to_list()
    return y

get_args

get_args() -> Namespace

Parse command line arguments.

Returns:

  • Namespace

    Parsed command line arguments.

Source code in src/stimulus/cli/predict.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_args() -> argparse.Namespace:
    """Parse command line arguments.

    Returns:
        Parsed command line arguments.
    """
    parser = argparse.ArgumentParser(description="Predict model outputs on a dataset.")
    parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model .py file.")
    parser.add_argument("-w", "--weight", type=str, required=True, metavar="FILE", help="Path to model weights file.")
    parser.add_argument(
        "-mc",
        "--model_config",
        type=str,
        required=True,
        metavar="FILE",
        help="Path to tune config file with model hyperparameters.",
    )
    parser.add_argument(
        "-ec",
        "--experiment_config",
        type=str,
        required=True,
        metavar="FILE",
        help="Path to experiment config for data modification.",
    )
    parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input data.")
    parser.add_argument("-o", "--output", type=str, required=True, metavar="FILE", help="Path for output predictions.")
    parser.add_argument("--split", type=int, help="Data split to use (default: None).")
    parser.add_argument("--return_labels", action="store_true", help="Include labels with predictions.")

    return parser.parse_args()

get_batch_size

get_batch_size(mconfig: dict[str, Any]) -> int

Get batch size from model config.

Parameters:

  • mconfig (dict[str, Any]) –

    Model configuration dictionary.

Returns:

  • int

    Batch size to use for predictions.

Source code in src/stimulus/cli/predict.py
68
69
70
71
72
73
74
75
76
77
78
79
80
def get_batch_size(mconfig: dict[str, Any]) -> int:
    """Get batch size from model config.

    Args:
        mconfig: Model configuration dictionary.

    Returns:
        Batch size to use for predictions.
    """
    default_batch_size = 256
    if "data_params" in mconfig and "batch_size" in mconfig["data_params"]:
        return mconfig["data_params"]["batch_size"]
    return default_batch_size

get_meta_keys

get_meta_keys(names: Sequence[str]) -> list[str]

Extract metadata column keys.

Parameters:

Returns:

  • list[str]

    List of metadata column keys.

Source code in src/stimulus/cli/predict.py
123
124
125
126
127
128
129
130
131
132
def get_meta_keys(names: Sequence[str]) -> list[str]:
    """Extract metadata column keys.

    Args:
        names: List of column names.

    Returns:
        List of metadata column keys.
    """
    return [name for name in names if name.split(":")[1] == "meta"]

load_model

load_model(
    model_class: Any,
    weight_path: str,
    mconfig: dict[str, Any],
) -> Any

Load model with hyperparameters and weights.

Parameters:

  • model_class (Any) –

    Model class to instantiate.

  • weight_path (str) –

    Path to model weights.

  • mconfig (dict[str, Any]) –

    Model configuration dictionary.

Returns:

  • Any

    Loaded model instance.

Source code in src/stimulus/cli/predict.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def load_model(model_class: Any, weight_path: str, mconfig: dict[str, Any]) -> Any:
    """Load model with hyperparameters and weights.

    Args:
        model_class: Model class to instantiate.
        weight_path: Path to model weights.
        mconfig: Model configuration dictionary.

    Returns:
        Loaded model instance.
    """
    hyperparameters = mconfig["model_params"]
    model = model_class(**hyperparameters)
    model.load_state_dict(torch.load(weight_path))
    return model

main

main(
    model_path: str,
    weight_path: str,
    mconfig_path: str,
    econfig_path: str,
    data_path: str,
    output: str,
    *,
    return_labels: bool = False,
    split: int | None = None
) -> None

Run model prediction pipeline.

Parameters:

  • model_path (str) –

    Path to model file.

  • weight_path (str) –

    Path to model weights.

  • mconfig_path (str) –

    Path to model config.

  • econfig_path (str) –

    Path to experiment config.

  • data_path (str) –

    Path to input data.

  • output (str) –

    Path for output predictions.

  • return_labels (bool, default: False ) –

    Whether to include labels.

  • split (int | None, default: None ) –

    Data split to use.

Source code in src/stimulus/cli/predict.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def main(
    model_path: str,
    weight_path: str,
    mconfig_path: str,
    econfig_path: str,
    data_path: str,
    output: str,
    *,
    return_labels: bool = False,
    split: int | None = None,
) -> None:
    """Run model prediction pipeline.

    Args:
        model_path: Path to model file.
        weight_path: Path to model weights.
        mconfig_path: Path to model config.
        econfig_path: Path to experiment config.
        data_path: Path to input data.
        output: Path for output predictions.
        return_labels: Whether to include labels.
        split: Data split to use.
    """
    with open(mconfig_path) as in_json:
        mconfig = json.load(in_json)

    model_class = import_class_from_file(model_path)
    model = load_model(model_class, weight_path, mconfig)

    with open(econfig_path) as in_json:
        experiment_name = json.load(in_json)["experiment"]
    initialized_experiment_class = get_experiment(experiment_name)

    dataloader = DataLoader(
        TorchDataset(data_path, initialized_experiment_class, split=split),
        batch_size=get_batch_size(mconfig),
        shuffle=False,
    )

    predictor = PredictWrapper(model, dataloader)
    out = predictor.predict(return_labels=return_labels)
    y_pred, y_true = out if return_labels else (out, {})

    y_pred = {k: v.tolist() for k, v in y_pred.items()}
    y_true = {k: v.tolist() for k, v in y_true.items()}

    data = pl.read_csv(data_path)
    y_pred = parse_y_keys(y_pred, data, y_type="pred")
    y_true = parse_y_keys(y_true, data, y_type="label")

    y = {**y_pred, **y_true}
    y = add_meta_info(data, y)
    df = pl.from_dict(y)
    df.write_csv(output)

parse_y_keys

parse_y_keys(
    y: dict[str, Any], data: DataFrame, y_type: str = "pred"
) -> dict[str, Any]

Parse dictionary keys to match input data format.

Parameters:

  • y (dict[str, Any]) –

    Dictionary of predictions or labels.

  • data (DataFrame) –

    Input DataFrame.

  • y_type (str, default: 'pred' ) –

    Type of values ('pred' or 'label').

Returns:

  • dict[str, Any]

    Dictionary with updated keys.

Source code in src/stimulus/cli/predict.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def parse_y_keys(y: dict[str, Any], data: pl.DataFrame, y_type: str = "pred") -> dict[str, Any]:
    """Parse dictionary keys to match input data format.

    Args:
        y: Dictionary of predictions or labels.
        data: Input DataFrame.
        y_type: Type of values ('pred' or 'label').

    Returns:
        Dictionary with updated keys.
    """
    if not y:
        return y

    parsed_y = {}
    for k1, v1 in y.items():
        for k2 in data.columns:
            if k1 == k2.split(":")[0]:
                new_key = f"{k1}:{y_type}:{k2.split(':')[2]}"
                parsed_y[new_key] = v1

    return parsed_y

run

run() -> None

Execute model prediction pipeline.

Source code in src/stimulus/cli/predict.py
191
192
193
194
195
196
197
198
199
200
201
202
203
def run() -> None:
    """Execute model prediction pipeline."""
    args = get_args()
    main(
        args.model,
        args.weight,
        args.model_config,
        args.experiment_config,
        args.data,
        args.output,
        return_labels=args.return_labels,
        split=args.split,
    )