Skip to content

check_model

CLI module for checking model configuration and running initial tests.

Functions:

  • check_model

    Run the main model checking pipeline.

check_model

check_model(
    data_path: str,
    model_path: str,
    model_config_path: str,
    optuna_results_dirpath: str = "./optuna_results",
    force_device: Optional[str] = None,
) -> tuple[str, str]

Run the main model checking pipeline.

Parameters:

  • data_path (str) –

    Path to input data file.

  • model_path (str) –

    Path to model file.

  • model_config_path (str) –

    Path to model config file.

  • optuna_results_dirpath (str, default: './optuna_results' ) –

    Directory for optuna results.

  • force_device (Optional[str], default: None ) –

    Force the device to use.

Source code in src/stimulus/cli/check_model.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def check_model(
    data_path: str,
    model_path: str,
    model_config_path: str,
    optuna_results_dirpath: str = "./optuna_results",
    force_device: Optional[str] = None,
) -> tuple[str, str]:
    """Run the main model checking pipeline.

    Args:
        data_path: Path to input data file.
        model_path: Path to model file.
        model_config_path: Path to model config file.
        optuna_results_dirpath: Directory for optuna results.
        force_device: Force the device to use.
    """
    dataset_dict = datasets.load_from_disk(data_path)
    dataset_dict.set_format("torch")
    train_dataset = dataset_dict["train"]
    validation_dataset = dataset_dict["test"]
    logger.info("Dataset loaded successfully.")

    model_class = model_file_interface.import_class_from_file(model_path)

    logger.info("Model class loaded successfully.")

    with open(model_config_path) as file:
        model_config_content = yaml.safe_load(file)
        model_config = model_schema.Model(**model_config_content)

    logger.info("Model config loaded successfully.")

    base_path = optuna_results_dirpath
    artifact_path = optuna_results_dirpath + "/artifacts"
    os.makedirs(base_path, exist_ok=True)
    os.makedirs(artifact_path, exist_ok=True)
    artifact_store = optuna.artifacts.FileSystemArtifactStore(base_path=artifact_path)
    storage = optuna.storages.JournalStorage(
        optuna.storages.journal.JournalFileBackend(f"{base_path}/optuna_journal_storage.log"),
    )

    device = resolve_device(force_device=force_device, config_device=model_config.device)
    objective = optuna_tune.Objective(
        model_class=model_class,
        network_params=model_config.network_params,
        optimizer_params=model_config.optimizer_params,
        data_params=model_config.data_params,
        loss_params=model_config.loss_params,
        train_torch_dataset=train_dataset,
        val_torch_dataset=validation_dataset,
        artifact_store=artifact_store,
        max_samples=model_config.max_samples,
        compute_objective_every_n_samples=model_config.compute_objective_every_n_samples,
        target_metric=model_config.objective.metric,
        device=device,
    )

    logger.info(f"Objective: {objective}")
    study = optuna_tune.tune_loop(
        objective=objective,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=2, n_startup_trials=2),
        sampler=optuna.samplers.TPESampler(),
        n_trials=N_TRIALS,
        direction=model_config.objective.direction,
        storage=storage,
    )
    if study is None:
        raise ValueError("Study is None")
    logger.info(f"Study: {study}")
    logger.info(f"Study best trial: {study.best_trial}")
    logger.info(f"Study direction: {study.direction}")
    logger.info(f"Study best value: {study.best_value}")
    logger.info(f"Study best params: {study.best_params}")
    logger.info(f"Study trials count: {len(study.trials)}")

    for artifact_meta in optuna.artifacts.get_all_artifact_meta(study_or_trial=study):
        logger.info(artifact_meta)
    # Download the best model
    trial = study.best_trial
    best_artifact_id = trial.user_attrs["model_id"]
    file_path = trial.user_attrs["model_path"]
    optuna.artifacts.download_artifact(
        artifact_store=artifact_store,
        file_path=file_path,
        artifact_id=best_artifact_id,
    )

    logger.info(f"Best model downloaded successfully to {file_path}")

    return base_path, file_path