nano_gpt.devices
Library for working with hardware devices.
1"""Library for working with hardware devices.""" 2 3import torch 4import logging 5 6 7_LOGGER = logging.getLogger(__name__) 8 9 10def get_device() -> str: 11 """Pick the best device available.""" 12 device = "cpu" 13 if torch.cuda.is_available(): 14 device = "cuda" 15 elif torch.backends.mps.is_available(): 16 device = "mps" 17 return device 18 19 20def get_dtype(device: str) -> torch.dtype: 21 """Get the type for the device.""" 22 if device == "mps": 23 return torch.float16 # bfloat16 not supported on MPS 24 if device == "cuda": 25 return torch.bfloat16 26 return torch.bfloat16
def
get_device() -> str:
11def get_device() -> str: 12 """Pick the best device available.""" 13 device = "cpu" 14 if torch.cuda.is_available(): 15 device = "cuda" 16 elif torch.backends.mps.is_available(): 17 device = "mps" 18 return device
Pick the best device available.
def
get_dtype(device: str) -> torch.dtype:
21def get_dtype(device: str) -> torch.dtype: 22 """Get the type for the device.""" 23 if device == "mps": 24 return torch.float16 # bfloat16 not supported on MPS 25 if device == "cuda": 26 return torch.bfloat16 27 return torch.bfloat16
Get the type for the device.