1. Core Module

Depending on your installation, you can import ai4plasma.core module directly.

1.1. Model

Base model classes for neural network training and inference in AI4Plasma.

This module provides a unified framework for neural network training, validation, and inference in the AI4Plasma project. It implements the core training pipeline infrastructure with support for distributed computing, checkpoint management, and TensorBoard monitoring.

1.1.1. Model Classes

  • BaseModel: Abstract base class for all AI4Plasma models.

  • CfgBaseModel: Configuration-driven model wrapper.

class ai4plasma.core.model.BaseModel(network)[source]

Bases: ABC

Base Model class that provides a framework for training and predicting.

This class provides a unified framework for training and inference using a neural network as the underlying computational model. It handles model initialization, loss computation, and prediction with device-agnostic computation (CPU/GPU).

This class is abstract and should be subclassed by concrete implementations that override the train() and calc_loss() methods.

device_id

The device ID for GPU computation.

Type:

int

network

The neural network model moved to the configured device.

Type:

torch.nn.Module

loss_func

The loss function used during training.

Type:

callable, optional

optimizer

The optimizer used to update network parameters.

Type:

torch.optim.Optimizer, optional

lr_scheduler

The learning rate scheduler for adaptive learning rates.

Type:

torch.optim.lr_scheduler, optional

abstractmethod calc_loss()[source]

Calculate the loss.

Computes the loss value for the current batch. This method should be overridden by subclasses to implement custom loss computation logic.

Notes

Abstract method that must be implemented by subclasses.

predict(X)[source]

Perform inference using the model.

Executes the network in evaluation mode without computing gradients, suitable for inference on new data.

Parameters:

X (torch.Tensor) – Input data for inference.

Returns:

Output predictions from the neural network.

Return type:

torch.Tensor

prepare_test_data()[source]

Set testing data.

This method should be overridden by subclasses to load and prepare the testing dataset.

prepare_train_data()[source]

Set training data.

This method should be overridden by subclasses to load and prepare the training dataset.

set_loss_func(loss_func) None[source]

Set the loss function.

Parameters:

loss_func (callable) – The loss function to be used during training. Typically a function from torch.nn.functional or torch.nn that computes a scalar loss value.

set_lr_scheduler(lr_scheduler) None[source]

Set the learning rate scheduler.

Parameters:

lr_scheduler (torch.optim.lr_scheduler) – The learning rate scheduler to be used during training for adaptive learning rate adjustment.

set_optimizer(optimizer) None[source]

Set the optimizer.

Parameters:

optimizer (torch.optim.Optimizer) – The optimizer to be used during training. Examples include Adam, SGD, or other PyTorch optimizers.

abstractmethod train()[source]

Execute one training step.

Performs a single training iteration including forward pass, loss calculation, and backward pass with optimizer step. This method should be overridden by subclasses to implement custom training logic.

Notes

Abstract method that must be implemented by subclasses.

class ai4plasma.core.model.CfgBaseModel(cfg_file, network)[source]

Bases: BaseModel

Configuration-driven base model with training pipeline support.

Extends BaseModel with configuration file loading, checkpoint management, TensorBoard logging, and automated training pipeline with visualization. This class implements the full training lifecycle including initialization, per-epoch callbacks, and post-training actions.

cfg

Configuration dictionary loaded from JSON file.

Type:

dict

saved_model

Loaded model checkpoint containing model state, optimizer state, and epoch.

Type:

dict, optional

loss_list

List of (epoch, loss_value) tuples for tracking training history.

Type:

list

fig_file_list

List of figure file paths for GIF animation generation.

Type:

list

writer

TensorBoard summary writer for logging metrics.

Type:

SummaryWriter, optional

epoch

Current training epoch.

Type:

int

total_epochs

Total number of epochs to train.

Type:

int

last_epoch

Last completed epoch (used for resuming training).

Type:

int

do_after_each_epoch()[source]

Execute callbacks after each training epoch.

Handles logging to TensorBoard, generating visualization plots, saving checkpoints, and collecting loss history based on configured frequencies and callbacks.

Notes

This method should be called at the end of each epoch in the training loop. Automatically manages checkpoint saving, TensorBoard logging, and GIF figure collection based on configured frequencies.

do_after_training()[source]

Finalize training and generate outputs.

Generates animated GIF from collected figures, saves loss history to disk, closes TensorBoard writer, and optionally removes temporary files.

Notes

Should be called after training is complete to generate final outputs and cleanup temporary resources.

do_before_training(**kwargs)[source]

Initialize training pipeline before starting epochs.

Extracts runtime parameters from kwargs and loads pretrained model checkpoint if specified in configuration.

Parameters:

kwargs (dict) – Keyword arguments passed to get_kwargs() for configuration.

get_init_args()[source]

Initialize default training parameters.

Sets up default values for loss function (MSE), optimizer (Adam), epoch tracking, and storage lists for history and figures.

get_json_args(cfg)[source]

Extract and configure training parameters from configuration dictionary.

Parameters:

cfg (dict) – Configuration dictionary containing training parameters such as learning rate, number of epochs, logging frequency, file paths, etc.

get_kwargs(**kwargs)[source]

Extract and set training kwargs (callbacks, optimizer, loss functions).

Parameters:

kwargs (dict) –

  • optimizer : torch.optim.Optimizer, optional

  • calc_l2_errcallable, optional

    Function to calculate L2 error for logging.

  • plot_func_trainingcallable, optional

    Function to generate training plots for TensorBoard.

  • plot_func_gifcallable, optional

    Function to generate plots for GIF animation.

load_last_epoch()[source]

Load the last completed epoch from checkpoint and compute total epochs.

Updates total_epochs based on the loaded last_epoch to support resuming training from where it was paused.

load_model(model_file, map_location=None)[source]

Load a complete checkpoint including model, optimizer, and epoch.

Orchestrates loading of model weights, optimizer state, and epoch info from a saved checkpoint file for resuming training.

Parameters:
  • model_file (str) – Path to the model checkpoint file.

  • map_location (str, optional) – Device to map the model to. If None, uses the configured default device.

load_model_from_file(model_file, map_location=None)[source]

Load a model checkpoint from file.

Parameters:
  • model_file (str) – Path to the saved model checkpoint file.

  • map_location (str, optional) – Device to map the model to. If None, uses the configured default device.

Returns:

Loaded checkpoint dictionary containing model, optimizer, and epoch info.

Return type:

dict

load_optimizer()[source]

Load optimizer state from the saved checkpoint.

Loads optimizer state including momentum, accumulated gradients, etc. from the saved checkpoint if available.

load_weights()[source]

Load model weights from the saved checkpoint.

Loads network weights from the saved model dictionary if available.

1.2. Network

Core neural network architectures for scientific computing and physics-informed learning.

This module provides flexible, highly configurable neural network building blocks optimized for physics-informed machine learning and operator learning problems in the AI4Plasma framework.

1.2.1. Network Classes

  • Network: Abstract base class for all neural network architectures.

  • FNN: Fully Connected Neural Network.

  • CNN: Convolutional Neural Network.

  • RelaxLayer: Relaxed hidden layer for NAS-PINN.

  • RelaxFNN: Relaxed fully connected network for NAS-PINN.

class ai4plasma.core.network.CNN(conv_layers, fc_layers=None, input_dim=2, act_fun=ReLU(), use_BN=False, use_pooling=True, pooling_type='max', kernel_size=3, stride=1, padding=1, pooling_kernel_size=2, pooling_stride=None, pooling_padding=0, init_method='xavier')[source]

Bases: Network

Convolutional Neural Network for scientific computing applications.

A flexible convolutional neural network designed for physics-informed machine learning and operator learning in plasma physics and other scientific domains. Supports 1D, 2D, and 3D convolutions with customizable architecture parameters, batch normalization, pooling strategies, and an optional fully connected head.

conv_layers

Channel counts for convolutional layers.

Type:

list of int

fc_layers_config

Configured layer sizes for the fully connected head.

Type:

list of int, optional

input_dim

Spatial dimension of input data (1, 2, or 3).

Type:

int

act_fun

Activation function applied after conv and fc layers.

Type:

torch.nn.Module

use_BN

Whether batch normalization is used.

Type:

bool

use_pooling

Whether pooling layers are used.

Type:

bool

conv_net

Sequential container of convolutional layers.

Type:

torch.nn.Sequential

fc_net

Sequential container of fully connected layers (lazily initialized).

Type:

torch.nn.Sequential, optional

global_pool

Global pooling layer (used if fc_layers is None).

Type:

torch.nn.Module, optional

forward(x)[source]

Forward pass through the CNN with lazy FC layer initialization.

On the first forward pass with fc_layers, automatically adjusts the fully connected layer input size based on the actual conv output shape. Subsequent forwards use the pre-initialized FC layers.

Parameters:

x (torch.Tensor) –

Input tensor with shape depending on input_dim:

  • 1D: (batch_size, channels, length)

  • 2D: (batch_size, channels, height, width)

  • 3D: (batch_size, channels, depth, height, width)

Returns:

Output tensor of shape (batch_size, output_features).

Return type:

torch.Tensor

get_feature_size(input_shape)[source]

Calculate the output feature size after convolutional layers.

Used to determine the input size required for the first fully connected layer when designing network architectures manually.

Parameters:

input_shape (tuple) – Shape of input tensor (channels, spatial_dims…)

Returns:

Number of features after convolution and pooling operations.

Return type:

int

init_weights(method='xavier')[source]

Initialize network weights using the specified method.

Applies the initialization strategy to all convolutional, linear, and batch normalization layers in the network.

Parameters:

method ({'xavier', 'kaiming', 'zero'}, optional) – Weight initialization method. Default is ‘xavier’.

class ai4plasma.core.network.FNN(layers, act_fun=Tanh(), use_BN=False, init_method='xavier')[source]

Bases: Network

Fully Connected Neural Network (Multi-layer Perceptron).

A flexible multi-layer fully connected neural network with customizable depths, widths, activation functions, and weight initialization strategies. Suitable for function approximation and physics-informed machine learning tasks.

layers

Number of neurons in each layer.

Type:

list of int

act_fun

Activation function applied between layers.

Type:

torch.nn.Module

net

The complete network model.

Type:

torch.nn.Sequential

forward(x)[source]

Forward pass through the FNN.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, input_dim).

Returns:

Output tensor of shape (batch_size, output_dim).

Return type:

torch.Tensor

init_weights(net, method='xavier')[source]

Initialize network weights using specified method.

Parameters:
  • net (torch.nn.Sequential) – The network whose weights to initialize.

  • method ({'xavier', 'zero'}, optional) – Initialization method. Default is ‘xavier’.

linear_model(layers, activation, use_BN=False)[source]

Construct a sequential multi-layer network.

Parameters:
  • layers (list of int) – Number of neurons in each layer.

  • activation (torch.nn.Module) – Activation function to apply between layers.

  • use_BN (bool, optional) – Whether to use batch normalization. Default is False.

Returns:

Sequential model containing linear layers with optional batch norm and activation functions.

Return type:

torch.nn.Sequential

class ai4plasma.core.network.Network[source]

Bases: Module, ABC

Abstract base class for neural network architectures.

This class serves as a base for all neural network implementations. All subclasses should override the forward() and init_weights() methods to implement specific architectures and initialization strategies.

None
abstractmethod forward(x)[source]

Forward pass through the network.

Parameters:

x (torch.Tensor) – Input tensor to the network.

Returns:

Output tensor from the network.

Return type:

torch.Tensor

Notes

Abstract method that must be implemented by subclasses.

abstractmethod init_weights()[source]

Initialize network weights.

This method should be overridden by subclasses to provide specific weight initialization strategies appropriate for the network architecture.

Notes

Abstract method that must be implemented by subclasses.

class ai4plasma.core.network.RelaxFNN(layers, C_in_list, neuron_list)[source]

Bases: Network

Relaxed fully connected network for NAS-PINN.

Builds a stack of relaxed layers whose architectures are controlled by learnable parameters. The network supports architecture search by learning soft selections over identity vs. nonlinear paths and neuron counts.

Parameters:
  • layers (int) – Number of relaxed layers.

  • C_in_list (list of int) – Input dimension for each layer (length equals layers).

  • neuron_list (list of int) – Candidate neuron counts for each relaxed layer.

layers

Number of relaxed layers.

Type:

int

C_in_list

Input dimensions for each layer.

Type:

list of int

neuron_list

Candidate neuron counts.

Type:

list of int

network

Stack of RelaxLayer modules.

Type:

torch.nn.ModuleList

gs

Architecture parameters with shape (layers, len(neuron_list) + 1).

Type:

torch.Tensor

arch_parameters()[source]

Return architecture parameters for optimization.

Returns:

List containing the architecture parameter tensor.

Return type:

list of torch.Tensor

build_up()[source]

Build the relaxed network stack.

Creates a ModuleList of RelaxLayer instances based on C_in_list and neuron_list.

forward(x)[source]

Forward pass through the relaxed FNN.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, C_in_list[0]).

Returns:

Output tensor of shape (batch_size, 1).

Return type:

torch.Tensor

init_weights()[source]

Initialize architecture parameters gs.

Creates learnable parameters for each layer, with the first two entries controlling identity vs. nonlinear mixing and the remaining entries selecting neuron counts.

gs shape: (layers, len(neuron_list) + 1), where:
  • gs[i, 0] corresponds to the weight for the identity path in layer i.

  • gs[i, 1] corresponds to the weight for the nonlinear path in layer i.

  • gs[i, 2:] correspond to the weights for selecting neuron counts in layer i.

load_gs(arch_param)[source]

Load architecture parameters from an external list.

self.arch_para is a list length 1, where the first element is the architecture parameter tensor gs.

Parameters:

arch_param (list of torch.Tensor) – List where the first element contains the architecture parameters.

searched_neuron(threshold=0.001)[source]

Derive the discrete architecture from learned parameters.

If the identity vs. nonlinear weights are close, retain a residual connection and append the selected neuron count. The residual connection is denoted by ‘0+’ followed by the selected neuron count. Otherwise, select the dominant option.

Parameters:

threshold (float, optional) – Threshold for determining if identity and nonlinear paths are close. Default is 1e-3.

Returns:

List of selected neuron descriptors per layer.

Return type:

list

class ai4plasma.core.network.RelaxLayer(C_in, neuron_list, act_fun=Tanh(), init_method='xavier')[source]

Bases: Network

Relaxed hidden layer for NAS-PINN architecture search.

Implements a relaxed fully connected hidden layer where the effective operation is determined by learnable architecture parameters g. The layer blends identity and nonlinear transformations with soft selections over candidate neuron counts.

Parameters:
  • C_in (int) – Input feature dimension.

  • neuron_list (list of int) – Candidate neuron counts. neuron_list[0] is treated as identity.

  • act_fun (torch.nn.Module, optional) – Activation function applied after the linear operator. Default is Tanh.

  • init_method ({'xavier', 'zero'}, optional) – Weight initialization method for the linear operator. Default is ‘xavier’.

C_in

Input feature dimension.

Type:

int

neuron_list

Candidate neuron counts (including identity as index 0).

Type:

list of int

op

Linear + activation operator for the maximal neuron count.

Type:

torch.nn.Sequential

masks

Binary masks used to select subsets of neurons.

Type:

torch.Tensor

forward(x, g)[source]

Forward pass with relaxed architecture parameters.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, C_in).

  • g (torch.Tensor) – Architecture parameter tensor for this layer. The first two entries control identity vs. nonlinear mixing; the remaining entries select neuron counts.

Returns:

Output tensor of shape (batch_size, max(neuron_list)).

Return type:

torch.Tensor

init_weights(net, method='xavier')[source]

Initialize weights for the linear operator.

Parameters:
  • net (torch.nn.Sequential) – Operator whose weights will be initialized.

  • method ({'xavier', 'zero'}, optional) – Weight initialization method. Default is ‘xavier’.