Skip to content

device_utils

Device utilities for PyTorch model training and inference.

Functions:

  • get_device

    Get the appropriate device (CPU/GPU) for computation.

  • resolve_device

    Resolve device based on priority: force_device > config_device > auto-detection.

get_device

get_device() -> device

Get the appropriate device (CPU/GPU) for computation.

Returns:

  • device

    torch.device: The selected computation device

Source code in src/stimulus/learner/device_utils.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_device() -> torch.device:
    """Get the appropriate device (CPU/GPU) for computation.

    Returns:
        torch.device: The selected computation device
    """
    if torch.backends.mps.is_available():
        try:
            # Try to allocate a small tensor on MPS to check if it works
            device = torch.device("mps")
            # Create a small tensor and move it to MPS as a test
            test_tensor = torch.ones((1, 1)).to(device)
            del test_tensor  # Free the memory
            logger.info("Using MPS (Metal Performance Shaders) device")
        except RuntimeError as e:
            logger.warning(f"MPS available but failed to initialize: {e}")
            logger.warning("Falling back to CPU")
            return torch.device("cpu")
        else:
            return device

    if torch.cuda.is_available():
        device = torch.device("cuda")
        gpu_name = torch.cuda.get_device_name(0)
        memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        logger.info(f"Using GPU: {gpu_name} with {memory:.2f} GB memory")
        return device

    logger.info("Using CPU (GPU not available)")
    return torch.device("cpu")

resolve_device

resolve_device(
    force_device: Optional[str] = None,
    config_device: Optional[str] = None,
) -> device

Resolve device based on priority: force_device > config_device > auto-detection.

Parameters:

  • force_device (Optional[str], default: None ) –

    Device specified via CLI or function parameter (highest priority).

  • config_device (Optional[str], default: None ) –

    Device specified in model configuration (medium priority).

Returns:

  • device

    torch.device: The resolved computation device.

Raises:

  • RuntimeError

    If a forced or configured device is invalid or unavailable.

Source code in src/stimulus/learner/device_utils.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def resolve_device(force_device: Optional[str] = None, config_device: Optional[str] = None) -> torch.device:
    """Resolve device based on priority: force_device > config_device > auto-detection.

    Args:
        force_device: Device specified via CLI or function parameter (highest priority).
        config_device: Device specified in model configuration (medium priority).

    Returns:
        torch.device: The resolved computation device.

    Raises:
        RuntimeError: If a forced or configured device is invalid or unavailable.
    """
    if force_device is not None:
        try:
            device = torch.device(force_device)
        except RuntimeError as e:
            raise RuntimeError(
                f"Forced device '{force_device}' is not available. Please use a valid device.",
            ) from e
        else:
            logger.info(f"Using force-specified device: {force_device}")
            return device

    if config_device is not None:
        try:
            device = torch.device(config_device)
        except RuntimeError as e:
            raise RuntimeError(
                f"Device '{config_device}' specified in model configuration is not available. Please use a valid device.",
            ) from e
        else:
            logger.info(f"Using config-specified device: {config_device}")
            return device

    return get_device()