Skip to content

predict

A module for making predictions with PyTorch models using DataLoaders.

Classes:

  • PredictWrapper

    A wrapper to predict the output of a model on a datset loaded into a torch DataLoader.

PredictWrapper

PredictWrapper(
    model: Module,
    dataloader: DataLoader,
    loss_dict: Optional[dict[str, Any]] = None,
    device: device | None = None,
)

A wrapper to predict the output of a model on a datset loaded into a torch DataLoader.

It also provides the functionalities to measure the performance of the model.

Parameters:

  • model (Module) –

    The PyTorch model to make predictions with

  • dataloader (DataLoader) –

    DataLoader containing the evaluation data

  • loss_dict (Optional[dict[str, Any]], default: None ) –

    Optional dictionary of loss functions

  • device (device | None, default: None ) –

    The device to run the model on

Methods:

Source code in src/stimulus/learner/predict.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(
    self,
    model: nn.Module,
    dataloader: DataLoader,
    loss_dict: Optional[dict[str, Any]] = None,
    device: torch.device | None = None,
) -> None:
    """Initialize the PredictWrapper.

    Args:
        model: The PyTorch model to make predictions with
        dataloader: DataLoader containing the evaluation data
        loss_dict: Optional dictionary of loss functions
        device: The device to run the model on
    """
    if device is None:
        self.device = torch.device("cpu")
    else:
        self.device = device

    try:
        self.model = model.to(self.device)
    except RuntimeError as e:
        if self.device.type in ["cuda", "mps"]:
            logger.warning(f"Failed to move model to {self.device.type.upper()}: {e}")
            logger.warning("Falling back to CPU")
            self.device = torch.device("cpu")
            self.model = model.to(self.device)
        else:
            raise

    self.dataloader = dataloader
    self.loss_dict = loss_dict

    try:
        self.model.eval()
    except RuntimeError as e:
        logger.warning("Not able to run model.eval: %s", str(e))

compute_loss

compute_loss() -> float

Compute the loss.

The current implmentation basically computes the loss for each batch and then averages them. TODO we could potentially summarize the los across batches in a different way. Or sometimes we may potentially even have 1+ losses.

Source code in src/stimulus/learner/predict.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def compute_loss(self) -> float:
    """Compute the loss.

    The current implmentation basically computes the loss for each batch and then averages them.
    TODO we could potentially summarize the los across batches in a different way.
    Or sometimes we may potentially even have 1+ losses.
    """
    if self.loss_dict is None:
        raise ValueError("Loss function is not provided.")
    loss = 0.0
    with torch.no_grad():
        for x, y, _ in self.dataloader:
            try:
                # Move input tensors to the same device as the model
                device_x = {key: value.to(self.device) for key, value in x.items()}
                device_y = {key: value.to(self.device) for key, value in y.items()}
                # the loss_dict could be unpacked with ** and the function declaration handle it differently like **kwargs. to be decided, personally find this more clean and understable.
                current_loss = self.model.batch(x=device_x, y=device_y, **self.loss_dict)[0]
            except RuntimeError as e:
                if ("CUDA out of memory" in str(e) and self.device.type == "cuda") or (
                    "MPS backend out of memory" in str(e) and self.device.type == "mps"
                ):
                    logger.warning(f"{self.device.type.upper()} out of memory during loss computation: {e}")
                    logger.warning("Falling back to CPU for this batch")
                    temp_device = torch.device("cpu")
                    # Use CPU for this batch
                    x_cpu = {key: value.to(temp_device) for key, value in x.items()}
                    y_cpu = {key: value.to(temp_device) for key, value in y.items()}
                    # Move model to CPU temporarily
                    model_on_cpu = self.model.to(temp_device)
                    current_loss = model_on_cpu.batch(x=x_cpu, y=y_cpu, **self.loss_dict)[0]
                    # Move model back to original device for next batches
                    try:
                        self.model = self.model.to(self.device)
                    except RuntimeError:
                        logger.warning(f"Failed to move model back to {self.device.type}. Staying on CPU.")
                        self.device = temp_device
                else:
                    raise

            loss += current_loss.item()
    return loss / len(self.dataloader)

compute_metric

compute_metric(metric: str = 'loss') -> float

Wrapper to compute the performance metric.

Source code in src/stimulus/learner/predict.py
139
140
141
142
143
def compute_metric(self, metric: str = "loss") -> float:
    """Wrapper to compute the performance metric."""
    if metric == "loss":
        return self.compute_loss()
    return self.compute_other_metric(metric)

compute_metrics

compute_metrics(metrics: list[str]) -> dict[str, float]

Wrapper to compute the performance metrics.

Source code in src/stimulus/learner/predict.py
135
136
137
def compute_metrics(self, metrics: list[str]) -> dict[str, float]:
    """Wrapper to compute the performance metrics."""
    return {m: self.compute_metric(m) for m in metrics}

compute_other_metric

compute_other_metric(metric: str) -> float

Compute the performance metric.

TODO currently we computes the average performance metric across target y, but maybe in the future we want something different

Source code in src/stimulus/learner/predict.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def compute_other_metric(self, metric: str) -> float:
    """Compute the performance metric.

    # TODO currently we computes the average performance metric across target y, but maybe in the future we want something different
    """
    if not hasattr(self, "predictions") or not hasattr(self, "labels"):
        predictions, labels = self.predict(return_labels=True)
        self.predictions = predictions
        self.labels = labels

    # Explicitly type the labels and predictions as dictionaries with str keys
    labels_dict: dict[str, Tensor] = self.labels if isinstance(self.labels, dict) else {}
    predictions_dict: dict[str, Tensor] = self.predictions if isinstance(self.predictions, dict) else {}

    return sum(
        Performance(labels=labels_dict[k], predictions=predictions_dict[k], metric=metric).val for k in labels_dict
    ) / len(labels_dict)

handle_predictions

handle_predictions(
    predictions: Any, y: dict[str, Tensor]
) -> dict[str, Tensor]

Handle the model outputs from forward pass, into a dictionary of tensors, just like y.

Source code in src/stimulus/learner/predict.py
129
130
131
132
133
def handle_predictions(self, predictions: Any, y: dict[str, Tensor]) -> dict[str, Tensor]:
    """Handle the model outputs from forward pass, into a dictionary of tensors, just like y."""
    if len(y) == 1:
        return {next(iter(y.keys())): predictions}
    return dict(zip(y.keys(), predictions))

predict

predict(*, return_labels: bool = False) -> Union[
    dict[str, Tensor],
    tuple[dict[str, Tensor], dict[str, Tensor]],
]

Get the model predictions.

Basically, it runs a foward pass on the model for each batch, gets the predictions and concatenate them for all batches. Since the returned current_predictions are formed by tensors computed for one batch, the final predictions are obtained by concatenating them.

At the end it returns predictions as a dictionary of tensors with the same keys as y.

If return_labels if True, then the labels will be returned as well, also as a dictionary of tensors.

Parameters:

  • return_labels (bool, default: False ) –

    Whether to also return the labels

Returns:

Source code in src/stimulus/learner/predict.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def predict(
    self,
    *,
    return_labels: bool = False,
) -> Union[dict[str, Tensor], tuple[dict[str, Tensor], dict[str, Tensor]]]:
    """Get the model predictions.

    Basically, it runs a foward pass on the model for each batch,
    gets the predictions and concatenate them for all batches.
    Since the returned `current_predictions` are formed by tensors computed for one batch,
    the final `predictions` are obtained by concatenating them.

    At the end it returns `predictions` as a dictionary of tensors with the same keys as `y`.

    If return_labels if True, then the `labels` will be returned as well, also as a dictionary of tensors.

    Args:
        return_labels: Whether to also return the labels

    Returns:
        Dictionary of predictions, and optionally labels
    """
    # create empty dictionaries with the column names
    first_batch = next(iter(self.dataloader))
    keys = first_batch[1].keys()
    predictions: dict[str, list[Tensor]] = {k: [] for k in keys}
    labels: dict[str, list[Tensor]] = {k: [] for k in keys}

    # get the predictions (and labels) for each batch
    with torch.no_grad():
        for x, y, _ in self.dataloader:
            try:
                x_device = {key: value.to(self.device) for key, value in x.items()}
                current_predictions = self.model(**x_device).detach().cpu()
                current_predictions = self.handle_predictions(current_predictions, y)
            except RuntimeError as e:
                if ("CUDA out of memory" in str(e) and self.device.type == "cuda") or (
                    "MPS backend out of memory" in str(e) and self.device.type == "mps"
                ):
                    logger.warning(f"{self.device.type.upper()} out of memory during prediction: {e}")
                    logger.warning("Falling back to CPU for this batch")
                    temp_device = torch.device("cpu")
                    # Use CPU for this batch
                    x_cpu = {key: value.to(temp_device) for key, value in x.items()}
                    # Move model to CPU temporarily
                    model_on_cpu = self.model.to(temp_device)
                    current_predictions = model_on_cpu(**x_cpu).detach().cpu()
                    current_predictions = self.handle_predictions(current_predictions, y)
                    # Move model back to original device for next batches
                    try:
                        self.model = self.model.to(self.device)
                    except RuntimeError:
                        logger.warning(f"Failed to move model back to {self.device.type}. Staying on CPU.")
                        self.device = temp_device
                else:
                    raise

            for k in keys:
                # it might happen that the batch consists of one element only so the torch.cat will fail. To prevent this the function to ensure at least one dimensionality is called.
                predictions[k].append(ensure_at_least_1d(current_predictions[k]))
                if return_labels:
                    labels[k].append(ensure_at_least_1d(y[k]))

    # return the predictions (and labels) as a dictionary of tensors for the entire dataset.
    if not return_labels:
        return {k: torch.cat(v) for k, v in predictions.items()}
    return {k: torch.cat(v) for k, v in predictions.items()}, {k: torch.cat(v) for k, v in labels.items()}