Source code for ai4plasma.utils.device

"""Device management utilities for GPU/CPU selection in AI4Plasma.

This module provides device abstraction and management utilities for seamless
GPU/CPU computation in the AI4Plasma framework. It handles device selection,
validation, and provides a unified interface for device configuration.

Device Functions
----------------
- `check_gpu()`: Check if GPU is available on the system.
- `select_gpu_by_id()`: Select a specific GPU by its ID.
- `torch_device()`: Create a PyTorch device object.

Device Classes
--------------
- `Device`: Centralized device management class for GPU/CPU selection and handling.
"""

import torch

[docs] def check_gpu(print_required=False): """Check if GPU is available. Detects GPU availability on the current system using PyTorch. Parameters ---------- print_required : bool, optional If True, prints GPU availability status to console. Defaults to False. Returns ------- bool True if GPU is available, False otherwise. """ if torch.cuda.is_available(): has_gpu = True if print_required: print('GPU is available!') else: has_gpu = False if print_required: print('GPU is not available!') return has_gpu
[docs] def select_gpu_by_id(gpu_id=0): """Select GPU by its ID. Sets the active GPU using torch.cuda.set_device(). Validates GPU availability and ID validity before selection. Parameters ---------- gpu_id : int, optional GPU ID to select. Must be a non-negative integer and less than the total number of available GPUs. Defaults to 0. Raises ------ ValueError If gpu_id is not an integer, negative, exceeds available GPU count, or if GPU is not available on the system. """ if not isinstance(gpu_id, int): raise ValueError("gpu_id must be an integer.") if gpu_id < 0: raise ValueError(f"Invalid gpu_id {gpu_id}. gpu_id must be an integer greater than or equal to zero.") elif gpu_id >= torch.cuda.device_count(): raise ValueError(f"Invalid gpu_id {gpu_id}. Only {torch.cuda.device_count()} GPUs are available.") else: if not torch.cuda.is_available(): raise ValueError("GPU is not available.") torch.cuda.set_device(gpu_id)
[docs] def torch_device(device_id=-1): """Create and return a PyTorch device object. Generates a torch.device object configured for either GPU or CPU based on the provided device_id. Performs validation of device_id before creation. Parameters ---------- device_id : int, optional Device ID to use. If >= 0, selects the GPU with that ID. If < 0, uses CPU. Defaults to -1 (CPU). Returns ------- torch.device A PyTorch device object configured for the specified device. Raises ------ ValueError If device_id is not an integer, if a GPU device_id is specified but GPU is not available, or if device_id exceeds the number of available GPUs. """ if not isinstance(device_id, int): raise ValueError("device_id must be an integer.") if device_id >= 0: if not torch.cuda.is_available(): raise ValueError("GPU is not available, but a GPU device_id was specified.") if device_id >= torch.cuda.device_count(): raise ValueError(f"Invalid device_id {device_id}. Only {torch.cuda.device_count()} GPUs are available.") return torch.device(f"cuda:{device_id}" if device_id >= 0 else "cpu")
[docs] class Device: """Centralized device management for GPU/CPU selection and handling. This class provides an abstraction layer for device management, allowing flexible switching between GPU and CPU devices with validation. It maintains device state and provides a callable interface for dynamic device retrieval. Attributes ---------- device_id : int The current device ID. Negative values indicate CPU, non-negative values indicate GPU ID. device : torch.device The PyTorch device object for the current device. """ def __init__(self, device_id=-1) -> None: """Initialize Device class with specified device ID. Sets up device management with the specified device ID, creating the underlying torch.device object through validation. Parameters ---------- device_id : int, optional Device ID to select. If < 0, uses CPU. Non-negative values select the GPU with that ID. Defaults to -1 (CPU). Raises ------ ValueError If device_id is invalid or GPU device_id is specified but GPU is not available. """ self.device_id = None self.device = None self.set_device(device_id) def __call__(self, device_id=None): """Retrieve current device or set a new device. Provides callable interface allowing the Device object to be used as a function for both device switching and retrieval. Parameters ---------- device_id : int, optional Device ID to set. If None, returns the current device without changing it. If provided, updates the device to the specified ID. Defaults to None. Returns ------- torch.device The current (or newly set) PyTorch device object. """ if device_id is not None: self.set_device(device_id) return self.device def __str__(self) -> str: """Return string representation of Device object. Provides a human-readable string representation showing the device type and ID (for GPU). Returns ------- str String representation in format 'Device(cpu)' for CPU or 'Device(cuda:{device_id})' for GPU. """ if self.device_id < 0: return "Device(cpu)" else: return f"Device(cuda:{self.device_id})"
[docs] def set_device(self, device_id=-1) -> None: """Set the device based on provided device ID. Updates the internal device configuration to the specified device_id. Creates a new torch.device object and updates device_id attribute. Parameters ---------- device_id : int, optional Device ID to select. If < 0, uses CPU. Non-negative values select the GPU with that ID. Defaults to -1 (CPU). Raises ------ ValueError If device_id is invalid or GPU device_id is specified but GPU is not available. """ self.device = torch_device(device_id) self.device_id = device_id