Source code for ai4plasma.operator.deepcsnet

"""DeepCSNet - Deep learning for electron-impact cross section prediction.

This module provides a specialized neural network architecture for predicting
electron-impact cross sections in plasma physics applications. DeepCSNet employs
a novel coefficient-subnet structure that separately processes different input
feature types before combining them through tensor operations.

DeepCSNet Classes
-----------------
- `DeepCSNet`: Modular neural network with coefficient-subnet architecture
- `DeepCSNetDataset`: Custom dataset class for cross section data handling
- `DeepCSNetModel`: High-level wrapper for training and inference

DeepCSNet References
--------------------
[1] Y. Wang and L. Zhong, "DeepCSNet: a deep learning method for predicting
    electron-impact doubly differential ionization cross sections,"
    Plasma Sources Science and Technology, vol. 33, no. 10, p. 105012, 2024.
"""

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from ai4plasma.core.model import BaseModel
from ai4plasma.config import DEVICE



[docs] class DeepCSNet(nn.Module): """Coefficient-subnet neural network for cross section prediction. Architecture ------------ The network consists of up to three optional sub-networks that process different input modalities: 1. **Molecule Net** (optional, for multi-molecule scenarios): - Processes molecular descriptors - Input: [batch_size, molecule_features] - Output: [batch_size, molecule_hidden_dim] 2. **Energy Net** (optional, for energy-dependent cross sections): - Processes incident energy features - Input: [batch_size, energy_features] - Output: [batch_size, energy_hidden_dim] 3. **Trunk Net** (required): - Processes output coordinate features - Input: [n_points, coordinate_features] - Output: [n_points, trunk_hidden_dim] Operation Modes --------------- - **SMC (Single-Molecule)**: Energy Net + Trunk Net - **MMC (Multi-Molecule)**: Molecule Net + (optional Energy Net) + Trunk Net Attributes ---------- molecule_net : torch.nn.Module or None Sub-network for processing molecular features. If None, operates in SMC mode. energy_net : torch.nn.Module or None Sub-network for processing energy features. If None in MMC mode, only molecule features are used. trunk_net : torch.nn.Module Sub-network for processing coordinate features (output dimensions). Required. bias_last : torch.nn.Parameter Learnable scalar bias term added to final output. Shape: [1]. module : str Operating mode: "SMC" or "MMC". Notes ----- - At least one of molecule_net or energy_net must be provided - Output dimensions of all branch networks must match trunk network output dimension - In MMC mode with both networks, outputs are concatenated """ def __init__(self, trunk_net: nn.Module, molecule_net: nn.Module = None, energy_net: nn.Module = None, ): """Initialize the DeepCSNet model. Parameters ---------- trunk_net : torch.nn.Module Sub-network for processing coordinate features (output dimensions). Input shape: [n_points, coordinate_dim] Output shape: [n_points, trunk_hidden_dim] Required and processes output coordinate space. molecule_net : torch.nn.Module, optional Sub-network for processing molecular features. Input shape: [batch_size, molecule_dim] Output shape: [batch_size, molecule_hidden_dim] If provided, operates in MMC mode. Default is None. energy_net : torch.nn.Module, optional Sub-network for processing energy features. Input shape: [batch_size, energy_dim] Output shape: [batch_size, energy_hidden_dim] Required in SMC mode, optional in MMC mode. Default is None. Raises ------ ValueError If trunk_net is None (trunk network is mandatory). If both molecule_net and energy_net are None (at least one required). """ super(DeepCSNet, self).__init__() if trunk_net is None: raise ValueError("Trunk net must be provided for DeepCSNet.") if molecule_net is None and energy_net is None: raise ValueError("At least one of molecule_net or energy_net must be provided for DeepCSNet.") self.molecule_net = molecule_net self.energy_net = energy_net self.trunk_net = trunk_net self.bias_last = nn.Parameter(torch.zeros(1)) # Initialize bis term to zero self.module = "SMC" if molecule_net is None else "MMC" # Determine the module type based on the presence of molecule_net
[docs] def forward(self, trunk_input, molecule_input=None, energy_input=None): """Forward pass through DeepCSNet. Processes inputs through respective sub-networks and combines outputs via tensor product operations. Parameters ---------- trunk_input : torch.Tensor Input for trunk network (coordinate features). Shape: [n_points, coordinate_dim] Examples: scattering angles, ejected electron energies. molecule_input : torch.Tensor, optional Input for molecule network (molecular features). Shape: [batch_size, molecule_dim] Required in MMC mode, should be None in SMC mode. energy_input : torch.Tensor, optional Input for energy network (energy features). Shape: [batch_size, energy_dim] Required in SMC mode, optional in MMC mode. Returns ------- torch.Tensor Predicted cross section values. Shape: [batch_size, n_points] Each row represents cross sections at all output coordinates. """ trunk_output = self.trunk_net(trunk_input) if self.module == "SMC": energy_output = self.energy_net(energy_input) output = torch.einsum('bi, ni->bn', energy_output, trunk_output) + self.bias_last else: # MMC molecule_output = self.molecule_net(molecule_input) if energy_input is None: output = torch.einsum('bi, ni->bn', molecule_output, trunk_output) + self.bias_last else: energy_output = self.energy_net(energy_input) branch = torch.cat((molecule_output, energy_output), dim=1) output = torch.einsum('bi, ni->bn', branch, trunk_output) + self.bias_last return output
[docs] class DeepCSNetDataset(Dataset): """Custom PyTorch Dataset for DeepCSNet training and inference. This dataset class handles the specialized data structure required by DeepCSNet, where different input modalities (molecular features, energy features, coordinate features) need to be organized and batched appropriately for the coefficient-subnet architecture. Overview -------- Unlike standard datasets where each sample is independent, DeepCSNet requires: - Branch inputs (molecule/energy): vary per case, shape [batch_size, features] - Trunk inputs (coordinates): shared across all cases, shape [n_points, coord_dim] - Targets: cross section values, shape [batch_size, n_points] This dataset supports two splitting strategies: 1. ``split_by_mole=True``: Iterate over cases (molecules/energies) 2. ``split_by_mole=False``: Iterate over coordinates Attributes ---------- molecule_inputs : torch.Tensor or None Molecular feature tensor, shape [n_cases, molecule_dim]. energy_inputs : torch.Tensor or None Energy feature tensor, shape [n_cases, energy_dim]. trunk_inputs : torch.Tensor Coordinate feature tensor, shape [n_points, coordinate_dim]. targets : torch.Tensor or None Target cross section tensor, shape [n_cases, n_points]. split_by_mole : bool Splitting strategy flag (True: by cases, False: by coordinates). """ def __init__(self, trunk_inputs: torch.Tensor, molecule_inputs: torch.Tensor = None, energy_inputs: torch.Tensor = None, targets: torch.Tensor = None, split_by_mole: bool = True ): """Initialize the DeepCSNet dataset. Parameters ---------- trunk_inputs : torch.Tensor Coordinate features (output dimensions). Shape: [n_points, coordinate_dim] molecule_inputs : torch.Tensor, optional Molecular features tensor. Shape: [n_cases, molecule_dim]. Default is None. energy_inputs : torch.Tensor, optional Energy features tensor. Shape: [n_cases, energy_dim]. Default is None. targets : torch.Tensor, optional Target cross section values. Shape: [n_cases, n_points]. Default is None. split_by_mole : bool, default=True Splitting strategy: - True: Split by case index (iterate over molecules/energies) - False: Split by coordinate index (iterate over output points) Raises ------ ValueError If both molecule_inputs and energy_inputs are None. """ if molecule_inputs is None and energy_inputs is None: raise ValueError("At least one of molecule_inputs or energy_inputs must be provided for DeepCSNetDataset.") self.molecule_inputs = molecule_inputs self.energy_inputs = energy_inputs self.trunk_inputs = trunk_inputs self.targets = targets self.split_by_mole = split_by_mole def __len__(self): """Return the size of the dataset. Returns ------- int Dataset length based on splitting strategy: - If split_by_mole=True: number of cases - If split_by_mole=False: number of output coordinates """ if self.split_by_mole: return len(self.energy_inputs) if self.energy_inputs is not None else len(self.molecule_inputs) else: return len(self.trunk_inputs) def __getitem__(self, idx): """Retrieve a single sample from the dataset. Parameters ---------- idx : int Sample index interpretation depends on split_by_mole: - If split_by_mole=True: index into cases - If split_by_mole=False: index into coordinates Returns ------- tuple A 4-tuple (molecule_input, energy_input, trunk_input, target): When split_by_mole=True: - molecule_input: tensor [molecule_dim] or None - energy_input: tensor [energy_dim] or None - trunk_input: tensor [n_points, coordinate_dim] - target: tensor [n_points] When split_by_mole=False: - molecule_input: tensor [n_cases, molecule_dim] or None - energy_input: tensor [n_cases, energy_dim] or None - trunk_input: tensor [coordinate_dim] - target: tensor [n_cases] """ if self.split_by_mole: # Split by the sample indices of molecule_inputs or energy_inputs if self.energy_inputs is not None and self.molecule_inputs is not None: return self.molecule_inputs[idx, :], self.energy_inputs[idx, :], self.trunk_inputs, self.targets[idx, :] elif self.energy_inputs is not None: return None, self.energy_inputs[idx, :], self.trunk_inputs, self.targets[idx, :] else: return self.molecule_inputs[idx, :], None, self.trunk_inputs, self.targets[idx, :] else: # Split by the sample indices of trunk_inputs return self.molecule_inputs, self.energy_inputs, self.trunk_inputs[idx, :], self.targets[:, idx]
[docs] class DeepCSNetModel(BaseModel): """High-level training and inference wrapper for DeepCSNet. This class provides a complete pipeline for working with DeepCSNet, including: data preparation and batching, training loop, model checkpointing, and inference. Attributes ---------- network : torch.nn.Module The underlying DeepCSNet architecture. dataset : DeepCSNetDataset Custom dataset instance (set by prepare_train_data). dataloader : torch.utils.data.DataLoader PyTorch DataLoader with custom collate function. loss_func : callable Loss function (defaults to MSELoss). optimizer : torch.optim.Optimizer Optimizer for gradient-based training. writer : torch.utils.tensorboard.SummaryWriter or None TensorBoard writer for logging. checkpoint_dir : str or None Directory for saving checkpoints. start_epoch : int Starting epoch for resuming training. """ def __init__(self, network: nn.Module): """Initialize the DeepCSNetModel. Parameters ---------- network : torch.nn.Module The DeepCSNet architecture to be trained. Must be an instance of DeepCSNet or compatible network. """ super().__init__(network=network) # defaults for optional features self.writer = None self.checkpoint_dir = None self.start_epoch = 0
[docs] def prepare_train_data(self, trunk_inputs: torch.Tensor, molecule_inputs: torch.Tensor = None, energy_inputs: torch.Tensor = None, targets: torch.Tensor = None, split_by_mole: bool = True, batch_size: int = None, shuffle: bool = False, drop_last: bool = False, ): """Prepare training data and create a DataLoader. Converts raw tensor data into a DeepCSNetDataset and wraps it with a DataLoader that uses a custom collate function. Parameters ---------- trunk_inputs : torch.Tensor Coordinate features (output dimensions). Shape: [n_points, coordinate_dim] molecule_inputs : torch.Tensor, optional Molecular features. Shape: [n_cases, molecule_dim]. Default is None. energy_inputs : torch.Tensor, optional Energy features. Shape: [n_cases, energy_dim]. Default is None. targets : torch.Tensor, optional Target cross section values. Shape: [n_cases, n_points]. Default is None. split_by_mole : bool, default=True Splitting strategy: - True: Iterate over cases, trunk_inputs replicated - False: Iterate over coordinates, branch inputs replicated batch_size : int, optional Number of samples per batch. If None, uses full dataset. Default is None. shuffle : bool, default=False Whether to shuffle the dataset at the beginning of each epoch. drop_last : bool, default=False Whether to drop the last incomplete batch. """ # Move data to the specified device (e.g., GPU) if molecule_inputs is not None: molecule_inputs = molecule_inputs.to(DEVICE()) if energy_inputs is not None: energy_inputs = energy_inputs.to(DEVICE()) trunk_inputs = trunk_inputs.to(DEVICE()) targets = targets.to(DEVICE()) # Create the dataset self.dataset = DeepCSNetDataset(trunk_inputs=trunk_inputs, molecule_inputs=molecule_inputs, energy_inputs=energy_inputs, targets=targets, split_by_mole=split_by_mole) # Set batch size to the full dataset size if not specified batch_size = len(self.dataset) if batch_size is None else batch_size def custom_collate_fn(batch): """Custom collate function for DeepCSNet batching. Standard PyTorch collate stacks all tensors along a new batch dimension. However, DeepCSNet requires trunk_inputs to remain [n_points, coord_dim] (not batched) since all cases in a batch share the same output coordinates. Parameters ---------- batch : list of tuples List of samples from DeepCSNetDataset.__getitem__(). Each tuple: (molecule_input, energy_input, trunk_input, target) Returns ------- tuple (molecule_inputs, energy_inputs, trunk_inputs, targets) where: - molecule_inputs: [batch_size, molecule_dim] or None - energy_inputs: [batch_size, energy_dim] or None - trunk_inputs: [n_points, coordinate_dim] (unbatched) - targets: [batch_size, n_points] """ molecule_inputs = [item[0] for item in batch] energy_inputs = [item[1] for item in batch] trunk_inputs = batch[0][2] # trunk_inputs is the same for all items in the batch targets = [item[3] for item in batch] # Stack inputs molecule_inputs = torch.stack(molecule_inputs) if molecule_inputs[0] is not None else None energy_inputs = torch.stack(energy_inputs) if energy_inputs[0] is not None else None targets = torch.stack(targets) return molecule_inputs, energy_inputs, trunk_inputs, targets # Create the DataLoader with custom collate function self.dataloader = DataLoader(dataset=self.dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last, collate_fn=custom_collate_fn)
[docs] def calc_loss(self, data): """Calculate the training loss for a given batch. Performs forward pass and computes loss using configured loss function. Parameters ---------- data : tuple A 4-tuple (molecule_inputs, energy_inputs, trunk_inputs, targets): - molecule_inputs: [batch_size, molecule_dim] or None - energy_inputs: [batch_size, energy_dim] or None - trunk_inputs: [n_points, coordinate_dim] - targets: [batch_size, n_points] Returns ------- torch.Tensor Scalar loss value from loss_func(predictions, targets). """ molecule_inputs, energy_inputs, trunk_inputs, targets = data predictions = self.network(trunk_input=trunk_inputs, molecule_input=molecule_inputs, energy_input=energy_inputs) loss = self.loss_func(predictions, targets) return loss
[docs] def predict(self, trunk_input, molecule_input=None, energy_input=None): """Perform inference using the trained DeepCSNet model. Sets network to evaluation mode and performs forward pass. Parameters ---------- trunk_input : torch.Tensor Coordinate features for output dimensions. Shape: [n_points, coordinate_dim] molecule_input : torch.Tensor, optional Molecular features for cases. Shape: [batch_size, molecule_dim]. Required in MMC mode, should be None in SMC mode. energy_input : torch.Tensor, optional Energy features for cases. Shape: [batch_size, energy_dim]. Required in SMC mode, optional in MMC mode. Returns ------- torch.Tensor Predicted cross section values. Shape: [batch_size, n_points] """ self.network.eval() # Set the network to evaluation mode return self.network(trunk_input=trunk_input, molecule_input=molecule_input, energy_input=energy_input)
[docs] def train(self, num_epochs: int, optimizer: torch.optim.Optimizer = None, optimizer_cfg: dict = None, lr: float = 1e-4, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, print_loss: bool = True, print_loss_freq: int = 1, tensorboard_logdir: str = None, save_final_model: bool = True, final_model_path: str = None, checkpoint_dir: str = None, checkpoint_freq: int = 1, resume_from: str = None, ): """Train the DeepCSNet model with comprehensive configuration. Implements a complete training pipeline with flexible optimizer configuration, learning rate scheduling, checkpoint management, and TensorBoard logging. Parameters ---------- num_epochs : int Total number of training epochs. If resuming: this is TOTAL target epochs (not additional). Example: If resume from epoch 500 with num_epochs=1000, training continues for 500 more epochs (500→1000). optimizer : torch.optim.Optimizer, optional Pre-configured optimizer instance. If provided, used directly. If None, created based on optimizer_cfg or lr. Default is None. optimizer_cfg : dict, optional Dictionary configuration for building optimizer when optimizer=None. Format: ``{'name': 'OptimizerName', 'params': {param_dict}}`` Example: ``{'name': 'Adam', 'params': {'lr': 1e-4, 'weight_decay': 1e-5}}`` Supported names: Any optimizer in torch.optim (Adam, SGD, AdamW, etc.) Default is None. lr : float, default=1e-4 Default learning rate when creating optimizer automatically. Used only if optimizer=None and 'lr' not in optimizer_cfg['params']. Typical range: 1e-5 to 1e-3. lr_scheduler : torch.optim.lr_scheduler._LRScheduler or callable, optional Learning rate scheduler for adaptive adjustment. Can be PyTorch scheduler or custom callable. If provided, .step() is called after each epoch. Default is None. print_loss : bool, default=True Whether to print loss to console during training. print_loss_freq : int, default=1 Frequency (in epochs) for printing loss. Example: print_loss_freq=10 prints every 10 epochs. tensorboard_logdir : str, optional Directory path for TensorBoard log files. If provided, creates SummaryWriter and logs loss each epoch. View logs with: ``tensorboard --logdir=<tensorboard_logdir>`` Default is None. save_final_model : bool, default=False Whether to save final trained model after all epochs. final_model_path : str, optional File path for saving final model if save_final_model=True. If None and checkpoint_dir set: saves to <checkpoint_dir>/final_model.pth If None and checkpoint_dir None: saves to './final_model.pth' Default is None. checkpoint_dir : str, optional Directory path for saving periodic training checkpoints. If provided, checkpoints named 'checkpoint_epoch_N.pth' are saved. Enables training interruption recovery. Default is None. checkpoint_freq : int, default=1 Frequency (in epochs) for saving checkpoints. Example: checkpoint_freq=100 saves every 100 epochs. resume_from : str, optional File path to checkpoint for resuming training. If provided, loads model state, optimizer state, and epoch number. Default is None. Raises ------ RuntimeError If prepare_train_data() has not been called. ValueError If optimizer_cfg specifies unknown optimizer name. FileNotFoundError If resume_from path does not exist. """ # Ensure dataloader exists if not hasattr(self, 'dataloader') or self.dataloader is None: raise RuntimeError("Call prepare_train_data(...) before train(...)") # Set default loss is not set if not hasattr(self, 'loss_func') or self.loss_func is None: self.set_loss_func(nn.MSELoss()) # Tensorboard writer if tensorboard_logdir is not None and tensorboard_logdir != "": os.makedirs(tensorboard_logdir, exist_ok=True) self.writer = SummaryWriter(tensorboard_logdir) print(f'Tensorboard writer created at: {tensorboard_logdir}') # Build or use provided optimizer if optimizer is not None: self.set_optimizer(optimizer) else: # If optimizer_cfg provided, use it if optimizer_cfg is not None and isinstance(optimizer_cfg, dict): name = optimizer_cfg.get('name', 'Adam') params = optimizer_cfg.get('params', {}) # if lr provided explicitly, allow override if 'lr' not in params and lr is not None: params['lr'] = lr opt_cls = getattr(torch.optim, name, None) if opt_cls is None: raise ValueError(f'Unknown optimizer name: {name}') self.set_optimizer(opt_cls(self.network.parameters(), **params)) else: # default optimizer self.set_optimizer(torch.optim.Adam(self.network.parameters(), lr=lr)) # checkpoint directory if checkpoint_dir is not None and checkpoint_dir != "": os.makedirs(checkpoint_dir, exist_ok=True) self.checkpoint_dir = checkpoint_dir # resume if requested self.start_epoch = 0 if resume_from is not None and resume_from != "": if os.path.exists(resume_from): ckpt = torch.load(resume_from, map_location=DEVICE()) if 'model' in ckpt: self.network.load_state_dict(ckpt['model']) if 'optimizer' in ckpt and self.optimizer is not None: try: self.optimizer.load_state_dict(ckpt['optimizer']) except Exception: print('Failed to load optimizer state (incompatible). Continuing with fresh optimizer state') if 'epoch' in ckpt: self.start_epoch = ckpt['epoch'] + 1 print(f'Resuming from checkpoint {resume_from} at epoch {self.start_epoch}') else: raise FileNotFoundError(f'Resume checkpoint not found: {resume_from}') # Training loop total_epochs = self.start_epoch + num_epochs for epoch in range(self.start_epoch, total_epochs): epoch_loss = 0.0 batch_count = 0 for molecule_batch, energy_batch, trunk_batch, target_batch in self.dataloader: # Compute loss loss = self.calc_loss((molecule_batch, energy_batch, trunk_batch, target_batch)) # Backpropagation and optimization self.optimizer.zero_grad() loss.backward() self.optimizer.step() epoch_loss += loss.item() batch_count += 1 # learning rate scheduler step if provided if lr_scheduler is not None: # allow callable or torch scheduler try: lr_scheduler.step() except Exception: # maybe callable that expects epoch lr_scheduler(epoch) avg_loss = epoch_loss / batch_count if batch_count > 0 else float('nan') # logging if print_loss and print_loss_freq > 0 and (epoch + 1) % print_loss_freq == 0: print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {avg_loss:g}') if self.writer is not None: self.writer.add_scalar('Loss', avg_loss, epoch + 1) self.writer.flush() # checkpointing if self.checkpoint_dir is not None and checkpoint_freq > 0 and (epoch + 1) % checkpoint_freq == 0: ckpt_file = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') torch.save({'model': self.network.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': epoch}, ckpt_file) print(f'Checkpoint saved: {ckpt_file}') # final save if save_final_model: if final_model_path is None or final_model_path == "": if self.checkpoint_dir is not None: final_model_path = os.path.join(self.checkpoint_dir, 'final_model.pth') else: final_model_path = 'final_model.pth' torch.save({'model': self.network.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': total_epochs - 1}, final_model_path) print(f'Final model saved to: {final_model_path}') # close writer if self.writer is not None: self.writer.close()