Source code for ai4plasma.core.network

"""Core neural network architectures for scientific computing and physics-informed learning.

This module provides flexible, highly configurable neural network building blocks optimized
for physics-informed machine learning and operator learning problems in the AI4Plasma framework.

Network Classes
---------------
- `Network`: Abstract base class for all neural network architectures.
- `FNN`: Fully Connected Neural Network.
- `CNN`: Convolutional Neural Network.
- `RelaxLayer`: Relaxed hidden layer for NAS-PINN.
- `RelaxFNN`: Relaxed fully connected network for NAS-PINN.
"""

from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from ai4plasma.config import REAL, DEVICE


[docs] class Network(nn.Module, ABC): """Abstract base class for neural network architectures. This class serves as a base for all neural network implementations. All subclasses should override the forward() and init_weights() methods to implement specific architectures and initialization strategies. Attributes ---------- None """ def __init__(self) -> None: super().__init__()
[docs] @abstractmethod def init_weights(self): """Initialize network weights. This method should be overridden by subclasses to provide specific weight initialization strategies appropriate for the network architecture. Notes ----- Abstract method that must be implemented by subclasses. """ pass
[docs] @abstractmethod def forward(self, x): """Forward pass through the network. Parameters ---------- x : torch.Tensor Input tensor to the network. Returns ------- torch.Tensor Output tensor from the network. Notes ----- Abstract method that must be implemented by subclasses. """ pass
[docs] class FNN(Network): """Fully Connected Neural Network (Multi-layer Perceptron). A flexible multi-layer fully connected neural network with customizable depths, widths, activation functions, and weight initialization strategies. Suitable for function approximation and physics-informed machine learning tasks. Attributes ---------- layers : list of int Number of neurons in each layer. act_fun : torch.nn.Module Activation function applied between layers. net : torch.nn.Sequential The complete network model. """ def __init__(self, layers, act_fun=nn.Tanh(), use_BN=False, init_method='xavier') -> None: """Initialize the FNN. Parameters ---------- layers : list of int Number of neurons in each layer. The first and last values are input and output dimensions, respectively. act_fun : torch.nn.Module, optional Activation function applied between layers. Default is Tanh(). use_BN : bool, optional Whether to use batch normalization after each linear layer (except the output layer). Default is False. init_method : {'xavier', 'zero'}, optional Weight initialization method. 'xavier' uses Xavier/Glorot initialization, 'zero' initializes all weights to zero. Default is 'xavier'. """ super().__init__() self.layers = layers self.act_fun = act_fun self.net = self.linear_model(layers, act_fun, use_BN) self.init_weights(self.net, init_method)
[docs] def linear_model(self, layers, activation, use_BN=False): """Construct a sequential multi-layer network. Parameters ---------- layers : list of int Number of neurons in each layer. activation : torch.nn.Module Activation function to apply between layers. use_BN : bool, optional Whether to use batch normalization. Default is False. Returns ------- torch.nn.Sequential Sequential model containing linear layers with optional batch norm and activation functions. """ model = nn.Sequential() for i in range(0, len(layers) - 2): layer_name = 'linear%d' % (i + 1) activation_name = 'activation%d' % (i + 1) model.add_module(layer_name, nn.Linear(layers[i], layers[i + 1], dtype=REAL('torch'))) if use_BN: model.add_module('BN%d' % (i + 1), nn.BatchNorm1d(layers[i + 1], dtype=REAL('torch'))) model.add_module(activation_name, activation) layer_name = 'linear%d' % (len(layers) - 1) model.add_module(layer_name, nn.Linear(layers[-2], layers[-1], dtype=REAL('torch'))) return model
[docs] def init_weights(self, net, method='xavier'): """Initialize network weights using specified method. Parameters ---------- net : torch.nn.Sequential The network whose weights to initialize. method : {'xavier', 'zero'}, optional Initialization method. Default is 'xavier'. """ for m in net.modules(): if isinstance(m, nn.Linear): if method == 'zero': nn.init.constant_(m.weight, 0.0) else: # default: 'xavier' nn.init.xavier_normal_(m.weight) # bias nn.init.constant_(m.bias, 0.0)
[docs] def forward(self, x): """Forward pass through the FNN. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, input_dim). Returns ------- torch.Tensor Output tensor of shape (batch_size, output_dim). """ out = self.net(x) return out
[docs] class CNN(Network): """Convolutional Neural Network for scientific computing applications. A flexible convolutional neural network designed for physics-informed machine learning and operator learning in plasma physics and other scientific domains. Supports 1D, 2D, and 3D convolutions with customizable architecture parameters, batch normalization, pooling strategies, and an optional fully connected head. Attributes ---------- conv_layers : list of int Channel counts for convolutional layers. fc_layers_config : list of int, optional Configured layer sizes for the fully connected head. input_dim : int Spatial dimension of input data (1, 2, or 3). act_fun : torch.nn.Module Activation function applied after conv and fc layers. use_BN : bool Whether batch normalization is used. use_pooling : bool Whether pooling layers are used. conv_net : torch.nn.Sequential Sequential container of convolutional layers. fc_net : torch.nn.Sequential, optional Sequential container of fully connected layers (lazily initialized). global_pool : torch.nn.Module, optional Global pooling layer (used if fc_layers is None). """ def __init__(self, conv_layers, fc_layers=None, input_dim=2, act_fun=nn.ReLU(), use_BN=False, use_pooling=True, pooling_type='max', kernel_size=3, stride=1, padding=1, pooling_kernel_size=2, pooling_stride=None, pooling_padding=0, init_method='xavier') -> None: """Initialize the CNN. Parameters ---------- conv_layers : list of int Channel counts for convolutional layers. The first value must match the input channels. fc_layers : list of int, optional Neuron counts for fully connected layers. If None, global average pooling is used. If provided, the first value is automatically adjusted to match the flattened conv feature size. Default is None. input_dim : int, optional Spatial dimension of input data (1, 2, or 3). Default is 2. act_fun : torch.nn.Module, optional Activation function. Default is ReLU(). use_BN : bool, optional Whether to use batch normalization. Default is False. use_pooling : bool, optional Whether to use pooling layers. Default is True. pooling_type : {'max', 'avg'}, optional Type of pooling to use. Default is 'max'. kernel_size : int or tuple, optional Convolution kernel size. Default is 3. stride : int or tuple, optional Convolution stride. Default is 1. padding : int or tuple, optional Convolution padding. Default is 1. pooling_kernel_size : int or tuple, optional Pooling kernel size. Default is 2. pooling_stride : int or tuple, optional Pooling stride. If None, defaults to pooling_kernel_size. Default is None. pooling_padding : int or tuple, optional Pooling padding. Default is 0. init_method : {'xavier', 'kaiming', 'zero'}, optional Weight initialization method. Default is 'xavier'. Raises ------ ValueError If input_dim is not 1, 2, or 3. """ super().__init__() # Validate input dimensions if input_dim not in [1, 2, 3]: raise ValueError("input_dim must be 1, 2, or 3") self.conv_layers = conv_layers self.fc_layers_config = fc_layers # Store original config self.input_dim = input_dim self.act_fun = act_fun self.use_BN = use_BN self.use_pooling = use_pooling self.pooling_type = pooling_type self.kernel_size = kernel_size self.stride = stride self.padding = padding self.pooling_kernel_size = pooling_kernel_size self.pooling_stride = pooling_stride if pooling_stride is not None else pooling_kernel_size self.pooling_padding = pooling_padding self.init_method = init_method # Build convolutional backbone self.conv_net = self._build_conv_layers(conv_layers, kernel_size, stride, padding, pooling_kernel_size, self.pooling_stride, pooling_padding) # FC net will be built lazily on first forward pass if fc_layers is specified self.fc_net = None self.fc_layers = None self._fc_initialized = False if fc_layers is None: # Use global average pooling for regression tasks self.global_pool = self._get_global_pool() self.use_fc = False else: self.use_fc = True # Initialize conv weights self.init_weights(init_method) def _get_conv_layer(self): """Get the appropriate convolution layer class. Returns ------- type Convolution layer class (Conv1d, Conv2d, or Conv3d) based on input_dim. """ if self.input_dim == 1: return nn.Conv1d elif self.input_dim == 2: return nn.Conv2d else: # input_dim == 3 return nn.Conv3d def _get_bn_layer(self, channels): """Get the appropriate batch normalization layer. Parameters ---------- channels : int Number of channels for batch normalization. Returns ------- torch.nn.Module BatchNorm1d, BatchNorm2d, or BatchNorm3d layer. """ if self.input_dim == 1: return nn.BatchNorm1d(channels, dtype=REAL('torch')) elif self.input_dim == 2: return nn.BatchNorm2d(channels, dtype=REAL('torch')) else: # input_dim == 3 return nn.BatchNorm3d(channels, dtype=REAL('torch')) def _get_pool_layer(self, kernel_size=2, stride=None, padding=0): """Get the appropriate pooling layer. Parameters ---------- kernel_size : int or tuple, optional Kernel size for pooling. Default is 2. stride : int or tuple, optional Stride for pooling. If None, defaults to kernel_size. Default is None. padding : int or tuple, optional Padding for pooling. Default is 0. Returns ------- torch.nn.Module MaxPool1d, MaxPool2d, MaxPool3d, AvgPool1d, AvgPool2d, or AvgPool3d layer. """ if stride is None: stride = kernel_size if self.pooling_type == 'max': if self.input_dim == 1: return nn.MaxPool1d(kernel_size, stride=stride, padding=padding) elif self.input_dim == 2: return nn.MaxPool2d(kernel_size, stride=stride, padding=padding) else: # input_dim == 3 return nn.MaxPool3d(kernel_size, stride=stride, padding=padding) else: # avg pooling if self.input_dim == 1: return nn.AvgPool1d(kernel_size, stride=stride, padding=padding) elif self.input_dim == 2: return nn.AvgPool2d(kernel_size, stride=stride, padding=padding) else: # input_dim == 3: return nn.AvgPool3d(kernel_size, stride=stride, padding=padding) def _get_global_pool(self): """Get global average pooling layer. Returns ------- torch.nn.Module AdaptiveAvgPool1d, AdaptiveAvgPool2d, or AdaptiveAvgPool3d layer. """ if self.input_dim == 1: return nn.AdaptiveAvgPool1d(1) elif self.input_dim == 2: return nn.AdaptiveAvgPool2d(1) else: # input_dim == 3 return nn.AdaptiveAvgPool3d(1) def _calculate_feature_size(self, input_tensor): """Calculate the flattened feature size after convolution. Performs a forward pass through the convolutional layers without recording gradients to determine the size of the flattened output. Parameters ---------- input_tensor : torch.Tensor Actual input tensor to calculate feature size from. Returns ------- int The flattened feature size after convolution and pooling. """ with torch.no_grad(): out = self.conv_net(input_tensor) flattened_size = out.view(out.size(0), -1).size(1) return flattened_size def _build_conv_layers(self, layers, kernel_size, stride, padding, pooling_kernel_size, pooling_stride, pooling_padding): """Build the convolutional layers of the network. Parameters ---------- layers : list of int Channel numbers for each convolutional layer. kernel_size : int or tuple Kernel size for convolutions. stride : int or tuple Stride for convolutions. padding : int or tuple Padding for convolutions. pooling_kernel_size : int or tuple Kernel size for pooling. pooling_stride : int or tuple Stride for pooling. pooling_padding : int or tuple Padding for pooling. Returns ------- torch.nn.Sequential Sequential container of convolutional layers with optional batch normalization, activation, and pooling. """ conv_layer_fn = self._get_conv_layer() model = nn.Sequential() for i in range(len(layers) - 1): # Convolutional layer conv_name = f'conv{i + 1}' model.add_module(conv_name, conv_layer_fn( layers[i], layers[i + 1], kernel_size=kernel_size, stride=stride, padding=padding, dtype=REAL('torch') )) # Batch normalization (optional) if self.use_BN: bn_name = f'bn{i + 1}' model.add_module(bn_name, self._get_bn_layer(layers[i + 1])) # Activation function act_name = f'activation{i + 1}' model.add_module(act_name, self.act_fun) # Pooling layer (optional) if self.use_pooling: pool_name = f'pool{i + 1}' model.add_module(pool_name, self._get_pool_layer( kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding )) return model def _build_fc_layers(self, layers): """Build the fully connected layers of the network. Parameters ---------- layers : list of int Neuron counts for each fully connected layer. Returns ------- torch.nn.Sequential Sequential container of fully connected layers with activation functions (except on the output layer). """ model = nn.Sequential() for i in range(len(layers) - 1): # Linear layer linear_name = f'fc{i + 1}' model.add_module(linear_name, nn.Linear(layers[i], layers[i + 1], dtype=REAL('torch'))) # Activation (except for the last layer) if i < len(layers) - 2: act_name = f'fc_activation{i + 1}' model.add_module(act_name, self.act_fun) return model
[docs] def init_weights(self, method='xavier'): """Initialize network weights using the specified method. Applies the initialization strategy to all convolutional, linear, and batch normalization layers in the network. Parameters ---------- method : {'xavier', 'kaiming', 'zero'}, optional Weight initialization method. Default is 'xavier'. """ for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): if method == 'zero': nn.init.constant_(m.weight, 0.0) elif method == 'kaiming': nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') else: # default: 'xavier' nn.init.xavier_normal_(m.weight) # Initialize bias if m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0)
[docs] def forward(self, x): """Forward pass through the CNN with lazy FC layer initialization. On the first forward pass with fc_layers, automatically adjusts the fully connected layer input size based on the actual conv output shape. Subsequent forwards use the pre-initialized FC layers. Parameters ---------- x : torch.Tensor Input tensor with shape depending on input_dim: - 1D: (batch_size, channels, length) - 2D: (batch_size, channels, height, width) - 3D: (batch_size, channels, depth, height, width) Returns ------- torch.Tensor Output tensor of shape (batch_size, output_features). """ # Lazily initialize FC layers on first forward pass if self.use_fc and not self._fc_initialized: # Calculate actual feature size from input actual_feature_size = self._calculate_feature_size(x) # Adjust fc_layers config if needed if self.fc_layers_config[0] != actual_feature_size: adjusted_fc_layers = [actual_feature_size] + self.fc_layers_config[1:] print(f"[INFO] CNN: Automatically adjusted fc_layers[0] from {self.fc_layers_config[0]} to {actual_feature_size} (based on actual input shape {x.shape})") self.fc_layers = adjusted_fc_layers else: self.fc_layers = self.fc_layers_config # Build FC layers self.fc_net = self._build_fc_layers(self.fc_layers) self.fc_net = self.fc_net.to(x.device) self._fc_initialized = True # Pass through convolutional layers out = self.conv_net(x) # Process through FC layers or global pooling if self.use_fc: # Flatten for fully connected layers out = out.view(out.size(0), -1) # (batch_size, flattened_features) out = self.fc_net(out) else: # Global average pooling for regression out = self.global_pool(out) out = out.view(out.size(0), -1) # Flatten to (batch_size, channels) return out
[docs] def get_feature_size(self, input_shape): """Calculate the output feature size after convolutional layers. Used to determine the input size required for the first fully connected layer when designing network architectures manually. Parameters ---------- input_shape : tuple Shape of input tensor (channels, spatial_dims...) Returns ------- int Number of features after convolution and pooling operations. """ # Create a dummy input tensor if self.input_dim == 1: dummy_input = torch.zeros(1, *input_shape, dtype=REAL('torch')) elif self.input_dim == 2: dummy_input = torch.zeros(1, *input_shape, dtype=REAL('torch')) else: # input_dim == 3 dummy_input = torch.zeros(1, *input_shape, dtype=REAL('torch')) # Pass through conv layers with torch.no_grad(): out = self.conv_net(dummy_input) return out.view(1, -1).size(1)
[docs] class RelaxLayer(Network): """Relaxed hidden layer for NAS-PINN architecture search. Implements a relaxed fully connected hidden layer where the effective operation is determined by learnable architecture parameters ``g``. The layer blends identity and nonlinear transformations with soft selections over candidate neuron counts. Parameters ---------- C_in : int Input feature dimension. neuron_list : list of int Candidate neuron counts. ``neuron_list[0]`` is treated as identity. act_fun : torch.nn.Module, optional Activation function applied after the linear operator. Default is Tanh. init_method : {'xavier', 'zero'}, optional Weight initialization method for the linear operator. Default is 'xavier'. Attributes ---------- C_in : int Input feature dimension. neuron_list : list of int Candidate neuron counts (including identity as index 0). op : torch.nn.Sequential Linear + activation operator for the maximal neuron count. masks : torch.Tensor Binary masks used to select subsets of neurons. """ def __init__(self, C_in, neuron_list, act_fun=nn.Tanh(), init_method='xavier'): """Initialize a relaxed hidden layer. Parameters ---------- C_in : int Input feature dimension. neuron_list : list of int Candidate neuron counts. The last value is the maximum width. act_fun : torch.nn.Module, optional Activation function applied after the linear operator. Default is Tanh. init_method : {'xavier', 'zero'}, optional Weight initialization method for the linear operator. Default is 'xavier'. """ super().__init__() self.C_in = C_in self.neuron_list = neuron_list # neuron_list[0] is always 0, representing Identity operation self.op = nn.Sequential( nn.Linear(C_in, neuron_list[-1]), act_fun) i = 0 for neuron in neuron_list[1:]: one = torch.ones(1,int(neuron)) if neuron < neuron_list[-1]: zero = torch.zeros(1,int(neuron_list[-1] - neuron)) mask = torch.cat([one,zero], 1) else: mask = one if i < 1: self.masks = mask i += 1 else: self.masks = torch.cat((self.masks, mask), 0) i += 1 self.masks = self.masks.to(DEVICE()) self.init_weights(self.op, init_method)
[docs] def forward(self, x, g): """Forward pass with relaxed architecture parameters. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, C_in). g : torch.Tensor Architecture parameter tensor for this layer. The first two entries control identity vs. nonlinear mixing; the remaining entries select neuron counts. Returns ------- torch.Tensor Output tensor of shape (batch_size, max(neuron_list)). """ # double-level: g[0] is a, g[1] is b, b contains all neurons g = g.view(1, -1) ab = F.softmax(g[:,:2], dim=-1) a = ab[0,0] b = ab[0,1] g_neuron = F.softmax(g[:,2:], dim=-1) patch = torch.zeros(x.size(0), self.neuron_list[-1] - self.C_in) patch = patch.to(DEVICE()) identity = torch.cat((x, patch), 1) temp = self.op(x) weight = (torch.mm(g_neuron, self.masks)).T temp = temp * weight[:,0] * b out = temp + a * identity return out
[docs] def init_weights(self, net, method='xavier'): """Initialize weights for the linear operator. Parameters ---------- net : torch.nn.Sequential Operator whose weights will be initialized. method : {'xavier', 'zero'}, optional Weight initialization method. Default is 'xavier'. """ for m in net.modules(): if isinstance(m, nn.Linear): if method == 'zero': nn.init.constant_(m.weight, 0.0) else: # default: 'xavier' nn.init.xavier_normal_(m.weight) # bias nn.init.constant_(m.bias, 0.0)
[docs] class RelaxFNN(Network): """Relaxed fully connected network for NAS-PINN. Builds a stack of relaxed layers whose architectures are controlled by learnable parameters. The network supports architecture search by learning soft selections over identity vs. nonlinear paths and neuron counts. Parameters ---------- layers : int Number of relaxed layers. C_in_list : list of int Input dimension for each layer (length equals ``layers``). neuron_list : list of int Candidate neuron counts for each relaxed layer. Attributes ---------- layers : int Number of relaxed layers. C_in_list : list of int Input dimensions for each layer. neuron_list : list of int Candidate neuron counts. network : torch.nn.ModuleList Stack of RelaxLayer modules. gs : torch.Tensor Architecture parameters with shape (layers, len(neuron_list) + 1). """ def __init__(self, layers, C_in_list, neuron_list): """Initialize the relaxed FNN. Parameters ---------- layers : int Number of relaxed layers. C_in_list : list of int Input dimension for each layer. neuron_list : list of int Candidate neuron counts for each layer. """ super().__init__() self.layers = layers self.C_in_list = C_in_list self.neuron_list = neuron_list self.output_layer = nn.Linear(C_in_list[-1], 1) self.init_weights() self.build_up()
[docs] def init_weights(self): """Initialize architecture parameters ``gs``. Creates learnable parameters for each layer, with the first two entries controlling identity vs. nonlinear mixing and the remaining entries selecting neuron counts. gs shape: (layers, len(neuron_list) + 1), where: - gs[i, 0] corresponds to the weight for the identity path in layer i. - gs[i, 1] corresponds to the weight for the nonlinear path in layer i. - gs[i, 2:] correspond to the weights for selecting neuron counts in layer i. """ self.gs = Variable(1e-3 * torch.randn(self.layers, len(self.neuron_list)+1).to(DEVICE()), requires_grad=True) self.arch_para = [self.gs]
[docs] def load_gs(self, arch_param): """Load architecture parameters from an external list. self.arch_para is a list length 1, where the first element is the architecture parameter tensor gs. Parameters ---------- arch_param : list of torch.Tensor List where the first element contains the architecture parameters. """ self.gs = arch_param[0] self.arch_para = [self.gs]
[docs] def arch_parameters(self): """Return architecture parameters for optimization. Returns ------- list of torch.Tensor List containing the architecture parameter tensor. """ self.arch_para = [self.gs] return self.arch_para
[docs] def build_up(self): """Build the relaxed network stack. Creates a ModuleList of RelaxLayer instances based on ``C_in_list`` and ``neuron_list``. """ self.network = nn.ModuleList() for i in range(self.layers): self.network.append(RelaxLayer(self.C_in_list[i], self.neuron_list))
[docs] def searched_neuron(self, threshold=1e-3): """Derive the discrete architecture from learned parameters. If the identity vs. nonlinear weights are close, retain a residual connection and append the selected neuron count. The residual connection is denoted by '0+' followed by the selected neuron count. Otherwise, select the dominant option. Parameters ---------- threshold : float, optional Threshold for determining if identity and nonlinear paths are close. Default is 1e-3. Returns ------- list List of selected neuron descriptors per layer. """ final_neuron = [] for i in range(self.layers): g = self.gs[i] ab = F.softmax(g[:2], dim=-1) if abs(ab[0] - ab[1]) > threshold: if torch.argmax(ab) == 0: index = 0 else: g = F.softmax(g[2:], dim=-1) index = torch.argmax(g)+1 neuron = self.neuron_list[index] else: neuron = '0+' g = F.softmax(g[2:], dim=-1) index = torch.argmax(g)+1 neuron += str(self.neuron_list[index]) final_neuron.append(neuron) return final_neuron
[docs] def forward(self, x): """Forward pass through the relaxed FNN. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, C_in_list[0]). Returns ------- torch.Tensor Output tensor of shape (batch_size, 1). """ for i in range(self.layers): x = self.network[i](x, self.gs[i]) y = self.output_layer(x) return y