Skip to content

generic_utils

Utility functions for general purpose operations like seed setting and tensor manipulation.

Functions:

  • ensure_at_least_1d

    Function to make sure tensors given are not zero dimensional. if they are add one dimension.

  • set_general_seeds

    Set all relevant random seeds to a given value.

ensure_at_least_1d

ensure_at_least_1d(tensor: Tensor) -> Tensor

Function to make sure tensors given are not zero dimensional. if they are add one dimension.

Source code in src/stimulus/utils/generic_utils.py
10
11
12
13
14
def ensure_at_least_1d(tensor: torch.Tensor) -> torch.Tensor:
    """Function to make sure tensors given are not zero dimensional. if they are add one dimension."""
    if tensor.dim() == 0:
        tensor = tensor.unsqueeze(0)
    return tensor

set_general_seeds

set_general_seeds(seed_value: Union[int, None]) -> None

Set all relevant random seeds to a given value.

Especially useful in case of ray.tune. Ray does not have a "generic" seed as far as ray 2.23.

Source code in src/stimulus/utils/generic_utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def set_general_seeds(seed_value: Union[int, None]) -> None:
    """Set all relevant random seeds to a given value.

    Especially useful in case of ray.tune. Ray does not have a "generic" seed as far as ray 2.23.
    """
    # Set python seed
    random.seed(seed_value)

    # set numpy seed
    np.random.seed(seed_value)

    # set torch seed, diffrently from the two above torch can nopt take Noneas input value so it will not be called in that case.
    if seed_value is not None:
        torch.manual_seed(seed_value)