Skip to content

tuning

CLI module for running raytune tuning experiment.

Functions:

load_data_config_from_path

load_data_config_from_path(
    data_path: str, data_config_path: str, split: int
) -> Dataset

Load the data config from a path.

Parameters:

  • data_path (str) –

    Path to the input data file.

  • data_config_path (str) –

    Path to the data config file.

  • split (int) –

    Split index to use (0=train, 1=validation, 2=test).

Returns:

  • Dataset

    A TorchDataset with the configured data.

Source code in src/stimulus/cli/tuning.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def load_data_config_from_path(data_path: str, data_config_path: str, split: int) -> torch.utils.data.Dataset:
    """Load the data config from a path.

    Args:
        data_path: Path to the input data file.
        data_config_path: Path to the data config file.
        split: Split index to use (0=train, 1=validation, 2=test).

    Returns:
        A TorchDataset with the configured data.
    """
    with open(data_config_path) as file:
        data_config_dict = yaml.safe_load(file)
        data_config_obj = data_config_parser.SplitTransformDict(**data_config_dict)

    encoders, input_columns, label_columns, meta_columns = data_config_parser.parse_split_transform_config(
        data_config_obj,
    )

    return data_handlers.TorchDataset(
        loader=data_handlers.DatasetLoader(
            encoders=encoders,
            input_columns=input_columns,
            label_columns=label_columns,
            meta_columns=meta_columns,
            csv_path=data_path,
            split=split,
        ),
    )

tune

tune(
    data_path: str,
    model_path: str,
    data_config_path: str,
    model_config_path: str,
    optuna_results_dirpath: str = "./optuna_results",
    best_model_path: str = "best_model.safetensors",
    best_optimizer_path: str = "best_optimizer.pt",
) -> None

Run model hyperparameter tuning.

Parameters:

  • data_path (str) –

    Path to input data file.

  • model_path (str) –

    Path to model file.

  • data_config_path (str) –

    Path to data config file.

  • model_config_path (str) –

    Path to model config file.

  • optuna_results_dirpath (str, default: './optuna_results' ) –

    Directory for optuna results.

  • best_model_path (str, default: 'best_model.safetensors' ) –

    Path to write the best model to.

  • best_optimizer_path (str, default: 'best_optimizer.pt' ) –

    Path to write the best optimizer to.

Source code in src/stimulus/cli/tuning.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def tune(
    data_path: str,
    model_path: str,
    data_config_path: str,
    model_config_path: str,
    optuna_results_dirpath: str = "./optuna_results",
    best_model_path: str = "best_model.safetensors",
    best_optimizer_path: str = "best_optimizer.pt",
) -> None:
    """Run model hyperparameter tuning.

    Args:
        data_path: Path to input data file.
        model_path: Path to model file.
        data_config_path: Path to data config file.
        model_config_path: Path to model config file.
        optuna_results_dirpath: Directory for optuna results.
        best_model_path: Path to write the best model to.
        best_optimizer_path: Path to write the best optimizer to.
    """
    # Load train and validation datasets
    train_dataset = load_data_config_from_path(data_path, data_config_path, split=0)
    validation_dataset = load_data_config_from_path(data_path, data_config_path, split=1)

    # Load model class
    model_class = model_file_interface.import_class_from_file(model_path)

    # Load model config
    with open(model_config_path) as file:
        model_config_dict: dict[str, Any] = yaml.safe_load(file)
    model_config: model_schema.Model = model_schema.Model(**model_config_dict)

    # get the pruner
    pruner = model_config_parser.get_pruner(model_config.pruner)

    # get the sampler
    sampler = model_config_parser.get_sampler(model_config.sampler)

    # storage setups
    base_path = optuna_results_dirpath
    artifact_path = optuna_results_dirpath + "/artifacts"
    os.makedirs(base_path, exist_ok=True)
    os.makedirs(artifact_path, exist_ok=True)
    artifact_store = optuna.artifacts.FileSystemArtifactStore(base_path=artifact_path)
    storage = optuna.storages.JournalStorage(
        optuna.storages.journal.JournalFileBackend(f"{base_path}/optuna_journal_storage.log"),
    )

    device = optuna_tune.get_device()

    objective = optuna_tune.Objective(
        model_class=model_class,
        network_params=model_config.network_params,
        optimizer_params=model_config.optimizer_params,
        data_params=model_config.data_params,
        loss_params=model_config.loss_params,
        train_torch_dataset=train_dataset,
        val_torch_dataset=validation_dataset,
        artifact_store=artifact_store,
        max_batches=model_config.max_batches,
        compute_objective_every_n_batches=model_config.compute_objective_every_n_batches,
        target_metric=model_config.objective.metric,
        device=device,
    )

    study = optuna_tune.tune_loop(
        objective=objective,
        storage=storage,
        sampler=sampler,
        pruner=pruner,
        n_trials=model_config.n_trials,
        direction=model_config.objective.direction,
    )

    best_trial = study.best_trial
    best_model_artifact_id = best_trial.user_attrs["model_id"]
    best_optimizer_artifact_id = best_trial.user_attrs["optimizer_id"]
    best_model_file_path = best_trial.user_attrs["model_path"]
    best_optimizer_file_path = best_trial.user_attrs["optimizer_path"]

    optuna.artifacts.download_artifact(
        artifact_store=artifact_store,
        file_path=best_model_file_path,
        artifact_id=best_model_artifact_id,
    )
    optuna.artifacts.download_artifact(
        artifact_store=artifact_store,
        file_path=best_optimizer_file_path,
        artifact_id=best_optimizer_artifact_id,
    )
    try:
        shutil.move(best_model_file_path, best_model_path)
        shutil.move(best_optimizer_file_path, best_optimizer_path)
    except FileNotFoundError:
        logger.info("Best model or optimizer file_path not found, creating output directories")
        os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
        os.makedirs(os.path.dirname(best_optimizer_path), exist_ok=True)
        shutil.move(best_model_file_path, best_model_path)
        shutil.move(best_optimizer_file_path, best_optimizer_path)