Skip to content

api

Stimulus Python API module.

This module provides Python functions that wrap CLI functionality for direct use in Python scripts. All functions work with in-memory objects (HuggingFace datasets, PyTorch models, configuration dictionaries) instead of requiring file I/O operations.

Usage Examples:

Basic Data Processing

import stimulus
from stimulus.api import (
    create_encoders_from_config,
    create_splitter_from_config,
)

# Load your data
dataset = datasets.load_dataset("csv", data_files="data.csv")

# Create encoders from config dict
encoder_config = {
    "columns": [
        {
            "column_name": "category",
            "column_type": "input",
            "encoder": [{"name": "LabelEncoder", "params": {"dtype": "int64"}}],
        }
    ]
}
encoders = create_encoders_from_config(encoder_config)

# Encode the dataset
encoded_dataset = stimulus.encode(dataset, encoders)

# Split the dataset
splitter_config = {
    "split": {
        "split_method": "RandomSplitter",
        "params": {"test_ratio": 0.2, "random_state": 42},
        "split_input_columns": ["category"],
    }
}
splitter, split_columns = create_splitter_from_config(splitter_config)
split_dataset = stimulus.split(encoded_dataset, splitter, split_columns)

Model Training and Prediction

# Define your model class
class MyModel(torch.nn.Module):
    def __init__(self, hidden_size=64):
        super().__init__()
        self.layer = torch.nn.Linear(10, hidden_size)
        # ... rest of model definition

    def batch(self, batch, optimizer=None, **loss_dict):
        # ... implement forward pass and training logic
        return loss, metrics


# Create model config
model_config = model_schema.Model(
    model_params={
        "hidden_size": model_schema.TunableParameter(
            mode="int", params={"low": 32, "high": 128}
        )
    },
    optimizer={
        "method": model_schema.TunableParameter(
            mode="categorical", params={"choices": ["Adam", "SGD"]}
        )
    },
    # ... other config
)

# Tune hyperparameters
best_config, best_model, metrics = stimulus.tune(
    dataset=split_dataset,
    model_class=MyModel,
    model_config=model_config,
    n_trials=20,
)

# Make predictions
predictions = stimulus.predict(split_dataset, best_model)

Modules:

  • api

    Python API for Stimulus CLI functions.

Functions:

check_model

check_model(
    dataset: DatasetDict,
    model_class: type[Module],
    model_config: Model,
    n_trials: int = 3,
    max_samples: int = 100,
    force_device: Optional[str] = None,
) -> tuple[dict[str, Any], Module]

Check model configuration and run initial tests.

Validates that a model can be loaded and trained with the given configuration. Performs a small-scale hyperparameter tuning run to verify everything works.

Parameters:

  • dataset (DatasetDict) –

    HuggingFace dataset containing train/test splits.

  • model_class (type[Module]) –

    PyTorch model class to check.

  • model_config (Model) –

    Model configuration with tunable parameters.

  • n_trials (int, default: 3 ) –

    Number of trials for validation (default: 3).

  • max_samples (int, default: 100 ) –

    Maximum samples per trial (default: 100).

  • force_device (Optional[str], default: None ) –

    Force specific device ("cpu", "cuda", "mps") (default: None for auto).

Returns:

Example

config, model = check_model(dataset, MyModel, model_config) print("Model validation successful!")

Source code in src/stimulus/api/api.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def check_model(
    dataset: datasets.DatasetDict,
    model_class: type[torch.nn.Module],
    model_config: model_schema.Model,
    n_trials: int = 3,
    max_samples: int = 100,
    force_device: Optional[str] = None,
) -> tuple[dict[str, Any], torch.nn.Module]:
    """Check model configuration and run initial tests.

    Validates that a model can be loaded and trained with the given configuration.
    Performs a small-scale hyperparameter tuning run to verify everything works.

    Args:
        dataset: HuggingFace dataset containing train/test splits.
        model_class: PyTorch model class to check.
        model_config: Model configuration with tunable parameters.
        n_trials: Number of trials for validation (default: 3).
        max_samples: Maximum samples per trial (default: 100).
        force_device: Force specific device ("cpu", "cuda", "mps") (default: None for auto).

    Returns:
        Tuple of (best_config, best_model).

    Example:
        >>> config, model = check_model(dataset, MyModel, model_config)
        >>> print("Model validation successful!")
    """
    best_config, best_model, _metrics = tune(
        dataset=dataset,
        model_class=model_class,
        model_config=model_config,
        n_trials=n_trials,
        max_samples=max_samples,
        target_metric="val_loss",
        direction="minimize",
        force_device=force_device,
    )

    logger.info("Model check completed successfully!")
    return best_config, best_model

compare_tensors

compare_tensors(
    tensor_dicts: list[dict[str, Tensor]],
    mode: str = "cosine_similarity",
) -> dict[str, list[float]]

Compare prediction tensors using various similarity metrics.

Parameters:

  • tensor_dicts (list[dict[str, Tensor]]) –

    List of tensor dictionaries to compare.

  • mode (str, default: 'cosine_similarity' ) –

    Comparison mode ("cosine_similarity" or "discrete_comparison", default: "cosine_similarity").

Returns:

Example

results = compare_tensors([pred1, pred2], mode="cosine_similarity") print(results["cosine_similarity"])

Source code in src/stimulus/api/api.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def compare_tensors(
    tensor_dicts: list[dict[str, torch.Tensor]],
    mode: str = "cosine_similarity",
) -> dict[str, list[float]]:
    """Compare prediction tensors using various similarity metrics.

    Args:
        tensor_dicts: List of tensor dictionaries to compare.
        mode: Comparison mode ("cosine_similarity" or "discrete_comparison", default: "cosine_similarity").

    Returns:
        Dictionary containing comparison results.

    Example:
        >>> results = compare_tensors([pred1, pred2], mode="cosine_similarity")
        >>> print(results["cosine_similarity"])
    """
    results: dict[str, list[float]] = defaultdict(list)

    for i in range(len(tensor_dicts)):
        for j in range(i + 1, len(tensor_dicts)):
            tensor1 = tensor_dicts[i]
            tensor2 = tensor_dicts[j]
            tensor_comparison = compare_tensors_impl(tensor1, tensor2, mode)

            for key, tensor in tensor_comparison.items():
                if tensor.ndim == 0:
                    results[key].append(tensor.item())
                else:
                    results[key].append(tensor.mean().item())

    return dict(results)

create_encoders_from_config

create_encoders_from_config(
    config_dict: dict,
) -> dict[str, Any]

Create encoders from a configuration dictionary.

Parameters:

  • config_dict (dict) –

    Configuration dictionary matching SplitTransformDict schema.

Returns:

  • dict[str, Any]

    Dictionary mapping column names to encoder instances.

Source code in src/stimulus/api/api.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def create_encoders_from_config(config_dict: dict) -> dict[str, Any]:
    """Create encoders from a configuration dictionary.

    Args:
        config_dict: Configuration dictionary matching SplitTransformDict schema.

    Returns:
        Dictionary mapping column names to encoder instances.
    """
    data_config_obj = data_config_parser.SplitTransformDict(**config_dict)
    encoders, _input_columns, _label_columns, _meta_columns = data_config_parser.parse_split_transform_config(
        data_config_obj,
    )
    return encoders

create_splitter_from_config

create_splitter_from_config(
    config_dict: dict,
) -> tuple[AbstractSplitter, list[str]]

Create a splitter from a configuration dictionary.

Parameters:

  • config_dict (dict) –

    Configuration dictionary matching SplitConfigDict schema.

Returns:

Source code in src/stimulus/api/api.py
422
423
424
425
426
427
428
429
430
431
432
433
def create_splitter_from_config(config_dict: dict) -> tuple[splitters.AbstractSplitter, list[str]]:
    """Create a splitter from a configuration dictionary.

    Args:
        config_dict: Configuration dictionary matching SplitConfigDict schema.

    Returns:
        Tuple of (splitter_instance, split_columns).
    """
    data_config_obj = data_config_parser.SplitConfigDict(**config_dict)
    splitter = data_config_parser.create_splitter(data_config_obj.split)
    return splitter, data_config_obj.split.split_input_columns

create_transforms_from_config

create_transforms_from_config(
    config_dict: dict,
) -> dict[str, list[Any]]

Create transforms from a configuration dictionary.

Parameters:

  • config_dict (dict) –

    Configuration dictionary matching SplitTransformDict schema.

Returns:

  • dict[str, list[Any]]

    Dictionary mapping column names to lists of transform instances.

Source code in src/stimulus/api/api.py
436
437
438
439
440
441
442
443
444
445
446
def create_transforms_from_config(config_dict: dict) -> dict[str, list[Any]]:
    """Create transforms from a configuration dictionary.

    Args:
        config_dict: Configuration dictionary matching SplitTransformDict schema.

    Returns:
        Dictionary mapping column names to lists of transform instances.
    """
    data_config_obj = data_config_parser.SplitTransformDict(**config_dict)
    return data_config_parser.create_transforms([data_config_obj.transforms])

encode

encode(
    dataset: DatasetDict,
    encoders: dict[str, Any],
    num_proc: Optional[int] = None,
    *,
    remove_unencoded_columns: bool = True
) -> DatasetDict

Encode a dataset using the provided encoders.

Parameters:

  • dataset (DatasetDict) –

    HuggingFace dataset to encode.

  • encoders (dict[str, Any]) –

    Dictionary mapping column names to encoder instances.

  • num_proc (Optional[int], default: None ) –

    Number of processes to use for encoding (default: None for single process).

  • remove_unencoded_columns (bool, default: True ) –

    Whether to remove columns not in encoders config (default: True).

Returns:

  • DatasetDict

    The encoded HuggingFace dataset.

Example

from stimulus.data.encoding.encoders import LabelEncoder encoders = {"category": LabelEncoder(dtype="int64")} encoded_dataset = encode(dataset, encoders)

Source code in src/stimulus/api/api.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def encode(
    dataset: datasets.DatasetDict,
    encoders: dict[str, Any],
    num_proc: Optional[int] = None,
    *,
    remove_unencoded_columns: bool = True,
) -> datasets.DatasetDict:
    """Encode a dataset using the provided encoders.

    Args:
        dataset: HuggingFace dataset to encode.
        encoders: Dictionary mapping column names to encoder instances.
        num_proc: Number of processes to use for encoding (default: None for single process).
        remove_unencoded_columns: Whether to remove columns not in encoders config (default: True).

    Returns:
        The encoded HuggingFace dataset.

    Example:
        >>> from stimulus.data.encoding.encoders import LabelEncoder
        >>> encoders = {"category": LabelEncoder(dtype="int64")}
        >>> encoded_dataset = encode(dataset, encoders)
    """
    # Set format to numpy for processing
    dataset.set_format(type="numpy")

    logger.info(f"Loaded encoders for columns: {list(encoders.keys())}")

    # Identify columns that aren't in the encoder configuration
    columns_to_remove = set()
    if remove_unencoded_columns:
        for split_name, split_dataset in dataset.items():
            dataset_columns = set(split_dataset.column_names)
            encoder_columns = set(encoders.keys())
            columns_to_remove.update(dataset_columns - encoder_columns)
            logger.info(
                f"Removing columns not in encoder configuration from {split_name} split: {list(columns_to_remove)}",
            )

    # Apply the encoders to the data
    dataset = dataset.map(
        encode_csv_cli.encode_batch,
        batched=True,
        fn_kwargs={"encoders_config": encoders},
        remove_columns=list(columns_to_remove),
        num_proc=num_proc,
    )

    logger.info("Dataset encoded successfully.")
    return dataset

load_model_from_files

load_model_from_files(
    model_path: str, config_path: str, weights_path: str
) -> Module

Load a model from files (convenience function for predict API).

Parameters:

  • model_path (str) –

    Path to the model Python file.

  • config_path (str) –

    Path to the model configuration JSON file.

  • weights_path (str) –

    Path to the model weights file (.safetensors).

Returns:

  • Module

    Loaded PyTorch model instance.

Source code in src/stimulus/api/api.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def load_model_from_files(model_path: str, config_path: str, weights_path: str) -> torch.nn.Module:
    """Load a model from files (convenience function for predict API).

    Args:
        model_path: Path to the model Python file.
        config_path: Path to the model configuration JSON file.
        weights_path: Path to the model weights file (.safetensors).

    Returns:
        Loaded PyTorch model instance.
    """
    with open(config_path) as f:
        best_config = json.load(f)

    model_class = import_class_from_file(model_path)
    model_instance = model_class(**best_config)

    weights = load_file(weights_path)
    model_instance.load_state_dict(weights)
    return model_instance

predict

predict(
    dataset: DatasetDict,
    model: StimulusModel,
    batch_size: int = 256,
) -> dict[str, Tensor]

Run model prediction on the dataset.

Parameters:

  • dataset (DatasetDict) –

    HuggingFace dataset to predict on.

  • model (StimulusModel) –

    PyTorch model instance (already loaded with weights).

  • batch_size (int, default: 256 ) –

    Batch size for prediction (default: 256).

Returns:

  • dict[str, Tensor]

    Dictionary containing prediction tensors and statistics.

Example

predictions = predict(test_dataset, trained_model) print(predictions["predictions"])

Source code in src/stimulus/api/api.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def predict(
    dataset: datasets.DatasetDict,
    model: StimulusModel,
    batch_size: int = 256,
) -> dict[str, torch.Tensor]:
    """Run model prediction on the dataset.

    Args:
        dataset: HuggingFace dataset to predict on.
        model: PyTorch model instance (already loaded with weights).
        batch_size: Batch size for prediction (default: 256).

    Returns:
        Dictionary containing prediction tensors and statistics.

    Example:
        >>> predictions = predict(test_dataset, trained_model)
        >>> print(predictions["predictions"])
    """
    dataset.set_format(type="torch")
    splits = [dataset[split_name] for split_name in dataset]
    all_splits = datasets.concatenate_datasets(splits)
    loader = torch.utils.data.DataLoader(all_splits, batch_size=batch_size, shuffle=False)

    # create empty tensor for predictions
    is_first_batch = True
    model.eval()

    with torch.no_grad():
        for batch in loader:
            if is_first_batch:
                _loss, statistics = model.batch(batch)
                is_first_batch = False
            else:
                _loss, temp_statistics = model.batch(batch)
                statistics = _update_statistics(statistics, temp_statistics)

    return _convert_dict_to_tensor(statistics)

split

split(
    dataset: DatasetDict,
    splitter: AbstractSplitter,
    split_columns: list[str],
    *,
    force: bool = False
) -> DatasetDict

Split a dataset using the provided splitter.

Parameters:

  • dataset (DatasetDict) –

    HuggingFace dataset to split.

  • splitter (AbstractSplitter) –

    Splitter instance (e.g., RandomSplitter, StratifiedSplitter).

  • split_columns (list[str]) –

    List of column names to use for splitting logic.

  • force (bool, default: False ) –

    Overwrite existing test split if it exists (default: False).

Returns:

  • DatasetDict

    Dataset with train/test splits.

Example

from stimulus.data.splitting.splitters import RandomSplitter splitter = RandomSplitter(test_ratio=0.2, random_state=42) split_dataset = split(dataset, splitter, ["target_column"])

Source code in src/stimulus/api/api.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def split(
    dataset: datasets.DatasetDict,
    splitter: splitters.AbstractSplitter,
    split_columns: list[str],
    *,
    force: bool = False,
) -> datasets.DatasetDict:
    """Split a dataset using the provided splitter.

    Args:
        dataset: HuggingFace dataset to split.
        splitter: Splitter instance (e.g., RandomSplitter, StratifiedSplitter).
        split_columns: List of column names to use for splitting logic.
        force: Overwrite existing test split if it exists (default: False).

    Returns:
        Dataset with train/test splits.

    Example:
        >>> from stimulus.data.splitting.splitters import RandomSplitter
        >>> splitter = RandomSplitter(test_ratio=0.2, random_state=42)
        >>> split_dataset = split(dataset, splitter, ["target_column"])
    """
    if "test" in dataset and not force:
        logger.info("Test split already exists and force was set to False. Returning existing split.")
        return dataset

    if "test" in dataset and force:
        logger.info(
            "Test split already exists and force was set to True. Merging current test split into train and recalculating splits.",
        )
        dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
        del dataset["test"]

    dataset_with_numpy_format = dataset.with_format("numpy")
    column_data_dict = {}
    for col_name in split_columns:
        try:
            column_data_dict[col_name] = dataset_with_numpy_format["train"][col_name]
        except KeyError as err:
            raise ValueError(
                f"Column '{col_name}' not found in dataset with columns {dataset_with_numpy_format['train'].column_names}",
            ) from err

    if not column_data_dict:
        raise ValueError(
            f"No data columns were extracted for splitting. Input specified columns are {split_columns}, "
            f"dataset has columns {dataset_with_numpy_format['train'].column_names}",
        )

    train_indices, test_indices = splitter.get_split_indexes(column_data_dict)

    train_dataset = dataset_with_numpy_format["train"].select(train_indices)
    test_dataset = dataset_with_numpy_format["train"].select(test_indices)

    return datasets.DatasetDict({"train": train_dataset, "test": test_dataset})

transform

transform(
    dataset: DatasetDict,
    transforms_config: dict[str, list[AbstractTransform]],
) -> DatasetDict

Transform a dataset using the provided transformations.

Parameters:

  • dataset (DatasetDict) –

    HuggingFace dataset to transform.

  • transforms_config (dict[str, list[AbstractTransform]]) –

    Dictionary mapping column names to lists of transform instances.

Returns:

  • DatasetDict

    Transformed HuggingFace dataset.

Example

from stimulus.data.transforming.transforms import NoiseTransform transforms_config = {"feature": [NoiseTransform(noise_level=0.1)]} transformed_dataset = transform(dataset, transforms_config)

Source code in src/stimulus/api/api.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def transform(
    dataset: datasets.DatasetDict,
    transforms_config: dict[str, list[transforms.AbstractTransform]],
) -> datasets.DatasetDict:
    """Transform a dataset using the provided transformations.

    Args:
        dataset: HuggingFace dataset to transform.
        transforms_config: Dictionary mapping column names to lists of transform instances.

    Returns:
        Transformed HuggingFace dataset.

    Example:
        >>> from stimulus.data.transforming.transforms import NoiseTransform
        >>> transforms_config = {"feature": [NoiseTransform(noise_level=0.1)]}
        >>> transformed_dataset = transform(dataset, transforms_config)
    """
    dataset.set_format(type="numpy")
    logger.info("Transforms initialized successfully.")

    # Apply the transformations to the data
    dataset = dataset.map(
        transform_batch,
        batched=True,
        fn_kwargs={"transforms_config": transforms_config},
    )

    # Filter out NaN values
    logger.debug(f"Dataset type: {type(dataset)}")
    dataset["train"] = dataset["train"].filter(lambda example: not any(pd.isna(value) for value in example.values()))
    dataset["test"] = dataset["test"].filter(lambda example: not any(pd.isna(value) for value in example.values()))

    return dataset

tune

tune(
    dataset: DatasetDict,
    model_class: type[Module],
    model_config: Model,
    n_trials: int = 100,
    max_samples: int = 1000,
    compute_objective_every_n_samples: int = 50,
    target_metric: str = "val_loss",
    direction: str = "minimize",
    storage: Optional[BaseStorage] = None,
    force_device: Optional[str] = None,
) -> tuple[dict[str, Any], Module, dict[str, Tensor]]

Run hyperparameter tuning using Optuna.

Parameters:

  • dataset (DatasetDict) –

    HuggingFace dataset containing train/test splits.

  • model_class (type[Module]) –

    PyTorch model class to tune.

  • model_config (Model) –

    Model configuration with tunable parameters.

  • n_trials (int, default: 100 ) –

    Number of trials to run (default: 100).

  • max_samples (int, default: 1000 ) –

    Maximum samples per trial (default: 1000).

  • compute_objective_every_n_samples (int, default: 50 ) –

    Frequency to compute objective (default: 50).

  • target_metric (str, default: 'val_loss' ) –

    Metric to optimize (default: "val_loss").

  • direction (str, default: 'minimize' ) –

    Optimization direction ("minimize" or "maximize", default: "minimize").

  • storage (Optional[BaseStorage], default: None ) –

    Optuna storage backend (default: None for in-memory).

  • force_device (Optional[str], default: None ) –

    Force specific device ("cpu", "cuda", "mps") (default: None for auto).

Returns:

Example

config, model, metrics = tune( ... dataset=train_dataset, ... model_class=MyModel, ... model_config=model_config, ... n_trials=50, ... )

Source code in src/stimulus/api/api.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def tune(
    dataset: datasets.DatasetDict,
    model_class: type[torch.nn.Module],
    model_config: model_schema.Model,
    n_trials: int = 100,
    max_samples: int = 1000,
    compute_objective_every_n_samples: int = 50,
    target_metric: str = "val_loss",
    direction: str = "minimize",
    storage: Optional[optuna.storages.BaseStorage] = None,
    force_device: Optional[str] = None,
) -> tuple[dict[str, Any], torch.nn.Module, dict[str, torch.Tensor]]:
    """Run hyperparameter tuning using Optuna.

    Args:
        dataset: HuggingFace dataset containing train/test splits.
        model_class: PyTorch model class to tune.
        model_config: Model configuration with tunable parameters.
        n_trials: Number of trials to run (default: 100).
        max_samples: Maximum samples per trial (default: 1000).
        compute_objective_every_n_samples: Frequency to compute objective (default: 50).
        target_metric: Metric to optimize (default: "val_loss").
        direction: Optimization direction ("minimize" or "maximize", default: "minimize").
        storage: Optuna storage backend (default: None for in-memory).
        force_device: Force specific device ("cpu", "cuda", "mps") (default: None for auto).

    Returns:
        Tuple of (best_config, best_model, best_metrics).

    Example:
        >>> config, model, metrics = tune(
        ...     dataset=train_dataset,
        ...     model_class=MyModel,
        ...     model_config=model_config,
        ...     n_trials=50,
        ... )
    """
    device = resolve_device(force_device=force_device, config_device=model_config.device)

    # Convert HuggingFace dataset to torch datasets
    dataset.set_format(type="torch")
    train_torch_dataset = dataset["train"]
    val_torch_dataset = dataset["test"]  # Using test as validation

    # Create temporary artifact store
    with tempfile.TemporaryDirectory() as temp_dir:
        artifact_store = optuna.artifacts.FileSystemArtifactStore(base_path=temp_dir)

        # Create objective function
        objective = optuna_tune.Objective(
            model_class=model_class,
            network_params=model_config.network_params,
            optimizer_params=model_config.optimizer_params,
            data_params=model_config.data_params,
            loss_params=model_config.loss_params,
            train_torch_dataset=train_torch_dataset,
            val_torch_dataset=val_torch_dataset,
            artifact_store=artifact_store,
            max_samples=max_samples,
            compute_objective_every_n_samples=compute_objective_every_n_samples,
            target_metric=target_metric,
            device=device,
        )

        # Get pruner and sampler
        pruner = model_config_parser.get_pruner(model_config.pruner)
        sampler = model_config_parser.get_sampler(model_config.sampler)

        # Run tuning
        study = optuna_tune.tune_loop(
            objective=objective,
            pruner=pruner,
            sampler=sampler,
            n_trials=n_trials,
            direction=direction,
            storage=storage,
        )

        # Get best trial and create best model
        best_trial = study.best_trial
        best_config = best_trial.params

        # Recreate best model
        model_suggestions = model_config_parser.suggest_parameters(best_trial, model_config.network_params)
        best_model = model_class(**model_suggestions)

        # Load best weights if available
        if "model_id" in best_trial.user_attrs:
            model_path = artifact_store.download_artifact(
                artifact_id=best_trial.user_attrs["model_id"],
                dst_path=os.path.join(temp_dir, "best_model.safetensors"),
            )
            weights = load_file(model_path)
            best_model.load_state_dict(weights)

        # Get best metrics
        best_metrics = {k: v for k, v in best_trial.user_attrs.items() if k.startswith(("train_", "val_"))}

        return best_config, best_model, best_metrics