3. Operator Learning Module

This module contains neural operators for learning mapping between function spaces.

3.1. DeepONet

DeepONet (Deep Operator Network) implementation for learning nonlinear operators.

This module provides a comprehensive framework for training and deploying DeepONet models, which are neural networks designed to learn nonlinear operators mapping between infinite-dimensional function spaces.

3.1.1. DeepONet Classes

  • DeepONet: Core neural network architecture with automatic branch type detection

  • DeepONetDataset: PyTorch Dataset supporting both 2D (FNN) and 4D (CNN) inputs

  • DeepONetModel: High-level training wrapper with checkpointing and TensorBoard

class ai4plasma.operator.deeponet.DeepONet(branch_net, trunk_net)[source]

Bases: Module

Neural network architecture for learning nonlinear operators.

DeepONet consists of two sub-networks: a branch network and a trunk network. The branch network processes input functions (supports both FNN and CNN architectures), while the trunk network processes spatial/temporal locations (typically FNN). The outputs of these networks are combined via inner product to produce the final prediction.

branch_net

Branch network (FNN or CNN) that processes input functions.

Type:

torch.nn.Module

trunk_net

Trunk network (typically FNN) that processes coordinates.

Type:

torch.nn.Module

bias_last

Learnable bias term added to final output.

Type:

torch.nn.Parameter

branch_is_cnn

Flag indicating if branch is CNN (4D) or FNN (2D).

Type:

bool

forward(branch_inputs, trunk_inputs)[source]

Forward pass through DeepONet.

Automatically handles both FNN (2D input) and CNN (4D input) branch networks. Combines branch and trunk outputs via inner product (Einstein summation) and adds a learnable bias term.

Parameters:
  • branch_inputs (torch.Tensor) –

    Input data for branch network.

    • FNN: shape (batch_size, features)

    • CNN: shape (batch_size, channels, height, width)

  • trunk_inputs (torch.Tensor) – Input for trunk network, shape (num_points, features).

Returns:

Output prediction of DeepONet, shape (batch_size, num_points).

Return type:

torch.Tensor

class ai4plasma.operator.deeponet.DeepONetDataset(branch_inputs, trunk_inputs, targets, split_by_branch=True)[source]

Bases: Dataset

PyTorch Dataset for DeepONet supporting FNN and CNN branch networks.

This dataset handles the organization of branch-trunk-target triplets for DeepONet training. It supports flexible splitting strategies and automatically detects whether the branch input is for FNN (2D) or CNN (4D) networks.

branch_inputs

Input data for the branch network.

Type:

torch.Tensor

trunk_inputs

Input data for the trunk network.

Type:

torch.Tensor

targets

Target output data.

Type:

torch.Tensor

split_by_branch

Splitting strategy flag.

Type:

bool

is_cnn_input

Flag indicating if branch input is 4D (CNN) or 2D (FNN).

Type:

bool

class ai4plasma.operator.deeponet.DeepONetModel(network)[source]

Bases: BaseModel

High-level training and inference wrapper for DeepONet.

This class provides a comprehensive interface for preparing training data, calculating loss, making predictions, and training the DeepONet model with advanced features including checkpointing, TensorBoard logging, learning rate scheduling, and resume capabilities.

network

The DeepONet model to be trained.

Type:

torch.nn.Module

writer

TensorBoard writer for logging.

Type:

torch.utils.tensorboard.SummaryWriter, optional

checkpoint_dir

Directory for saving training checkpoints.

Type:

str, optional

start_epoch

Starting epoch for training (useful for resume).

Type:

int

is_cnn_input

Flag indicating if branch input is CNN (4D) or FNN (2D).

Type:

bool

dataset

Dataset instance for training.

Type:

DeepONetDataset

dataloader

DataLoader instance for batched training.

Type:

torch.utils.data.DataLoader

calc_loss(data)[source]

Calculate the training loss for a given batch.

Parameters:

data (tuple) – A tuple containing (branch_input, trunk_input, target).

Returns:

The computed loss value (scalar).

Return type:

torch.Tensor

predict(branch_input_data, trunk_input_data)[source]

Perform inference using the trained DeepONet model.

Sets the model to evaluation mode and performs forward pass without computing gradients.

Parameters:
  • branch_input_data (torch.Tensor) – Input data for the branch network.

  • trunk_input_data (torch.Tensor) – Input data for the trunk network.

Returns:

The predicted output of the DeepONet model.

Return type:

torch.Tensor

prepare_train_data(branch_input_data, trunk_input_data, target_data, split_by_branch=True, batch_size=None, shuffle=False, drop_last=False)[source]

Prepare the training data and create a DataLoader.

Supports both FNN and CNN branch networks. Automatically moves data to the configured device (CPU/GPU) and creates a custom collate function to handle mixed tensor dimensions.

Parameters:
  • branch_input_data (torch.Tensor) –

    Input data for the branch network.

    • For FNN: shape (A, M) where A is samples, M is features

    • For CNN: shape (A, C, H, W) where C is channels, H, W are spatial dims

  • trunk_input_data (torch.Tensor) – Input data for the trunk network, with shape (B, N).

  • target_data (torch.Tensor) – Target output data, with shape (A, B).

  • split_by_branch (bool, optional) – If True, split by branch sample indices; if False, by trunk sample indices. Default is True.

  • batch_size (int, optional) – The batch size for the DataLoader. If None, use the full dataset as a single batch. Default is None.

  • shuffle (bool, optional) – Whether to shuffle the data in the DataLoader. Default is False.

  • drop_last (bool, optional) – Whether to drop the last incomplete batch if dataset size is not divisible by batch_size. Default is False.

train(num_epochs, optimizer=None, optimizer_cfg: dict = None, lr: float = 0.0001, lr_scheduler=None, print_loss: bool = True, print_loss_freq: int = 1, tensorboard_logdir: str = None, save_final_model: bool = False, final_model_path: str = None, checkpoint_dir: str = None, checkpoint_freq: int = 1, resume_from: str = None)[source]

Train the DeepONet model with comprehensive options.

Supports pre-built optimizer instances or configuration dictionaries for construction. Also supports learning-rate adjustment, TensorBoard logging, checkpointing, and resuming.

Parameters:
  • num_epochs (int) – Number of epochs to train (total, not additional when resuming).

  • optimizer (torch.optim.Optimizer, optional) – Pre-built optimizer instance. If provided, used as-is. Default is None.

  • optimizer_cfg (dict, optional) – Configuration to build optimizer when optimizer is None. Example: {'name': 'Adam', 'params': {'lr': 1e-4, 'weight_decay': 0}} Default is None.

  • lr (float, optional) – Default learning rate when building a default optimizer. Default is 1e-4.

  • lr_scheduler (torch.optim.lr_scheduler._LRScheduler or callable, optional) – Optional scheduler to step each epoch. Default is None.

  • print_loss (bool, optional) – Whether to print loss each epoch. Default is True.

  • print_loss_freq (int, optional) – Print loss every N epochs. Default is 1.

  • tensorboard_logdir (str, optional) – Path to write TensorBoard logs. If provided, a SummaryWriter is created. Default is None.

  • save_final_model (bool, optional) – Whether to save final model at the end of training. Default is False.

  • final_model_path (str, optional) – Path to save final model if save_final_model=True. Default is None.

  • checkpoint_dir (str, optional) – Directory to save epoch checkpoints (model+optimizer+epoch). Default is None.

  • checkpoint_freq (int, optional) – Save checkpoint every N epochs. Default is 1.

  • resume_from (str, optional) – Path to checkpoint file to resume from. If provided, loads model/optimizer/epoch. Default is None.

Raises:
  • RuntimeError – If prepare_train_data() has not been called before training.

  • FileNotFoundError – If resume_from path does not exist.

  • ValueError – If optimizer_cfg contains an unknown optimizer name.

3.2. DeepCSNet

DeepCSNet - Deep learning for electron-impact cross section prediction.

This module provides a specialized neural network architecture for predicting electron-impact cross sections in plasma physics applications. DeepCSNet employs a novel coefficient-subnet structure that separately processes different input feature types before combining them through tensor operations.

3.2.1. DeepCSNet Classes

  • DeepCSNet: Modular neural network with coefficient-subnet architecture

  • DeepCSNetDataset: Custom dataset class for cross section data handling

  • DeepCSNetModel: High-level wrapper for training and inference

3.2.2. DeepCSNet References

[1] Y. Wang and L. Zhong, “DeepCSNet: a deep learning method for predicting

electron-impact doubly differential ionization cross sections,” Plasma Sources Science and Technology, vol. 33, no. 10, p. 105012, 2024.

class ai4plasma.operator.deepcsnet.DeepCSNet(trunk_net: Module, molecule_net: Module = None, energy_net: Module = None)[source]

Bases: Module

Coefficient-subnet neural network for cross section prediction.

3. Architecture

The network consists of up to three optional sub-networks that process different input modalities:

  1. Molecule Net (optional, for multi-molecule scenarios):

    • Processes molecular descriptors

    • Input: [batch_size, molecule_features]

    • Output: [batch_size, molecule_hidden_dim]

  2. Energy Net (optional, for energy-dependent cross sections):

    • Processes incident energy features

    • Input: [batch_size, energy_features]

    • Output: [batch_size, energy_hidden_dim]

  3. Trunk Net (required):

    • Processes output coordinate features

    • Input: [n_points, coordinate_features]

    • Output: [n_points, trunk_hidden_dim]

3. Operation Modes

  • SMC (Single-Molecule): Energy Net + Trunk Net

  • MMC (Multi-Molecule): Molecule Net + (optional Energy Net) + Trunk Net

molecule_net

Sub-network for processing molecular features. If None, operates in SMC mode.

Type:

torch.nn.Module or None

energy_net

Sub-network for processing energy features. If None in MMC mode, only molecule features are used.

Type:

torch.nn.Module or None

trunk_net

Sub-network for processing coordinate features (output dimensions). Required.

Type:

torch.nn.Module

bias_last

Learnable scalar bias term added to final output. Shape: [1].

Type:

torch.nn.Parameter

module

Operating mode: “SMC” or “MMC”.

Type:

str

Notes

  • At least one of molecule_net or energy_net must be provided

  • Output dimensions of all branch networks must match trunk network output dimension

  • In MMC mode with both networks, outputs are concatenated

forward(trunk_input, molecule_input=None, energy_input=None)[source]

Forward pass through DeepCSNet.

Processes inputs through respective sub-networks and combines outputs via tensor product operations.

Parameters:
  • trunk_input (torch.Tensor) – Input for trunk network (coordinate features). Shape: [n_points, coordinate_dim] Examples: scattering angles, ejected electron energies.

  • molecule_input (torch.Tensor, optional) – Input for molecule network (molecular features). Shape: [batch_size, molecule_dim] Required in MMC mode, should be None in SMC mode.

  • energy_input (torch.Tensor, optional) – Input for energy network (energy features). Shape: [batch_size, energy_dim] Required in SMC mode, optional in MMC mode.

Returns:

Predicted cross section values. Shape: [batch_size, n_points] Each row represents cross sections at all output coordinates.

Return type:

torch.Tensor

class ai4plasma.operator.deepcsnet.DeepCSNetDataset(trunk_inputs: Tensor, molecule_inputs: Tensor = None, energy_inputs: Tensor = None, targets: Tensor = None, split_by_mole: bool = True)[source]

Bases: Dataset

Custom PyTorch Dataset for DeepCSNet training and inference.

This dataset class handles the specialized data structure required by DeepCSNet, where different input modalities (molecular features, energy features, coordinate features) need to be organized and batched appropriately for the coefficient-subnet architecture.

3. Overview

Unlike standard datasets where each sample is independent, DeepCSNet requires:

  • Branch inputs (molecule/energy): vary per case, shape [batch_size, features]

  • Trunk inputs (coordinates): shared across all cases, shape [n_points, coord_dim]

  • Targets: cross section values, shape [batch_size, n_points]

This dataset supports two splitting strategies:

  1. split_by_mole=True: Iterate over cases (molecules/energies)

  2. split_by_mole=False: Iterate over coordinates

molecule_inputs

Molecular feature tensor, shape [n_cases, molecule_dim].

Type:

torch.Tensor or None

energy_inputs

Energy feature tensor, shape [n_cases, energy_dim].

Type:

torch.Tensor or None

trunk_inputs

Coordinate feature tensor, shape [n_points, coordinate_dim].

Type:

torch.Tensor

targets

Target cross section tensor, shape [n_cases, n_points].

Type:

torch.Tensor or None

split_by_mole

Splitting strategy flag (True: by cases, False: by coordinates).

Type:

bool

class ai4plasma.operator.deepcsnet.DeepCSNetModel(network: Module)[source]

Bases: BaseModel

High-level training and inference wrapper for DeepCSNet.

This class provides a complete pipeline for working with DeepCSNet, including: data preparation and batching, training loop, model checkpointing, and inference.

network

The underlying DeepCSNet architecture.

Type:

torch.nn.Module

dataset

Custom dataset instance (set by prepare_train_data).

Type:

DeepCSNetDataset

dataloader

PyTorch DataLoader with custom collate function.

Type:

torch.utils.data.DataLoader

loss_func

Loss function (defaults to MSELoss).

Type:

callable

optimizer

Optimizer for gradient-based training.

Type:

torch.optim.Optimizer

writer

TensorBoard writer for logging.

Type:

torch.utils.tensorboard.SummaryWriter or None

checkpoint_dir

Directory for saving checkpoints.

Type:

str or None

start_epoch

Starting epoch for resuming training.

Type:

int

calc_loss(data)[source]

Calculate the training loss for a given batch.

Performs forward pass and computes loss using configured loss function.

Parameters:

data (tuple) – A 4-tuple (molecule_inputs, energy_inputs, trunk_inputs, targets): - molecule_inputs: [batch_size, molecule_dim] or None - energy_inputs: [batch_size, energy_dim] or None - trunk_inputs: [n_points, coordinate_dim] - targets: [batch_size, n_points]

Returns:

Scalar loss value from loss_func(predictions, targets).

Return type:

torch.Tensor

predict(trunk_input, molecule_input=None, energy_input=None)[source]

Perform inference using the trained DeepCSNet model.

Sets network to evaluation mode and performs forward pass.

Parameters:
  • trunk_input (torch.Tensor) – Coordinate features for output dimensions. Shape: [n_points, coordinate_dim]

  • molecule_input (torch.Tensor, optional) – Molecular features for cases. Shape: [batch_size, molecule_dim]. Required in MMC mode, should be None in SMC mode.

  • energy_input (torch.Tensor, optional) – Energy features for cases. Shape: [batch_size, energy_dim]. Required in SMC mode, optional in MMC mode.

Returns:

Predicted cross section values. Shape: [batch_size, n_points]

Return type:

torch.Tensor

prepare_train_data(trunk_inputs: Tensor, molecule_inputs: Tensor = None, energy_inputs: Tensor = None, targets: Tensor = None, split_by_mole: bool = True, batch_size: int = None, shuffle: bool = False, drop_last: bool = False)[source]

Prepare training data and create a DataLoader.

Converts raw tensor data into a DeepCSNetDataset and wraps it with a DataLoader that uses a custom collate function.

Parameters:
  • trunk_inputs (torch.Tensor) – Coordinate features (output dimensions). Shape: [n_points, coordinate_dim]

  • molecule_inputs (torch.Tensor, optional) – Molecular features. Shape: [n_cases, molecule_dim]. Default is None.

  • energy_inputs (torch.Tensor, optional) – Energy features. Shape: [n_cases, energy_dim]. Default is None.

  • targets (torch.Tensor, optional) – Target cross section values. Shape: [n_cases, n_points]. Default is None.

  • split_by_mole (bool, default=True) – Splitting strategy: - True: Iterate over cases, trunk_inputs replicated - False: Iterate over coordinates, branch inputs replicated

  • batch_size (int, optional) – Number of samples per batch. If None, uses full dataset. Default is None.

  • shuffle (bool, default=False) – Whether to shuffle the dataset at the beginning of each epoch.

  • drop_last (bool, default=False) – Whether to drop the last incomplete batch.

train(num_epochs: int, optimizer: Optimizer = None, optimizer_cfg: dict = None, lr: float = 0.0001, lr_scheduler: _LRScheduler = None, print_loss: bool = True, print_loss_freq: int = 1, tensorboard_logdir: str = None, save_final_model: bool = True, final_model_path: str = None, checkpoint_dir: str = None, checkpoint_freq: int = 1, resume_from: str = None)[source]

Train the DeepCSNet model with comprehensive configuration.

Implements a complete training pipeline with flexible optimizer configuration, learning rate scheduling, checkpoint management, and TensorBoard logging.

Parameters:
  • num_epochs (int) – Total number of training epochs. If resuming: this is TOTAL target epochs (not additional). Example: If resume from epoch 500 with num_epochs=1000, training continues for 500 more epochs (500→1000).

  • optimizer (torch.optim.Optimizer, optional) – Pre-configured optimizer instance. If provided, used directly. If None, created based on optimizer_cfg or lr. Default is None.

  • optimizer_cfg (dict, optional) – Dictionary configuration for building optimizer when optimizer=None. Format: {'name': 'OptimizerName', 'params': {param_dict}} Example: {'name': 'Adam', 'params': {'lr': 1e-4, 'weight_decay': 1e-5}} Supported names: Any optimizer in torch.optim (Adam, SGD, AdamW, etc.) Default is None.

  • lr (float, default=1e-4) – Default learning rate when creating optimizer automatically. Used only if optimizer=None and ‘lr’ not in optimizer_cfg[‘params’]. Typical range: 1e-5 to 1e-3.

  • lr_scheduler (torch.optim.lr_scheduler._LRScheduler or callable, optional) – Learning rate scheduler for adaptive adjustment. Can be PyTorch scheduler or custom callable. If provided, .step() is called after each epoch. Default is None.

  • print_loss (bool, default=True) – Whether to print loss to console during training.

  • print_loss_freq (int, default=1) – Frequency (in epochs) for printing loss. Example: print_loss_freq=10 prints every 10 epochs.

  • tensorboard_logdir (str, optional) – Directory path for TensorBoard log files. If provided, creates SummaryWriter and logs loss each epoch. View logs with: tensorboard --logdir=<tensorboard_logdir> Default is None.

  • save_final_model (bool, default=False) – Whether to save final trained model after all epochs.

  • final_model_path (str, optional) – File path for saving final model if save_final_model=True. If None and checkpoint_dir set: saves to <checkpoint_dir>/final_model.pth If None and checkpoint_dir None: saves to ‘./final_model.pth’ Default is None.

  • checkpoint_dir (str, optional) – Directory path for saving periodic training checkpoints. If provided, checkpoints named ‘checkpoint_epoch_N.pth’ are saved. Enables training interruption recovery. Default is None.

  • checkpoint_freq (int, default=1) – Frequency (in epochs) for saving checkpoints. Example: checkpoint_freq=100 saves every 100 epochs.

  • resume_from (str, optional) – File path to checkpoint for resuming training. If provided, loads model state, optimizer state, and epoch number. Default is None.

Raises:
  • RuntimeError – If prepare_train_data() has not been called.

  • ValueError – If optimizer_cfg specifies unknown optimizer name.

  • FileNotFoundError – If resume_from path does not exist.