Source code for ai4plasma.operator.deeponet

"""DeepONet (Deep Operator Network) implementation for learning nonlinear operators.

This module provides a comprehensive framework for training and deploying DeepONet models,
which are neural networks designed to learn nonlinear operators mapping between
infinite-dimensional function spaces.

DeepONet Classes
----------------
- `DeepONet`: Core neural network architecture with automatic branch type detection
- `DeepONetDataset`: PyTorch Dataset supporting both 2D (FNN) and 4D (CNN) inputs
- `DeepONetModel`: High-level training wrapper with checkpointing and TensorBoard
"""

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from ai4plasma.core.model import BaseModel
from ai4plasma.config import DEVICE



[docs] class DeepONet(nn.Module): """Neural network architecture for learning nonlinear operators. DeepONet consists of two sub-networks: a branch network and a trunk network. The branch network processes input functions (supports both FNN and CNN architectures), while the trunk network processes spatial/temporal locations (typically FNN). The outputs of these networks are combined via inner product to produce the final prediction. Attributes ---------- branch_net : torch.nn.Module Branch network (FNN or CNN) that processes input functions. trunk_net : torch.nn.Module Trunk network (typically FNN) that processes coordinates. bias_last : torch.nn.Parameter Learnable bias term added to final output. branch_is_cnn : bool Flag indicating if branch is CNN (4D) or FNN (2D). """ def __init__(self, branch_net, trunk_net): """Initialize DeepONet with branch and trunk networks. Parameters ---------- branch_net : torch.nn.Module Branch network that processes input functions. Can be FNN (2D input) or CNN (4D input). trunk_net : torch.nn.Module Trunk network for spatial/temporal locations. """ super(DeepONet, self).__init__() self.branch_net = branch_net self.trunk_net = trunk_net self.bias_last = nn.Parameter(torch.zeros(1)) # Initialize bias term to zero # Auto-detect if branch is CNN based on class name self.branch_is_cnn = self._detect_cnn_branch() def _detect_cnn_branch(self): """Detect if branch network is CNN or FNN based on class name. Returns ------- bool True if branch network is CNN, False if FNN. """ branch_class_name = self.branch_net.__class__.__name__ return branch_class_name.upper() == 'CNN'
[docs] def forward(self, branch_inputs, trunk_inputs): """Forward pass through DeepONet. Automatically handles both FNN (2D input) and CNN (4D input) branch networks. Combines branch and trunk outputs via inner product (Einstein summation) and adds a learnable bias term. Parameters ---------- branch_inputs : torch.Tensor Input data for branch network. - FNN: shape (batch_size, features) - CNN: shape (batch_size, channels, height, width) trunk_inputs : torch.Tensor Input for trunk network, shape (num_points, features). Returns ------- torch.Tensor Output prediction of DeepONet, shape (batch_size, num_points). """ # Process through branch network branch = self.branch_net(branch_inputs) # Process through trunk network trunk = self.trunk_net(trunk_inputs) # Combine outputs of branch and trunk networks using Einstein summation # branch shape: (batch_size, d) or (batch_size, output_dim) # trunk shape: (num_points, d) or (num_points, output_dim) out = torch.einsum("bi,ni->bn", branch, trunk) # Add the learnable bias term to the final output out += self.bias_last return out
[docs] class DeepONetDataset(Dataset): """PyTorch Dataset for DeepONet supporting FNN and CNN branch networks. This dataset handles the organization of branch-trunk-target triplets for DeepONet training. It supports flexible splitting strategies and automatically detects whether the branch input is for FNN (2D) or CNN (4D) networks. Attributes ---------- branch_inputs : torch.Tensor Input data for the branch network. trunk_inputs : torch.Tensor Input data for the trunk network. targets : torch.Tensor Target output data. split_by_branch : bool Splitting strategy flag. is_cnn_input : bool Flag indicating if branch input is 4D (CNN) or 2D (FNN). """ def __init__(self, branch_inputs, trunk_inputs, targets, split_by_branch=True): """Initialize the DeepONet dataset. Parameters ---------- branch_inputs : torch.Tensor Input data for the branch network. - For FNN: shape (A, M) where A is number of samples, M is feature dimension - For CNN: shape (A, C, H, W) where A is number of samples, C is channels, H, W are spatial dims trunk_inputs : torch.Tensor Input data for the trunk network, with shape (B, N) where B is number of points, N is feature dimension. targets : torch.Tensor Target output data, with shape (A, B). split_by_branch : bool, optional If True, split the dataset by the sample indices of branch_inputs; if False, split the dataset by the sample indices of trunk_inputs. Default is True. """ self.branch_inputs = branch_inputs self.trunk_inputs = trunk_inputs self.targets = targets self.split_by_branch = split_by_branch # Detect input shape: 2D for FNN, 4D for CNN self.is_cnn_input = (branch_inputs.dim() == 4) def __len__(self): """Return the size of the dataset. Returns ------- int The size of the dataset (number of samples based on split strategy). """ if self.split_by_branch: return len(self.branch_inputs) # Split by the number of samples in branch_inputs else: return len(self.trunk_inputs) # Split by the number of samples in trunk_inputs def __getitem__(self, idx): """Return a sample by index. Handles both 2D (FNN) and 4D (CNN) branch inputs, squeezing single dimensions and maintaining proper batch dimensions. Parameters ---------- idx : int The index of the sample. Returns ------- tuple A tuple containing (branch_input, trunk_input, target). - branch_input: For FNN: shape (M,), For CNN: shape (C, H, W) - trunk_input: shape (B, N) or (N,) depending on split_by_branch - target: shape (B,) or (A,) depending on split_by_branch """ if self.split_by_branch: # Split by the sample indices of branch_inputs if self.is_cnn_input: # For CNN: return 3D tensor (C, H, W) instead of 4D (1, C, H, W) return self.branch_inputs[idx], self.trunk_inputs, self.targets[idx, :] else: # For FNN: return 1D tensor (M,) instead of 2D (1, M) return self.branch_inputs[idx, :], self.trunk_inputs, self.targets[idx, :] else: # Split by the sample indices of trunk_inputs return self.branch_inputs, self.trunk_inputs[idx, :], self.targets[:, idx]
[docs] class DeepONetModel(BaseModel): """High-level training and inference wrapper for DeepONet. This class provides a comprehensive interface for preparing training data, calculating loss, making predictions, and training the DeepONet model with advanced features including checkpointing, TensorBoard logging, learning rate scheduling, and resume capabilities. Attributes ---------- network : torch.nn.Module The DeepONet model to be trained. writer : torch.utils.tensorboard.SummaryWriter, optional TensorBoard writer for logging. checkpoint_dir : str, optional Directory for saving training checkpoints. start_epoch : int Starting epoch for training (useful for resume). is_cnn_input : bool Flag indicating if branch input is CNN (4D) or FNN (2D). dataset : DeepONetDataset Dataset instance for training. dataloader : torch.utils.data.DataLoader DataLoader instance for batched training. """ def __init__(self, network) -> None: """Initialize the DeepONetModel. Parameters ---------- network : torch.nn.Module The DeepONet model to be trained (must be an instance of DeepONet). """ super().__init__(network=network) # defaults for optional features self.writer = None self.checkpoint_dir = None self.start_epoch = 0
[docs] def prepare_train_data(self, branch_input_data, trunk_input_data, target_data, split_by_branch=True, batch_size=None, shuffle=False, drop_last=False): """Prepare the training data and create a DataLoader. Supports both FNN and CNN branch networks. Automatically moves data to the configured device (CPU/GPU) and creates a custom collate function to handle mixed tensor dimensions. Parameters ---------- branch_input_data : torch.Tensor Input data for the branch network. - For FNN: shape (A, M) where A is samples, M is features - For CNN: shape (A, C, H, W) where C is channels, H, W are spatial dims trunk_input_data : torch.Tensor Input data for the trunk network, with shape (B, N). target_data : torch.Tensor Target output data, with shape (A, B). split_by_branch : bool, optional If True, split by branch sample indices; if False, by trunk sample indices. Default is True. batch_size : int, optional The batch size for the DataLoader. If None, use the full dataset as a single batch. Default is None. shuffle : bool, optional Whether to shuffle the data in the DataLoader. Default is False. drop_last : bool, optional Whether to drop the last incomplete batch if dataset size is not divisible by batch_size. Default is False. """ # Move data to the specified device (e.g., GPU) branch_input_data = branch_input_data.to(DEVICE()) trunk_input_data = trunk_input_data.to(DEVICE()) target_data = target_data.to(DEVICE()) # Detect if branch input is CNN (4D) or FNN (2D) is_cnn_input = (branch_input_data.dim() == 4) self.is_cnn_input = is_cnn_input # Create the dataset self.dataset = DeepONetDataset(branch_input_data, trunk_input_data, target_data, split_by_branch) # Set batch size to the full dataset size if not specified batch_size = len(self.dataset) if batch_size is None else batch_size def custom_collate_fn(batch): """Custom collate function handling FNN (2D) and CNN (4D) branch inputs. Ensures trunk_inputs remains (B, N) and reconstructs proper batch shapes by stacking individual samples. Parameters ---------- batch : list List of tuples returned by __getitem__. Returns ------- tuple (branch_inputs, trunk_inputs, targets) with proper shapes. - FNN: branch_inputs shape (batch_size, features) - CNN: branch_inputs shape (batch_size, channels, height, width) """ branch_inputs = [item[0] for item in batch] trunk_inputs = batch[0][1] # trunk_inputs is the same for all items targets = [item[2] for item in batch] # Stack branch inputs - automatically handles both 3D (CNN) and 1D (FNN) cases branch_inputs = torch.stack(branch_inputs) targets = torch.stack(targets) return branch_inputs, trunk_inputs, targets # Create the DataLoader with custom collate_fn self.dataloader = DataLoader(dataset=self.dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last, collate_fn=custom_collate_fn)
[docs] def calc_loss(self, data): """Calculate the training loss for a given batch. Parameters ---------- data : tuple A tuple containing (branch_input, trunk_input, target). Returns ------- torch.Tensor The computed loss value (scalar). """ branch_input_data, trunk_input_data, target_data = data predict_data = self.network(branch_input_data, trunk_input_data) loss = self.loss_func(predict_data, target_data) return loss
[docs] def predict(self, branch_input_data, trunk_input_data): """Perform inference using the trained DeepONet model. Sets the model to evaluation mode and performs forward pass without computing gradients. Parameters ---------- branch_input_data : torch.Tensor Input data for the branch network. trunk_input_data : torch.Tensor Input data for the trunk network. Returns ------- torch.Tensor The predicted output of the DeepONet model. """ self.network.eval() # Set the model to evaluation mode return self.network(branch_input_data, trunk_input_data)
[docs] def train(self, num_epochs, optimizer=None, optimizer_cfg: dict = None, lr: float = 1e-4, lr_scheduler=None, print_loss: bool = True, print_loss_freq: int = 1, tensorboard_logdir: str = None, save_final_model: bool = False, final_model_path: str = None, checkpoint_dir: str = None, checkpoint_freq: int = 1, resume_from: str = None): """Train the DeepONet model with comprehensive options. Supports pre-built optimizer instances or configuration dictionaries for construction. Also supports learning-rate adjustment, TensorBoard logging, checkpointing, and resuming. Parameters ---------- num_epochs : int Number of epochs to train (total, not additional when resuming). optimizer : torch.optim.Optimizer, optional Pre-built optimizer instance. If provided, used as-is. Default is None. optimizer_cfg : dict, optional Configuration to build optimizer when ``optimizer`` is None. Example: ``{'name': 'Adam', 'params': {'lr': 1e-4, 'weight_decay': 0}}`` Default is None. lr : float, optional Default learning rate when building a default optimizer. Default is 1e-4. lr_scheduler : torch.optim.lr_scheduler._LRScheduler or callable, optional Optional scheduler to step each epoch. Default is None. print_loss : bool, optional Whether to print loss each epoch. Default is True. print_loss_freq : int, optional Print loss every N epochs. Default is 1. tensorboard_logdir : str, optional Path to write TensorBoard logs. If provided, a SummaryWriter is created. Default is None. save_final_model : bool, optional Whether to save final model at the end of training. Default is False. final_model_path : str, optional Path to save final model if save_final_model=True. Default is None. checkpoint_dir : str, optional Directory to save epoch checkpoints (model+optimizer+epoch). Default is None. checkpoint_freq : int, optional Save checkpoint every N epochs. Default is 1. resume_from : str, optional Path to checkpoint file to resume from. If provided, loads model/optimizer/epoch. Default is None. Raises ------ RuntimeError If prepare_train_data() has not been called before training. FileNotFoundError If resume_from path does not exist. ValueError If optimizer_cfg contains an unknown optimizer name. """ # Ensure dataloader exists if not hasattr(self, 'dataloader') or self.dataloader is None: raise RuntimeError('Call prepare_train_data(...) before train(...)') # Set default loss if not set if not hasattr(self, 'loss_func') or self.loss_func is None: self.set_loss_func(nn.MSELoss()) # Tensorboard writer if tensorboard_logdir is not None and tensorboard_logdir != "": os.makedirs(tensorboard_logdir, exist_ok=True) self.writer = SummaryWriter(tensorboard_logdir) print(f'Tensorboard writer created at: {tensorboard_logdir}') # Build or use provided optimizer if optimizer is not None: self.set_optimizer(optimizer) else: # If optimizer_cfg provided, use it if optimizer_cfg is not None and isinstance(optimizer_cfg, dict): name = optimizer_cfg.get('name', 'Adam') params = optimizer_cfg.get('params', {}) # if lr provided explicitly, allow override if 'lr' not in params and lr is not None: params['lr'] = lr opt_cls = getattr(torch.optim, name, None) if opt_cls is None: raise ValueError(f'Unknown optimizer name: {name}') self.set_optimizer(opt_cls(self.network.parameters(), **params)) else: # default optimizer self.set_optimizer(torch.optim.Adam(self.network.parameters(), lr=lr)) # checkpoint directory if checkpoint_dir is not None and checkpoint_dir != "": os.makedirs(checkpoint_dir, exist_ok=True) self.checkpoint_dir = checkpoint_dir # resume if requested self.start_epoch = 0 if resume_from is not None and resume_from != "": if os.path.exists(resume_from): ckpt = torch.load(resume_from, map_location=DEVICE()) if 'model' in ckpt: self.network.load_state_dict(ckpt['model']) if 'optimizer' in ckpt and self.optimizer is not None: try: self.optimizer.load_state_dict(ckpt['optimizer']) except Exception: print('Failed to load optimizer state (incompatible). Continuing with fresh optimizer state') if 'epoch' in ckpt: self.start_epoch = ckpt['epoch'] + 1 print(f'Resuming from checkpoint {resume_from} at epoch {self.start_epoch}') else: raise FileNotFoundError(f'Resume checkpoint not found: {resume_from}') # Training loop total_epochs = self.start_epoch + num_epochs for epoch in range(self.start_epoch, total_epochs): epoch_loss = 0.0 batch_count = 0 for branch_batch, trunk_batch, target_batch in self.dataloader: # Compute loss loss = self.calc_loss((branch_batch, trunk_batch, target_batch)) # Backpropagation and optimization self.optimizer.zero_grad() loss.backward() self.optimizer.step() epoch_loss += loss.item() batch_count += 1 # learning rate scheduler step if provided if lr_scheduler is not None: # allow callable or torch scheduler try: lr_scheduler.step() except Exception: # maybe callable that expects epoch lr_scheduler(epoch) avg_loss = epoch_loss / batch_count if batch_count > 0 else float('nan') # logging if print_loss and print_loss_freq > 0 and (epoch+1) % print_loss_freq == 0: print(f'Epoch [{epoch+1}/{total_epochs}], Loss: {avg_loss:g}') if self.writer is not None: self.writer.add_scalar('Loss', avg_loss, epoch+1) self.writer.flush() # checkpointing if self.checkpoint_dir is not None and checkpoint_freq > 0 and ((epoch+1) % checkpoint_freq == 0): ckpt_file = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') torch.save({'model': self.network.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': epoch}, ckpt_file) print(f'Checkpoint saved: {ckpt_file}') # final save if save_final_model: if final_model_path is None or final_model_path == "": if self.checkpoint_dir is not None: final_model_path = os.path.join(self.checkpoint_dir, 'final_model.pth') else: final_model_path = 'final_model.pth' torch.save({'model': self.network.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': total_epochs - 1}, final_model_path) print(f'Final model saved to: {final_model_path}') # close writer if self.writer is not None: self.writer.close()