3. Operator Learning Module
This module contains neural operators for learning mapping between function spaces.
3.1. DeepONet
DeepONet (Deep Operator Network) implementation for learning nonlinear operators.
This module provides a comprehensive framework for training and deploying DeepONet models, which are neural networks designed to learn nonlinear operators mapping between infinite-dimensional function spaces.
3.1.1. DeepONet Classes
DeepONet: Core neural network architecture with automatic branch type detection
DeepONetDataset: PyTorch Dataset supporting both 2D (FNN) and 4D (CNN) inputs
DeepONetModel: High-level training wrapper with checkpointing and TensorBoard
- class ai4plasma.operator.deeponet.DeepONet(branch_net, trunk_net)[source]
Bases:
ModuleNeural network architecture for learning nonlinear operators.
DeepONet consists of two sub-networks: a branch network and a trunk network. The branch network processes input functions (supports both FNN and CNN architectures), while the trunk network processes spatial/temporal locations (typically FNN). The outputs of these networks are combined via inner product to produce the final prediction.
- branch_net
Branch network (FNN or CNN) that processes input functions.
- Type:
torch.nn.Module
- trunk_net
Trunk network (typically FNN) that processes coordinates.
- Type:
torch.nn.Module
- bias_last
Learnable bias term added to final output.
- Type:
torch.nn.Parameter
- branch_is_cnn
Flag indicating if branch is CNN (4D) or FNN (2D).
- Type:
bool
- forward(branch_inputs, trunk_inputs)[source]
Forward pass through DeepONet.
Automatically handles both FNN (2D input) and CNN (4D input) branch networks. Combines branch and trunk outputs via inner product (Einstein summation) and adds a learnable bias term.
- Parameters:
branch_inputs (torch.Tensor) –
Input data for branch network.
FNN: shape (batch_size, features)
CNN: shape (batch_size, channels, height, width)
trunk_inputs (torch.Tensor) – Input for trunk network, shape (num_points, features).
- Returns:
Output prediction of DeepONet, shape (batch_size, num_points).
- Return type:
torch.Tensor
- class ai4plasma.operator.deeponet.DeepONetDataset(branch_inputs, trunk_inputs, targets, split_by_branch=True)[source]
Bases:
DatasetPyTorch Dataset for DeepONet supporting FNN and CNN branch networks.
This dataset handles the organization of branch-trunk-target triplets for DeepONet training. It supports flexible splitting strategies and automatically detects whether the branch input is for FNN (2D) or CNN (4D) networks.
- branch_inputs
Input data for the branch network.
- Type:
torch.Tensor
- trunk_inputs
Input data for the trunk network.
- Type:
torch.Tensor
- targets
Target output data.
- Type:
torch.Tensor
- split_by_branch
Splitting strategy flag.
- Type:
bool
- is_cnn_input
Flag indicating if branch input is 4D (CNN) or 2D (FNN).
- Type:
bool
- class ai4plasma.operator.deeponet.DeepONetModel(network)[source]
Bases:
BaseModelHigh-level training and inference wrapper for DeepONet.
This class provides a comprehensive interface for preparing training data, calculating loss, making predictions, and training the DeepONet model with advanced features including checkpointing, TensorBoard logging, learning rate scheduling, and resume capabilities.
- network
The DeepONet model to be trained.
- Type:
torch.nn.Module
- writer
TensorBoard writer for logging.
- Type:
torch.utils.tensorboard.SummaryWriter, optional
- checkpoint_dir
Directory for saving training checkpoints.
- Type:
str, optional
- start_epoch
Starting epoch for training (useful for resume).
- Type:
int
- is_cnn_input
Flag indicating if branch input is CNN (4D) or FNN (2D).
- Type:
bool
- dataset
Dataset instance for training.
- Type:
- dataloader
DataLoader instance for batched training.
- Type:
torch.utils.data.DataLoader
- calc_loss(data)[source]
Calculate the training loss for a given batch.
- Parameters:
data (tuple) – A tuple containing (branch_input, trunk_input, target).
- Returns:
The computed loss value (scalar).
- Return type:
torch.Tensor
- predict(branch_input_data, trunk_input_data)[source]
Perform inference using the trained DeepONet model.
Sets the model to evaluation mode and performs forward pass without computing gradients.
- Parameters:
branch_input_data (torch.Tensor) – Input data for the branch network.
trunk_input_data (torch.Tensor) – Input data for the trunk network.
- Returns:
The predicted output of the DeepONet model.
- Return type:
torch.Tensor
- prepare_train_data(branch_input_data, trunk_input_data, target_data, split_by_branch=True, batch_size=None, shuffle=False, drop_last=False)[source]
Prepare the training data and create a DataLoader.
Supports both FNN and CNN branch networks. Automatically moves data to the configured device (CPU/GPU) and creates a custom collate function to handle mixed tensor dimensions.
- Parameters:
branch_input_data (torch.Tensor) –
Input data for the branch network.
For FNN: shape (A, M) where A is samples, M is features
For CNN: shape (A, C, H, W) where C is channels, H, W are spatial dims
trunk_input_data (torch.Tensor) – Input data for the trunk network, with shape (B, N).
target_data (torch.Tensor) – Target output data, with shape (A, B).
split_by_branch (bool, optional) – If True, split by branch sample indices; if False, by trunk sample indices. Default is True.
batch_size (int, optional) – The batch size for the DataLoader. If None, use the full dataset as a single batch. Default is None.
shuffle (bool, optional) – Whether to shuffle the data in the DataLoader. Default is False.
drop_last (bool, optional) – Whether to drop the last incomplete batch if dataset size is not divisible by batch_size. Default is False.
- train(num_epochs, optimizer=None, optimizer_cfg: dict = None, lr: float = 0.0001, lr_scheduler=None, print_loss: bool = True, print_loss_freq: int = 1, tensorboard_logdir: str = None, save_final_model: bool = False, final_model_path: str = None, checkpoint_dir: str = None, checkpoint_freq: int = 1, resume_from: str = None)[source]
Train the DeepONet model with comprehensive options.
Supports pre-built optimizer instances or configuration dictionaries for construction. Also supports learning-rate adjustment, TensorBoard logging, checkpointing, and resuming.
- Parameters:
num_epochs (int) – Number of epochs to train (total, not additional when resuming).
optimizer (torch.optim.Optimizer, optional) – Pre-built optimizer instance. If provided, used as-is. Default is None.
optimizer_cfg (dict, optional) – Configuration to build optimizer when
optimizeris None. Example:{'name': 'Adam', 'params': {'lr': 1e-4, 'weight_decay': 0}}Default is None.lr (float, optional) – Default learning rate when building a default optimizer. Default is 1e-4.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler or callable, optional) – Optional scheduler to step each epoch. Default is None.
print_loss (bool, optional) – Whether to print loss each epoch. Default is True.
print_loss_freq (int, optional) – Print loss every N epochs. Default is 1.
tensorboard_logdir (str, optional) – Path to write TensorBoard logs. If provided, a SummaryWriter is created. Default is None.
save_final_model (bool, optional) – Whether to save final model at the end of training. Default is False.
final_model_path (str, optional) – Path to save final model if save_final_model=True. Default is None.
checkpoint_dir (str, optional) – Directory to save epoch checkpoints (model+optimizer+epoch). Default is None.
checkpoint_freq (int, optional) – Save checkpoint every N epochs. Default is 1.
resume_from (str, optional) – Path to checkpoint file to resume from. If provided, loads model/optimizer/epoch. Default is None.
- Raises:
RuntimeError – If prepare_train_data() has not been called before training.
FileNotFoundError – If resume_from path does not exist.
ValueError – If optimizer_cfg contains an unknown optimizer name.
3.2. DeepCSNet
DeepCSNet - Deep learning for electron-impact cross section prediction.
This module provides a specialized neural network architecture for predicting electron-impact cross sections in plasma physics applications. DeepCSNet employs a novel coefficient-subnet structure that separately processes different input feature types before combining them through tensor operations.
3.2.1. DeepCSNet Classes
DeepCSNet: Modular neural network with coefficient-subnet architecture
DeepCSNetDataset: Custom dataset class for cross section data handling
DeepCSNetModel: High-level wrapper for training and inference
3.2.2. DeepCSNet References
- [1] Y. Wang and L. Zhong, “DeepCSNet: a deep learning method for predicting
electron-impact doubly differential ionization cross sections,” Plasma Sources Science and Technology, vol. 33, no. 10, p. 105012, 2024.
- class ai4plasma.operator.deepcsnet.DeepCSNet(trunk_net: Module, molecule_net: Module = None, energy_net: Module = None)[source]
Bases:
ModuleCoefficient-subnet neural network for cross section prediction.
3. Architecture
The network consists of up to three optional sub-networks that process different input modalities:
Molecule Net (optional, for multi-molecule scenarios):
Processes molecular descriptors
Input: [batch_size, molecule_features]
Output: [batch_size, molecule_hidden_dim]
Energy Net (optional, for energy-dependent cross sections):
Processes incident energy features
Input: [batch_size, energy_features]
Output: [batch_size, energy_hidden_dim]
Trunk Net (required):
Processes output coordinate features
Input: [n_points, coordinate_features]
Output: [n_points, trunk_hidden_dim]
3. Operation Modes
SMC (Single-Molecule): Energy Net + Trunk Net
MMC (Multi-Molecule): Molecule Net + (optional Energy Net) + Trunk Net
- molecule_net
Sub-network for processing molecular features. If None, operates in SMC mode.
- Type:
torch.nn.Module or None
- energy_net
Sub-network for processing energy features. If None in MMC mode, only molecule features are used.
- Type:
torch.nn.Module or None
- trunk_net
Sub-network for processing coordinate features (output dimensions). Required.
- Type:
torch.nn.Module
- bias_last
Learnable scalar bias term added to final output. Shape: [1].
- Type:
torch.nn.Parameter
- module
Operating mode: “SMC” or “MMC”.
- Type:
str
Notes
At least one of molecule_net or energy_net must be provided
Output dimensions of all branch networks must match trunk network output dimension
In MMC mode with both networks, outputs are concatenated
- forward(trunk_input, molecule_input=None, energy_input=None)[source]
Forward pass through DeepCSNet.
Processes inputs through respective sub-networks and combines outputs via tensor product operations.
- Parameters:
trunk_input (torch.Tensor) – Input for trunk network (coordinate features). Shape: [n_points, coordinate_dim] Examples: scattering angles, ejected electron energies.
molecule_input (torch.Tensor, optional) – Input for molecule network (molecular features). Shape: [batch_size, molecule_dim] Required in MMC mode, should be None in SMC mode.
energy_input (torch.Tensor, optional) – Input for energy network (energy features). Shape: [batch_size, energy_dim] Required in SMC mode, optional in MMC mode.
- Returns:
Predicted cross section values. Shape: [batch_size, n_points] Each row represents cross sections at all output coordinates.
- Return type:
torch.Tensor
- class ai4plasma.operator.deepcsnet.DeepCSNetDataset(trunk_inputs: Tensor, molecule_inputs: Tensor = None, energy_inputs: Tensor = None, targets: Tensor = None, split_by_mole: bool = True)[source]
Bases:
DatasetCustom PyTorch Dataset for DeepCSNet training and inference.
This dataset class handles the specialized data structure required by DeepCSNet, where different input modalities (molecular features, energy features, coordinate features) need to be organized and batched appropriately for the coefficient-subnet architecture.
3. Overview
Unlike standard datasets where each sample is independent, DeepCSNet requires:
Branch inputs (molecule/energy): vary per case, shape [batch_size, features]
Trunk inputs (coordinates): shared across all cases, shape [n_points, coord_dim]
Targets: cross section values, shape [batch_size, n_points]
This dataset supports two splitting strategies:
split_by_mole=True: Iterate over cases (molecules/energies)split_by_mole=False: Iterate over coordinates
- molecule_inputs
Molecular feature tensor, shape [n_cases, molecule_dim].
- Type:
torch.Tensor or None
- trunk_inputs
Coordinate feature tensor, shape [n_points, coordinate_dim].
- Type:
torch.Tensor
- split_by_mole
Splitting strategy flag (True: by cases, False: by coordinates).
- Type:
bool
- class ai4plasma.operator.deepcsnet.DeepCSNetModel(network: Module)[source]
Bases:
BaseModelHigh-level training and inference wrapper for DeepCSNet.
This class provides a complete pipeline for working with DeepCSNet, including: data preparation and batching, training loop, model checkpointing, and inference.
- network
The underlying DeepCSNet architecture.
- Type:
torch.nn.Module
- dataset
Custom dataset instance (set by prepare_train_data).
- Type:
- dataloader
PyTorch DataLoader with custom collate function.
- Type:
torch.utils.data.DataLoader
- loss_func
Loss function (defaults to MSELoss).
- Type:
callable
- optimizer
Optimizer for gradient-based training.
- Type:
torch.optim.Optimizer
- start_epoch
Starting epoch for resuming training.
- Type:
int
- calc_loss(data)[source]
Calculate the training loss for a given batch.
Performs forward pass and computes loss using configured loss function.
- Parameters:
data (tuple) – A 4-tuple (molecule_inputs, energy_inputs, trunk_inputs, targets): - molecule_inputs: [batch_size, molecule_dim] or None - energy_inputs: [batch_size, energy_dim] or None - trunk_inputs: [n_points, coordinate_dim] - targets: [batch_size, n_points]
- Returns:
Scalar loss value from loss_func(predictions, targets).
- Return type:
torch.Tensor
- predict(trunk_input, molecule_input=None, energy_input=None)[source]
Perform inference using the trained DeepCSNet model.
Sets network to evaluation mode and performs forward pass.
- Parameters:
trunk_input (torch.Tensor) – Coordinate features for output dimensions. Shape: [n_points, coordinate_dim]
molecule_input (torch.Tensor, optional) – Molecular features for cases. Shape: [batch_size, molecule_dim]. Required in MMC mode, should be None in SMC mode.
energy_input (torch.Tensor, optional) – Energy features for cases. Shape: [batch_size, energy_dim]. Required in SMC mode, optional in MMC mode.
- Returns:
Predicted cross section values. Shape: [batch_size, n_points]
- Return type:
torch.Tensor
- prepare_train_data(trunk_inputs: Tensor, molecule_inputs: Tensor = None, energy_inputs: Tensor = None, targets: Tensor = None, split_by_mole: bool = True, batch_size: int = None, shuffle: bool = False, drop_last: bool = False)[source]
Prepare training data and create a DataLoader.
Converts raw tensor data into a DeepCSNetDataset and wraps it with a DataLoader that uses a custom collate function.
- Parameters:
trunk_inputs (torch.Tensor) – Coordinate features (output dimensions). Shape: [n_points, coordinate_dim]
molecule_inputs (torch.Tensor, optional) – Molecular features. Shape: [n_cases, molecule_dim]. Default is None.
energy_inputs (torch.Tensor, optional) – Energy features. Shape: [n_cases, energy_dim]. Default is None.
targets (torch.Tensor, optional) – Target cross section values. Shape: [n_cases, n_points]. Default is None.
split_by_mole (bool, default=True) – Splitting strategy: - True: Iterate over cases, trunk_inputs replicated - False: Iterate over coordinates, branch inputs replicated
batch_size (int, optional) – Number of samples per batch. If None, uses full dataset. Default is None.
shuffle (bool, default=False) – Whether to shuffle the dataset at the beginning of each epoch.
drop_last (bool, default=False) – Whether to drop the last incomplete batch.
- train(num_epochs: int, optimizer: Optimizer = None, optimizer_cfg: dict = None, lr: float = 0.0001, lr_scheduler: _LRScheduler = None, print_loss: bool = True, print_loss_freq: int = 1, tensorboard_logdir: str = None, save_final_model: bool = True, final_model_path: str = None, checkpoint_dir: str = None, checkpoint_freq: int = 1, resume_from: str = None)[source]
Train the DeepCSNet model with comprehensive configuration.
Implements a complete training pipeline with flexible optimizer configuration, learning rate scheduling, checkpoint management, and TensorBoard logging.
- Parameters:
num_epochs (int) – Total number of training epochs. If resuming: this is TOTAL target epochs (not additional). Example: If resume from epoch 500 with num_epochs=1000, training continues for 500 more epochs (500→1000).
optimizer (torch.optim.Optimizer, optional) – Pre-configured optimizer instance. If provided, used directly. If None, created based on optimizer_cfg or lr. Default is None.
optimizer_cfg (dict, optional) – Dictionary configuration for building optimizer when optimizer=None. Format:
{'name': 'OptimizerName', 'params': {param_dict}}Example:{'name': 'Adam', 'params': {'lr': 1e-4, 'weight_decay': 1e-5}}Supported names: Any optimizer in torch.optim (Adam, SGD, AdamW, etc.) Default is None.lr (float, default=1e-4) – Default learning rate when creating optimizer automatically. Used only if optimizer=None and ‘lr’ not in optimizer_cfg[‘params’]. Typical range: 1e-5 to 1e-3.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler or callable, optional) – Learning rate scheduler for adaptive adjustment. Can be PyTorch scheduler or custom callable. If provided, .step() is called after each epoch. Default is None.
print_loss (bool, default=True) – Whether to print loss to console during training.
print_loss_freq (int, default=1) – Frequency (in epochs) for printing loss. Example: print_loss_freq=10 prints every 10 epochs.
tensorboard_logdir (str, optional) – Directory path for TensorBoard log files. If provided, creates SummaryWriter and logs loss each epoch. View logs with:
tensorboard --logdir=<tensorboard_logdir>Default is None.save_final_model (bool, default=False) – Whether to save final trained model after all epochs.
final_model_path (str, optional) – File path for saving final model if save_final_model=True. If None and checkpoint_dir set: saves to <checkpoint_dir>/final_model.pth If None and checkpoint_dir None: saves to ‘./final_model.pth’ Default is None.
checkpoint_dir (str, optional) – Directory path for saving periodic training checkpoints. If provided, checkpoints named ‘checkpoint_epoch_N.pth’ are saved. Enables training interruption recovery. Default is None.
checkpoint_freq (int, default=1) – Frequency (in epochs) for saving checkpoints. Example: checkpoint_freq=100 saves every 100 epochs.
resume_from (str, optional) – File path to checkpoint for resuming training. If provided, loads model state, optimizer state, and epoch number. Default is None.
- Raises:
RuntimeError – If prepare_train_data() has not been called.
ValueError – If optimizer_cfg specifies unknown optimizer name.
FileNotFoundError – If resume_from path does not exist.