21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132 | def tune(
data_path: str,
model_path: str,
model_config_path: str,
optuna_results_dirpath: str = "./optuna_results",
best_model_path: str = "best_model.safetensors",
best_optimizer_path: str = "best_optimizer.pt",
best_config_path: str = "best_config.json",
force_device: Optional[str] = None,
) -> None:
"""Run model hyperparameter tuning.
Args:
data_path: Path to input data file.
model_path: Path to model file.
model_config_path: Path to model config file.
optuna_results_dirpath: Directory for optuna results.
best_model_path: Path to write the best model to.
best_optimizer_path: Path to write the best optimizer to.
force_device: Force the device to use.
"""
# Load train and validation datasets
dataset_dict = datasets.load_from_disk(data_path)
dataset_dict.set_format("torch")
train_dataset = dataset_dict["train"]
validation_dataset = dataset_dict["test"]
# Load model class
model_class = model_file_interface.import_class_from_file(model_path)
# Load model config
with open(model_config_path) as file:
model_config_dict: dict[str, Any] = yaml.safe_load(file)
model_config: model_schema.Model = model_schema.Model(**model_config_dict)
# get the pruner
pruner = model_config_parser.get_pruner(model_config.pruner)
# get the sampler
sampler = model_config_parser.get_sampler(model_config.sampler)
# storage setups
base_path = optuna_results_dirpath
artifact_path = optuna_results_dirpath + "/artifacts"
os.makedirs(base_path, exist_ok=True)
os.makedirs(artifact_path, exist_ok=True)
artifact_store = optuna.artifacts.FileSystemArtifactStore(base_path=artifact_path)
storage = optuna.storages.JournalStorage(
optuna.storages.journal.JournalFileBackend(f"{base_path}/optuna_journal_storage.log"),
)
device = resolve_device(force_device=force_device, config_device=model_config.device)
objective = optuna_tune.Objective(
model_class=model_class,
network_params=model_config.network_params,
optimizer_params=model_config.optimizer_params,
data_params=model_config.data_params,
loss_params=model_config.loss_params,
train_torch_dataset=train_dataset,
val_torch_dataset=validation_dataset,
artifact_store=artifact_store,
max_samples=model_config.max_samples,
compute_objective_every_n_samples=model_config.compute_objective_every_n_samples,
target_metric=model_config.objective.metric,
device=device,
)
study = optuna_tune.tune_loop(
objective=objective,
storage=storage,
sampler=sampler,
pruner=pruner,
n_trials=model_config.n_trials,
direction=model_config.objective.direction,
)
best_trial = study.best_trial
best_model_artifact_id = best_trial.user_attrs["model_id"]
best_optimizer_artifact_id = best_trial.user_attrs["optimizer_id"]
best_model_file_path = best_trial.user_attrs["model_path"]
best_optimizer_file_path = best_trial.user_attrs["optimizer_path"]
best_model_suggestions_artifact_id = best_trial.user_attrs["model_suggestions_id"]
best_model_suggestions_file_path = best_trial.user_attrs["model_suggestions_path"]
optuna.artifacts.download_artifact(
artifact_store=artifact_store,
file_path=best_model_file_path,
artifact_id=best_model_artifact_id,
)
optuna.artifacts.download_artifact(
artifact_store=artifact_store,
file_path=best_optimizer_file_path,
artifact_id=best_optimizer_artifact_id,
)
optuna.artifacts.download_artifact(
artifact_store=artifact_store,
file_path=best_model_suggestions_file_path,
artifact_id=best_model_suggestions_artifact_id,
)
try:
shutil.move(best_model_file_path, best_model_path)
shutil.move(best_optimizer_file_path, best_optimizer_path)
shutil.move(best_model_suggestions_file_path, best_config_path)
except FileNotFoundError:
logger.info("Best model or optimizer file_path not found, creating output directories")
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
os.makedirs(os.path.dirname(best_optimizer_path), exist_ok=True)
os.makedirs(os.path.dirname(best_config_path), exist_ok=True)
shutil.move(best_model_file_path, best_model_path)
shutil.move(best_optimizer_file_path, best_optimizer_path)
shutil.move(best_model_suggestions_file_path, best_config_path)
|