nano_gpt.devices

Library for working with hardware devices.

 1"""Library for working with hardware devices."""
 2
 3import torch
 4import logging
 5
 6
 7_LOGGER = logging.getLogger(__name__)
 8
 9
10def get_device() -> str:
11    """Pick the best device available."""
12    device = "cpu"
13    if torch.cuda.is_available():
14        device = "cuda"
15    elif torch.backends.mps.is_available():
16        device = "mps"
17    return device
18
19
20def get_dtype(device: str) -> torch.dtype:
21    """Get the type for the device."""
22    if device == "mps":
23        return torch.float16  # bfloat16 not supported on MPS
24    if device == "cuda":
25        return torch.bfloat16
26    return torch.bfloat16
def get_device() -> str:
11def get_device() -> str:
12    """Pick the best device available."""
13    device = "cpu"
14    if torch.cuda.is_available():
15        device = "cuda"
16    elif torch.backends.mps.is_available():
17        device = "mps"
18    return device

Pick the best device available.

def get_dtype(device: str) -> torch.dtype:
21def get_dtype(device: str) -> torch.dtype:
22    """Get the type for the device."""
23    if device == "mps":
24        return torch.float16  # bfloat16 not supported on MPS
25    if device == "cuda":
26        return torch.bfloat16
27    return torch.bfloat16

Get the type for the device.