Skip to content

model_schema

Module for validating YAML configuration files.

Classes:

Loss

Bases: BaseModel

Loss parameters.

Model

Bases: BaseModel

Model configuration.

Methods:

validate_data_params

validate_data_params() -> Model

Validate that data_params contains batch_size.

Source code in src/stimulus/learner/interface/model_schema.py
180
181
182
183
184
185
@pydantic.model_validator(mode="after")
def validate_data_params(self) -> "Model":
    """Validate that data_params contains batch_size."""
    if "batch_size" not in self.data_params:
        raise ValueError("data_params must contain batch_size")
    return self

validate_input classmethod

validate_input(data: dict[str, Any]) -> dict[str, Any]

Print input data for debugging.

Source code in src/stimulus/learner/interface/model_schema.py
173
174
175
176
177
178
@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, data: dict[str, Any]) -> dict[str, Any]:
    """Print input data for debugging."""
    logger.info(f"Input data for Model: {data}")
    return data

Objective

Bases: BaseModel

Objective parameters.

Methods:

validate_direction

validate_direction() -> Objective

Validate that direction is supported by Optuna.

Source code in src/stimulus/learner/interface/model_schema.py
113
114
115
116
117
118
119
120
@pydantic.model_validator(mode="after")
def validate_direction(self) -> "Objective":
    """Validate that direction is supported by Optuna."""
    if self.direction not in ["minimize", "maximize"]:
        raise NotImplementedError(
            f"Direction {self.direction} not available for Optuna, please use one of the following: minimize, maximize",
        )
    return self

Pruner

Bases: BaseModel

Pruner parameters.

Methods:

validate_pruner

validate_pruner() -> Pruner

Validate that pruner is supported by Optuna.

Source code in src/stimulus/learner/interface/model_schema.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@pydantic.model_validator(mode="after")
def validate_pruner(self) -> "Pruner":
    """Validate that pruner is supported by Optuna."""
    # Get available pruners, filtering out internal ones that start with _
    available_pruners = [
        attr for attr in dir(optuna.pruners) if not attr.startswith("_") and attr != "TYPE_CHECKING"
    ]
    logger.info(f"Available pruners in Optuna: {available_pruners}")

    # Check if the pruner exists with correct case
    if not hasattr(optuna.pruners, self.name):
        # Try to find a case-insensitive match
        case_matches = [attr for attr in available_pruners if attr.lower() == self.name.lower()]
        if case_matches:
            logger.info(f"Found matching pruner with different case: {case_matches[0]}")
            self.name = case_matches[0]  # Use the correct case
        else:
            raise ValueError(
                f"Pruner '{self.name}' not available in Optuna. Available pruners: {available_pruners}",
            )
    return self

Sampler

Bases: BaseModel

Sampler parameters.

Methods:

validate_sampler

validate_sampler() -> Sampler

Validate that sampler is supported by Optuna.

Source code in src/stimulus/learner/interface/model_schema.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@pydantic.model_validator(mode="after")
def validate_sampler(self) -> "Sampler":
    """Validate that sampler is supported by Optuna."""
    # Get available samplers, filtering out internal ones that start with _
    available_samplers = [
        attr for attr in dir(optuna.samplers) if not attr.startswith("_") and attr != "TYPE_CHECKING"
    ]
    logger.info(f"Available samplers in Optuna: {available_samplers}")

    if not hasattr(optuna.samplers, self.name):
        # Try to find a case-insensitive match
        case_matches = [attr for attr in available_samplers if attr.lower() == self.name.lower()]
        if case_matches:
            logger.info(f"Found matching sampler with different case: {case_matches[0]}")
            self.name = case_matches[0]  # Use the correct case
        else:
            raise ValueError(
                f"Sampler '{self.name}' not available in Optuna. Available samplers: {available_samplers}",
            )
    return self

TunableParameter

Bases: BaseModel

Tunable parameter.

Methods:

  • validate_mode

    Validate that mode is supported by Optuna or custom methods.

  • validate_params

    Validate that the params are supported by Optuna.

validate_mode

validate_mode() -> TunableParameter

Validate that mode is supported by Optuna or custom methods.

Source code in src/stimulus/learner/interface/model_schema.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@pydantic.model_validator(mode="after")
def validate_mode(self) -> "TunableParameter":
    """Validate that mode is supported by Optuna or custom methods."""
    if self.mode not in [
        "categorical",
        "discrete_uniform",
        "float",
        "int",
        "loguniform",
        "uniform",
    ]:
        raise NotImplementedError(
            f"Mode {self.mode} not available for Optuna, please use one of the following: categorical, discrete_uniform, float, int, loguniform, uniform, variable_list",
        )

    return self

validate_params

validate_params() -> TunableParameter

Validate that the params are supported by Optuna.

Source code in src/stimulus/learner/interface/model_schema.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@pydantic.model_validator(mode="after")
def validate_params(self) -> "TunableParameter":
    """Validate that the params are supported by Optuna."""
    trial_methods: dict[str, Callable] = {
        "categorical": optuna.trial.Trial.suggest_categorical,
        "discrete_uniform": optuna.trial.Trial.suggest_discrete_uniform,
        "float": optuna.trial.Trial.suggest_float,
        "int": optuna.trial.Trial.suggest_int,
        "loguniform": optuna.trial.Trial.suggest_loguniform,
        "uniform": optuna.trial.Trial.suggest_uniform,
    }
    if self.mode in trial_methods:
        sig = inspect.signature(trial_methods[self.mode])
        required_params = {
            param.name
            for param in sig.parameters.values()
            if param.default is inspect.Parameter.empty
            and param.kind not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
            and param.name not in ("self", "trial", "name")
        }
        missing_params = required_params - set(self.params.keys())
        if missing_params:
            raise ValueError(f"Missing required params for mode '{self.mode}': {missing_params}")
    return self

VariableList

Bases: BaseModel

Variable list.

Methods:

validate_length

validate_length() -> VariableList

Validate that length is supported by Optuna.

Source code in src/stimulus/learner/interface/model_schema.py
69
70
71
72
73
74
75
def validate_length(self) -> "VariableList":
    """Validate that length is supported by Optuna."""
    if self.length.mode not in ["int"]:
        raise ValueError(
            f"length mode has to be set to int, got {self.length.mode}",
        )
    return self