protocols ¶
Protocol definitions for stimulus models.
Classes:
-
StimulusModel
–Protocol for stimulus models with batch and inference methods.
StimulusModel ¶
Bases: Protocol
Protocol for stimulus models with batch and inference methods.
Methods:
-
batch
–Process a batch and return loss and metrics.
-
eval
–Set model to evaluation mode.
-
inference
–Run inference on a batch and return loss and metrics.
-
to
–Move model to device.
-
train_batch
–Process a training batch and return loss, metrics, and optionally per-sample data.
batch ¶
batch(
batch: dict[str, Tensor],
optimizer: Any | None = None,
**loss_dict: Any
) -> tuple[Tensor, dict[str, Any]]
Process a batch and return loss and metrics.
Parameters:
-
batch
(dict[str, Tensor]
) –Dictionary of input tensors
-
optimizer
(Any | None
, default:None
) –Optional optimizer for training
-
**loss_dict
(Any
, default:{}
) –Additional loss function arguments
Returns:
Source code in src/stimulus/typing/protocols.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
|
eval ¶
eval() -> None
Set model to evaluation mode.
Source code in src/stimulus/typing/protocols.py
65 66 67 |
|
inference ¶
Run inference on a batch and return loss and metrics.
Parameters:
-
batch
(dict[str, Tensor]
) –Dictionary of input tensors
-
**loss_dict
(Any
, default:{}
) –Additional loss function arguments
Returns:
Source code in src/stimulus/typing/protocols.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
to ¶
to(device: device) -> StimulusModel
Move model to device.
Parameters:
-
device
(device
) –Target device
Returns:
-
StimulusModel
–Model on the target device
Source code in src/stimulus/typing/protocols.py
69 70 71 72 73 74 75 76 77 78 |
|
train_batch ¶
train_batch(
batch: dict[str, Tensor],
optimizer: Any,
**loss_dict: Any
) -> (
tuple[Tensor, dict[str, Any]]
| tuple[Tensor, dict[str, Any], dict[str, Any]]
)
Process a training batch and return loss, metrics, and optionally per-sample data.
Parameters:
-
batch
(dict[str, Tensor]
) –Dictionary of input tensors
-
optimizer
(Any
) –Optimizer for training
-
**loss_dict
(Any
, default:{}
) –Additional loss function arguments
Returns:
-
tuple[Tensor, dict[str, Any]] | tuple[Tensor, dict[str, Any], dict[str, Any]]
–Tuple of (loss tensor, metrics dictionary) or
-
tuple[Tensor, dict[str, Any]] | tuple[Tensor, dict[str, Any], dict[str, Any]]
–Tuple of (loss tensor, metrics dictionary, per-sample dictionary)
Source code in src/stimulus/typing/protocols.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
|