Skip to content

protocols

Protocol definitions for stimulus models.

Classes:

  • StimulusModel

    Protocol for stimulus models with batch and inference methods.

StimulusModel

Bases: Protocol

Protocol for stimulus models with batch and inference methods.

Methods:

  • batch

    Process a batch and return loss and metrics.

  • eval

    Set model to evaluation mode.

  • inference

    Run inference on a batch and return loss and metrics.

  • to

    Move model to device.

  • train_batch

    Process a training batch and return loss, metrics, and optionally per-sample data.

batch

batch(
    batch: dict[str, Tensor],
    optimizer: Any | None = None,
    **loss_dict: Any
) -> tuple[Tensor, dict[str, Any]]

Process a batch and return loss and metrics.

Parameters:

  • batch (dict[str, Tensor]) –

    Dictionary of input tensors

  • optimizer (Any | None, default: None ) –

    Optional optimizer for training

  • **loss_dict (Any, default: {} ) –

    Additional loss function arguments

Returns:

  • tuple[Tensor, dict[str, Any]]

    Tuple of (loss tensor, metrics dictionary)

Source code in src/stimulus/typing/protocols.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def batch(
    self,
    batch: dict[str, Tensor],
    optimizer: Any | None = None,
    **loss_dict: Any,
) -> tuple[Tensor, dict[str, Any]]:
    """Process a batch and return loss and metrics.

    Args:
        batch: Dictionary of input tensors
        optimizer: Optional optimizer for training
        **loss_dict: Additional loss function arguments

    Returns:
        Tuple of (loss tensor, metrics dictionary)
    """
    ...

eval

eval() -> None

Set model to evaluation mode.

Source code in src/stimulus/typing/protocols.py
65
66
67
def eval(self) -> None:
    """Set model to evaluation mode."""
    ...

inference

inference(
    batch: dict[str, Tensor], **loss_dict: Any
) -> tuple[Tensor, dict[str, Any]]

Run inference on a batch and return loss and metrics.

Parameters:

  • batch (dict[str, Tensor]) –

    Dictionary of input tensors

  • **loss_dict (Any, default: {} ) –

    Additional loss function arguments

Returns:

  • tuple[Tensor, dict[str, Any]]

    Tuple of (loss tensor, metrics dictionary)

Source code in src/stimulus/typing/protocols.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def inference(
    self,
    batch: dict[str, Tensor],
    **loss_dict: Any,
) -> tuple[Tensor, dict[str, Any]]:
    """Run inference on a batch and return loss and metrics.

    Args:
        batch: Dictionary of input tensors
        **loss_dict: Additional loss function arguments

    Returns:
        Tuple of (loss tensor, metrics dictionary)
    """
    ...

to

to(device: device) -> StimulusModel

Move model to device.

Parameters:

  • device (device) –

    Target device

Returns:

Source code in src/stimulus/typing/protocols.py
69
70
71
72
73
74
75
76
77
78
def to(self, device: torch.device) -> "StimulusModel":
    """Move model to device.

    Args:
        device: Target device

    Returns:
        Model on the target device
    """
    ...

train_batch

train_batch(
    batch: dict[str, Tensor],
    optimizer: Any,
    **loss_dict: Any
) -> (
    tuple[Tensor, dict[str, Any]]
    | tuple[Tensor, dict[str, Any], dict[str, Any]]
)

Process a training batch and return loss, metrics, and optionally per-sample data.

Parameters:

  • batch (dict[str, Tensor]) –

    Dictionary of input tensors

  • optimizer (Any) –

    Optimizer for training

  • **loss_dict (Any, default: {} ) –

    Additional loss function arguments

Returns:

Source code in src/stimulus/typing/protocols.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def train_batch(
    self,
    batch: dict[str, Tensor],
    optimizer: Any,
    **loss_dict: Any,
) -> tuple[Tensor, dict[str, Any]] | tuple[Tensor, dict[str, Any], dict[str, Any]]:
    """Process a training batch and return loss, metrics, and optionally per-sample data.

    Args:
        batch: Dictionary of input tensors
        optimizer: Optimizer for training
        **loss_dict: Additional loss function arguments

    Returns:
        Tuple of (loss tensor, metrics dictionary) or
        Tuple of (loss tensor, metrics dictionary, per-sample dictionary)
    """
    ...