1. Core Module
Depending on your installation, you can import ai4plasma.core module directly.
1.1. Model
Base model classes for neural network training and inference in AI4Plasma.
This module provides a unified framework for neural network training, validation, and inference in the AI4Plasma project. It implements the core training pipeline infrastructure with support for distributed computing, checkpoint management, and TensorBoard monitoring.
1.1.1. Model Classes
BaseModel: Abstract base class for all AI4Plasma models.
CfgBaseModel: Configuration-driven model wrapper.
- class ai4plasma.core.model.BaseModel(network)[source]
Bases:
ABCBase Model class that provides a framework for training and predicting.
This class provides a unified framework for training and inference using a neural network as the underlying computational model. It handles model initialization, loss computation, and prediction with device-agnostic computation (CPU/GPU).
This class is abstract and should be subclassed by concrete implementations that override the train() and calc_loss() methods.
- device_id
The device ID for GPU computation.
- Type:
int
- network
The neural network model moved to the configured device.
- Type:
torch.nn.Module
- loss_func
The loss function used during training.
- Type:
callable, optional
- optimizer
The optimizer used to update network parameters.
- Type:
torch.optim.Optimizer, optional
- lr_scheduler
The learning rate scheduler for adaptive learning rates.
- Type:
torch.optim.lr_scheduler, optional
- abstractmethod calc_loss()[source]
Calculate the loss.
Computes the loss value for the current batch. This method should be overridden by subclasses to implement custom loss computation logic.
Notes
Abstract method that must be implemented by subclasses.
- predict(X)[source]
Perform inference using the model.
Executes the network in evaluation mode without computing gradients, suitable for inference on new data.
- Parameters:
X (torch.Tensor) – Input data for inference.
- Returns:
Output predictions from the neural network.
- Return type:
torch.Tensor
- prepare_test_data()[source]
Set testing data.
This method should be overridden by subclasses to load and prepare the testing dataset.
- prepare_train_data()[source]
Set training data.
This method should be overridden by subclasses to load and prepare the training dataset.
- set_loss_func(loss_func) None[source]
Set the loss function.
- Parameters:
loss_func (callable) – The loss function to be used during training. Typically a function from torch.nn.functional or torch.nn that computes a scalar loss value.
- set_lr_scheduler(lr_scheduler) None[source]
Set the learning rate scheduler.
- Parameters:
lr_scheduler (torch.optim.lr_scheduler) – The learning rate scheduler to be used during training for adaptive learning rate adjustment.
- set_optimizer(optimizer) None[source]
Set the optimizer.
- Parameters:
optimizer (torch.optim.Optimizer) – The optimizer to be used during training. Examples include Adam, SGD, or other PyTorch optimizers.
- abstractmethod train()[source]
Execute one training step.
Performs a single training iteration including forward pass, loss calculation, and backward pass with optimizer step. This method should be overridden by subclasses to implement custom training logic.
Notes
Abstract method that must be implemented by subclasses.
- class ai4plasma.core.model.CfgBaseModel(cfg_file, network)[source]
Bases:
BaseModelConfiguration-driven base model with training pipeline support.
Extends BaseModel with configuration file loading, checkpoint management, TensorBoard logging, and automated training pipeline with visualization. This class implements the full training lifecycle including initialization, per-epoch callbacks, and post-training actions.
- cfg
Configuration dictionary loaded from JSON file.
- Type:
dict
- saved_model
Loaded model checkpoint containing model state, optimizer state, and epoch.
- Type:
dict, optional
- loss_list
List of (epoch, loss_value) tuples for tracking training history.
- Type:
list
- fig_file_list
List of figure file paths for GIF animation generation.
- Type:
list
- writer
TensorBoard summary writer for logging metrics.
- Type:
SummaryWriter, optional
- epoch
Current training epoch.
- Type:
int
- total_epochs
Total number of epochs to train.
- Type:
int
- last_epoch
Last completed epoch (used for resuming training).
- Type:
int
- do_after_each_epoch()[source]
Execute callbacks after each training epoch.
Handles logging to TensorBoard, generating visualization plots, saving checkpoints, and collecting loss history based on configured frequencies and callbacks.
Notes
This method should be called at the end of each epoch in the training loop. Automatically manages checkpoint saving, TensorBoard logging, and GIF figure collection based on configured frequencies.
- do_after_training()[source]
Finalize training and generate outputs.
Generates animated GIF from collected figures, saves loss history to disk, closes TensorBoard writer, and optionally removes temporary files.
Notes
Should be called after training is complete to generate final outputs and cleanup temporary resources.
- do_before_training(**kwargs)[source]
Initialize training pipeline before starting epochs.
Extracts runtime parameters from kwargs and loads pretrained model checkpoint if specified in configuration.
- Parameters:
kwargs (dict) – Keyword arguments passed to get_kwargs() for configuration.
- get_init_args()[source]
Initialize default training parameters.
Sets up default values for loss function (MSE), optimizer (Adam), epoch tracking, and storage lists for history and figures.
- get_json_args(cfg)[source]
Extract and configure training parameters from configuration dictionary.
- Parameters:
cfg (dict) – Configuration dictionary containing training parameters such as learning rate, number of epochs, logging frequency, file paths, etc.
- get_kwargs(**kwargs)[source]
Extract and set training kwargs (callbacks, optimizer, loss functions).
- Parameters:
kwargs (dict) –
optimizer : torch.optim.Optimizer, optional
- calc_l2_errcallable, optional
Function to calculate L2 error for logging.
- plot_func_trainingcallable, optional
Function to generate training plots for TensorBoard.
- plot_func_gifcallable, optional
Function to generate plots for GIF animation.
- load_last_epoch()[source]
Load the last completed epoch from checkpoint and compute total epochs.
Updates total_epochs based on the loaded last_epoch to support resuming training from where it was paused.
- load_model(model_file, map_location=None)[source]
Load a complete checkpoint including model, optimizer, and epoch.
Orchestrates loading of model weights, optimizer state, and epoch info from a saved checkpoint file for resuming training.
- Parameters:
model_file (str) – Path to the model checkpoint file.
map_location (str, optional) – Device to map the model to. If None, uses the configured default device.
- load_model_from_file(model_file, map_location=None)[source]
Load a model checkpoint from file.
- Parameters:
model_file (str) – Path to the saved model checkpoint file.
map_location (str, optional) – Device to map the model to. If None, uses the configured default device.
- Returns:
Loaded checkpoint dictionary containing model, optimizer, and epoch info.
- Return type:
dict
1.2. Network
Core neural network architectures for scientific computing and physics-informed learning.
This module provides flexible, highly configurable neural network building blocks optimized for physics-informed machine learning and operator learning problems in the AI4Plasma framework.
1.2.1. Network Classes
Network: Abstract base class for all neural network architectures.
FNN: Fully Connected Neural Network.
CNN: Convolutional Neural Network.
RelaxLayer: Relaxed hidden layer for NAS-PINN.
RelaxFNN: Relaxed fully connected network for NAS-PINN.
- class ai4plasma.core.network.CNN(conv_layers, fc_layers=None, input_dim=2, act_fun=ReLU(), use_BN=False, use_pooling=True, pooling_type='max', kernel_size=3, stride=1, padding=1, pooling_kernel_size=2, pooling_stride=None, pooling_padding=0, init_method='xavier')[source]
Bases:
NetworkConvolutional Neural Network for scientific computing applications.
A flexible convolutional neural network designed for physics-informed machine learning and operator learning in plasma physics and other scientific domains. Supports 1D, 2D, and 3D convolutions with customizable architecture parameters, batch normalization, pooling strategies, and an optional fully connected head.
- conv_layers
Channel counts for convolutional layers.
- Type:
list of int
- fc_layers_config
Configured layer sizes for the fully connected head.
- Type:
list of int, optional
- input_dim
Spatial dimension of input data (1, 2, or 3).
- Type:
int
- act_fun
Activation function applied after conv and fc layers.
- Type:
torch.nn.Module
- use_BN
Whether batch normalization is used.
- Type:
bool
- use_pooling
Whether pooling layers are used.
- Type:
bool
- conv_net
Sequential container of convolutional layers.
- Type:
torch.nn.Sequential
- fc_net
Sequential container of fully connected layers (lazily initialized).
- Type:
torch.nn.Sequential, optional
- global_pool
Global pooling layer (used if fc_layers is None).
- Type:
torch.nn.Module, optional
- forward(x)[source]
Forward pass through the CNN with lazy FC layer initialization.
On the first forward pass with fc_layers, automatically adjusts the fully connected layer input size based on the actual conv output shape. Subsequent forwards use the pre-initialized FC layers.
- Parameters:
x (torch.Tensor) –
Input tensor with shape depending on input_dim:
1D: (batch_size, channels, length)
2D: (batch_size, channels, height, width)
3D: (batch_size, channels, depth, height, width)
- Returns:
Output tensor of shape (batch_size, output_features).
- Return type:
torch.Tensor
- get_feature_size(input_shape)[source]
Calculate the output feature size after convolutional layers.
Used to determine the input size required for the first fully connected layer when designing network architectures manually.
- Parameters:
input_shape (tuple) – Shape of input tensor (channels, spatial_dims…)
- Returns:
Number of features after convolution and pooling operations.
- Return type:
int
- init_weights(method='xavier')[source]
Initialize network weights using the specified method.
Applies the initialization strategy to all convolutional, linear, and batch normalization layers in the network.
- Parameters:
method ({'xavier', 'kaiming', 'zero'}, optional) – Weight initialization method. Default is ‘xavier’.
- class ai4plasma.core.network.FNN(layers, act_fun=Tanh(), use_BN=False, init_method='xavier')[source]
Bases:
NetworkFully Connected Neural Network (Multi-layer Perceptron).
A flexible multi-layer fully connected neural network with customizable depths, widths, activation functions, and weight initialization strategies. Suitable for function approximation and physics-informed machine learning tasks.
- layers
Number of neurons in each layer.
- Type:
list of int
- act_fun
Activation function applied between layers.
- Type:
torch.nn.Module
- net
The complete network model.
- Type:
torch.nn.Sequential
- forward(x)[source]
Forward pass through the FNN.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, input_dim).
- Returns:
Output tensor of shape (batch_size, output_dim).
- Return type:
torch.Tensor
- init_weights(net, method='xavier')[source]
Initialize network weights using specified method.
- Parameters:
net (torch.nn.Sequential) – The network whose weights to initialize.
method ({'xavier', 'zero'}, optional) – Initialization method. Default is ‘xavier’.
- linear_model(layers, activation, use_BN=False)[source]
Construct a sequential multi-layer network.
- Parameters:
layers (list of int) – Number of neurons in each layer.
activation (torch.nn.Module) – Activation function to apply between layers.
use_BN (bool, optional) – Whether to use batch normalization. Default is False.
- Returns:
Sequential model containing linear layers with optional batch norm and activation functions.
- Return type:
torch.nn.Sequential
- class ai4plasma.core.network.Network[source]
Bases:
Module,ABCAbstract base class for neural network architectures.
This class serves as a base for all neural network implementations. All subclasses should override the forward() and init_weights() methods to implement specific architectures and initialization strategies.
- None
- class ai4plasma.core.network.RelaxFNN(layers, C_in_list, neuron_list)[source]
Bases:
NetworkRelaxed fully connected network for NAS-PINN.
Builds a stack of relaxed layers whose architectures are controlled by learnable parameters. The network supports architecture search by learning soft selections over identity vs. nonlinear paths and neuron counts.
- Parameters:
layers (int) – Number of relaxed layers.
C_in_list (list of int) – Input dimension for each layer (length equals
layers).neuron_list (list of int) – Candidate neuron counts for each relaxed layer.
- layers
Number of relaxed layers.
- Type:
int
- C_in_list
Input dimensions for each layer.
- Type:
list of int
- neuron_list
Candidate neuron counts.
- Type:
list of int
- network
Stack of RelaxLayer modules.
- Type:
torch.nn.ModuleList
- gs
Architecture parameters with shape (layers, len(neuron_list) + 1).
- Type:
torch.Tensor
- arch_parameters()[source]
Return architecture parameters for optimization.
- Returns:
List containing the architecture parameter tensor.
- Return type:
list of torch.Tensor
- build_up()[source]
Build the relaxed network stack.
Creates a ModuleList of RelaxLayer instances based on
C_in_listandneuron_list.
- forward(x)[source]
Forward pass through the relaxed FNN.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, C_in_list[0]).
- Returns:
Output tensor of shape (batch_size, 1).
- Return type:
torch.Tensor
- init_weights()[source]
Initialize architecture parameters
gs.Creates learnable parameters for each layer, with the first two entries controlling identity vs. nonlinear mixing and the remaining entries selecting neuron counts.
- gs shape: (layers, len(neuron_list) + 1), where:
gs[i, 0] corresponds to the weight for the identity path in layer i.
gs[i, 1] corresponds to the weight for the nonlinear path in layer i.
gs[i, 2:] correspond to the weights for selecting neuron counts in layer i.
- load_gs(arch_param)[source]
Load architecture parameters from an external list.
self.arch_para is a list length 1, where the first element is the architecture parameter tensor gs.
- Parameters:
arch_param (list of torch.Tensor) – List where the first element contains the architecture parameters.
- searched_neuron(threshold=0.001)[source]
Derive the discrete architecture from learned parameters.
If the identity vs. nonlinear weights are close, retain a residual connection and append the selected neuron count. The residual connection is denoted by ‘0+’ followed by the selected neuron count. Otherwise, select the dominant option.
- Parameters:
threshold (float, optional) – Threshold for determining if identity and nonlinear paths are close. Default is 1e-3.
- Returns:
List of selected neuron descriptors per layer.
- Return type:
list
- class ai4plasma.core.network.RelaxLayer(C_in, neuron_list, act_fun=Tanh(), init_method='xavier')[source]
Bases:
NetworkRelaxed hidden layer for NAS-PINN architecture search.
Implements a relaxed fully connected hidden layer where the effective operation is determined by learnable architecture parameters
g. The layer blends identity and nonlinear transformations with soft selections over candidate neuron counts.- Parameters:
C_in (int) – Input feature dimension.
neuron_list (list of int) – Candidate neuron counts.
neuron_list[0]is treated as identity.act_fun (torch.nn.Module, optional) – Activation function applied after the linear operator. Default is Tanh.
init_method ({'xavier', 'zero'}, optional) – Weight initialization method for the linear operator. Default is ‘xavier’.
- C_in
Input feature dimension.
- Type:
int
- neuron_list
Candidate neuron counts (including identity as index 0).
- Type:
list of int
- op
Linear + activation operator for the maximal neuron count.
- Type:
torch.nn.Sequential
- masks
Binary masks used to select subsets of neurons.
- Type:
torch.Tensor
- forward(x, g)[source]
Forward pass with relaxed architecture parameters.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, C_in).
g (torch.Tensor) – Architecture parameter tensor for this layer. The first two entries control identity vs. nonlinear mixing; the remaining entries select neuron counts.
- Returns:
Output tensor of shape (batch_size, max(neuron_list)).
- Return type:
torch.Tensor