Skip to content

optuna_tune

Optuna tuning module.

Classes:

  • Objective

    Objective class for Optuna tuning.

Functions:

  • get_device

    Get the appropriate device (CPU/GPU) for computation.

  • tune_loop

    Run the tuning loop.

Objective

Objective(
    model_class: Module,
    network_params: dict[
        str, TunableParameter | VariableList
    ],
    optimizer_params: dict[str, TunableParameter],
    data_params: dict[str, TunableParameter],
    loss_params: dict[str, TunableParameter],
    train_torch_dataset: Dataset,
    val_torch_dataset: Dataset,
    artifact_store: Any,
    max_batches: int = 1000,
    compute_objective_every_n_batches: int = 50,
    target_metric: str = "val_loss",
    device: device | None = None,
)

Objective class for Optuna tuning.

Parameters:

  • model_class (Module) –

    The model class to be tuned.

  • network_params (dict[str, TunableParameter | VariableList]) –

    The network parameters to be tuned.

  • optimizer_params (dict[str, TunableParameter]) –

    The optimizer parameters to be tuned.

  • data_params (dict[str, TunableParameter]) –

    The data parameters to be tuned.

  • loss_params (dict[str, TunableParameter]) –

    The loss parameters to be tuned.

  • train_torch_dataset (Dataset) –

    The training dataset.

  • val_torch_dataset (Dataset) –

    The validation dataset.

  • artifact_store (Any) –

    The artifact store to save the model and optimizer.

  • max_batches (int, default: 1000 ) –

    The maximum number of batches to train.

  • compute_objective_every_n_batches (int, default: 50 ) –

    The number of batches to compute the objective.

  • target_metric (str, default: 'val_loss' ) –

    The target metric to optimize.

  • device (device | None, default: None ) –

    The device to run the training on.

Methods:

  • objective

    Compute the objective metric(s) for the tuning process.

  • save_checkpoint

    Save the model and optimizer to the trial.

Source code in src/stimulus/learner/optuna_tune.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(
    self,
    model_class: torch.nn.Module,
    network_params: dict[str, model_schema.TunableParameter | model_schema.VariableList],
    optimizer_params: dict[str, model_schema.TunableParameter],
    data_params: dict[str, model_schema.TunableParameter],
    loss_params: dict[str, model_schema.TunableParameter],
    train_torch_dataset: torch.utils.data.Dataset,
    val_torch_dataset: torch.utils.data.Dataset,
    artifact_store: Any,
    max_batches: int = 1000,
    compute_objective_every_n_batches: int = 50,
    target_metric: str = "val_loss",
    device: torch.device | None = None,
):
    """Initialize the Objective class.

    Args:
        model_class: The model class to be tuned.
        network_params: The network parameters to be tuned.
        optimizer_params: The optimizer parameters to be tuned.
        data_params: The data parameters to be tuned.
        loss_params: The loss parameters to be tuned.
        train_torch_dataset: The training dataset.
        val_torch_dataset: The validation dataset.
        artifact_store: The artifact store to save the model and optimizer.
        max_batches: The maximum number of batches to train.
        compute_objective_every_n_batches: The number of batches to compute the objective.
        target_metric: The target metric to optimize.
        device: The device to run the training on.
    """
    self.model_class = model_class
    self.network_params = network_params
    self.optimizer_params = optimizer_params
    self.data_params = data_params
    self.loss_params = loss_params
    self.train_torch_dataset = train_torch_dataset
    self.val_torch_dataset = val_torch_dataset
    self.artifact_store = artifact_store
    self.target_metric = target_metric
    self.max_batches = max_batches
    self.compute_objective_every_n_batches = compute_objective_every_n_batches
    if device is None:
        self.device = torch.device("cpu")
    else:
        self.device = device

objective

objective(
    model: Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_dict: dict[str, Module],
) -> dict[str, float]

Compute the objective metric(s) for the tuning process.

Source code in src/stimulus/learner/optuna_tune.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def objective(
    self,
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    loss_dict: dict[str, torch.nn.Module],
) -> dict[str, float]:
    """Compute the objective metric(s) for the tuning process."""
    metrics = [
        "loss",
        "rocauc",
        "prauc",
        "mcc",
        "f1score",
        "precision",
        "recall",
        "spearmanr",
    ]  # TODO maybe we report only a subset of metrics, given certain criteria (eg. if classification or regression)
    predict_val = PredictWrapper(
        model,
        val_loader,
        loss_dict=loss_dict,
        device=self.device,
    )
    predict_train = PredictWrapper(
        model,
        train_loader,
        loss_dict=loss_dict,
        device=self.device,
    )
    return {
        **{"val_" + metric: value for metric, value in predict_val.compute_metrics(metrics).items()},
        **{"train_" + metric: value for metric, value in predict_train.compute_metrics(metrics).items()},
    }

save_checkpoint

save_checkpoint(
    trial: Trial,
    model_instance: Module,
    optimizer: Optimizer,
) -> None

Save the model and optimizer to the trial.

Source code in src/stimulus/learner/optuna_tune.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def save_checkpoint(
    self,
    trial: optuna.Trial,
    model_instance: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
) -> None:
    """Save the model and optimizer to the trial."""
    unique_id = str(uuid.uuid4())[:8]
    model_path = f"{trial.number}_{unique_id}_model.safetensors"
    optimizer_path = f"{trial.number}_{unique_id}_optimizer.pt"
    safe_save_model(model_instance, model_path)
    torch.save(optimizer.state_dict(), optimizer_path)
    artifact_id_model = optuna.artifacts.upload_artifact(
        artifact_store=self.artifact_store,
        file_path=model_path,
        study_or_trial=trial.study,
    )
    artifact_id_optimizer = optuna.artifacts.upload_artifact(
        artifact_store=self.artifact_store,
        file_path=optimizer_path,
        study_or_trial=trial.study,
    )
    # delete the files from the local filesystem
    try:
        os.remove(model_path)
        os.remove(optimizer_path)
    except FileNotFoundError:
        logger.info(f"File was already deleted: {model_path} or {optimizer_path}, most likely due to pruning")
    trial.set_user_attr("model_id", artifact_id_model)
    trial.set_user_attr("model_path", model_path)
    trial.set_user_attr("optimizer_id", artifact_id_optimizer)
    trial.set_user_attr("optimizer_path", optimizer_path)

get_device

get_device() -> device

Get the appropriate device (CPU/GPU) for computation.

Returns:

  • device

    torch.device: The selected computation device

Source code in src/stimulus/learner/optuna_tune.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def get_device() -> torch.device:
    """Get the appropriate device (CPU/GPU) for computation.

    Returns:
        torch.device: The selected computation device
    """
    if torch.backends.mps.is_available():
        try:
            # Try to allocate a small tensor on MPS to check if it works
            device = torch.device("mps")
            # Create a small tensor and move it to MPS as a test
            test_tensor = torch.ones((1, 1)).to(device)
            del test_tensor  # Free the memory
            logger.info("Using MPS (Metal Performance Shaders) device")
        except RuntimeError as e:
            logger.warning(f"MPS available but failed to initialize: {e}")
            logger.warning("Falling back to CPU")
            return torch.device("cpu")
        else:
            return device

    if torch.cuda.is_available():
        device = torch.device("cuda")
        gpu_name = torch.cuda.get_device_name(0)
        memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        logger.info(f"Using GPU: {gpu_name} with {memory:.2f} GB memory")
        return device

    logger.info("Using CPU (GPU not available)")
    return torch.device("cpu")

tune_loop

tune_loop(
    objective: Objective,
    pruner: BasePruner,
    sampler: BaseSampler,
    n_trials: int,
    direction: str,
    storage: BaseStorage | None = None,
) -> Study

Run the tuning loop.

Parameters:

  • objective (Objective) –

    The objective function to optimize.

  • pruner (BasePruner) –

    The pruner to use.

  • sampler (BaseSampler) –

    The sampler to use.

  • n_trials (int) –

    The number of trials to run.

  • direction (str) –

    The direction to optimize.

  • storage (BaseStorage | None, default: None ) –

    The storage to use.

Returns:

  • Study

    The study object.

Source code in src/stimulus/learner/optuna_tune.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def tune_loop(
    objective: Objective,
    pruner: optuna.pruners.BasePruner,
    sampler: optuna.samplers.BaseSampler,
    n_trials: int,
    direction: str,
    storage: optuna.storages.BaseStorage | None = None,
) -> optuna.Study:
    """Run the tuning loop.

    Args:
        objective: The objective function to optimize.
        pruner: The pruner to use.
        sampler: The sampler to use.
        n_trials: The number of trials to run.
        direction: The direction to optimize.
        storage: The storage to use.

    Returns:
        The study object.
    """
    if storage is None:
        study = optuna.create_study(direction=direction, sampler=sampler, pruner=pruner)
    else:
        study = optuna.create_study(direction=direction, sampler=sampler, pruner=pruner, storage=storage)
    study.optimize(objective, n_trials=n_trials)
    return study