protocols ¶
Protocol definitions for stimulus models.
Classes:
-
StimulusModel–Protocol for stimulus models with batch and inference methods.
StimulusModel ¶
Bases: Protocol
Protocol for stimulus models with batch and inference methods.
Methods:
-
batch–Process a batch and return loss and metrics.
-
eval–Set model to evaluation mode.
-
inference–Run inference on a batch and return loss and metrics.
-
to–Move model to device.
-
train_batch–Process a training batch and return loss, metrics, and optionally per-sample data.
batch ¶
batch(
batch: dict[str, Tensor],
optimizer: Any | None = None,
**loss_dict: Any
) -> tuple[Tensor, dict[str, Any]]
Process a batch and return loss and metrics.
Parameters:
-
batch(dict[str, Tensor]) –Dictionary of input tensors
-
optimizer(Any | None, default:None) –Optional optimizer for training
-
**loss_dict(Any, default:{}) –Additional loss function arguments
Returns:
Source code in src/stimulus/typing/protocols.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | |
eval ¶
eval() -> None
Set model to evaluation mode.
Source code in src/stimulus/typing/protocols.py
65 66 67 | |
inference ¶
Run inference on a batch and return loss and metrics.
Parameters:
-
batch(dict[str, Tensor]) –Dictionary of input tensors
-
**loss_dict(Any, default:{}) –Additional loss function arguments
Returns:
Source code in src/stimulus/typing/protocols.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | |
to ¶
to(device: device) -> StimulusModel
Move model to device.
Parameters:
-
device(device) –Target device
Returns:
-
StimulusModel–Model on the target device
Source code in src/stimulus/typing/protocols.py
69 70 71 72 73 74 75 76 77 78 | |
train_batch ¶
train_batch(
batch: dict[str, Tensor],
optimizer: Any,
**loss_dict: Any
) -> (
tuple[Tensor, dict[str, Any]]
| tuple[Tensor, dict[str, Any], dict[str, Any]]
)
Process a training batch and return loss, metrics, and optionally per-sample data.
Parameters:
-
batch(dict[str, Tensor]) –Dictionary of input tensors
-
optimizer(Any) –Optimizer for training
-
**loss_dict(Any, default:{}) –Additional loss function arguments
Returns:
-
tuple[Tensor, dict[str, Any]] | tuple[Tensor, dict[str, Any], dict[str, Any]]–Tuple of (loss tensor, metrics dictionary) or
-
tuple[Tensor, dict[str, Any]] | tuple[Tensor, dict[str, Any], dict[str, Any]]–Tuple of (loss tensor, metrics dictionary, per-sample dictionary)
Source code in src/stimulus/typing/protocols.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | |