Source code for ai4plasma.piml.meta_pinn

"""Meta-learning Physics-Informed Neural Networks (Meta-PINN).

This module implements Model-Agnostic Meta-Learning (MAML) for physics-informed
neural networks. It provides a task abstraction, bi-level optimization pipeline,
and training utilities for rapid adaptation to new physics tasks with limited data.

Meta-PINN Classes
-----------------
- `MetaTask`: Abstract task interface for meta-learning.
- `PINNTask`: PINN-specific task implementation with equation term losses.
- `MetaPINN`: MAML trainer for PINN task batches.

Meta-PINN References
--------------------
[1] L. Zhong, B. Wu, and Y. Wang, "Accelerating physics-informed neural network
    based 1D arc simulation by meta learning," Journal of Physics D: Applied Physics,
    vol. 56, p. 074006, 2023.
"""

import os
import traceback
from typing import Callable, Dict, List, Tuple, Optional, Union
from abc import ABC, abstractmethod
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from ai4plasma.core.network import FNN
from ai4plasma.piml.geo import Geo1D
from ai4plasma.piml.pinn import PINN, VisualizationCallback
from ai4plasma.piml.cs_pinn import calc_GL_coefs
from ai4plasma.plasma.prop import ArcPropSpline
from ai4plasma.utils.math import df_dX
from ai4plasma.config import REAL, DEVICE


[docs] class MetaTask(ABC): """ Abstract base class for defining meta-learning tasks. A meta-learning task encapsulates a specific problem instance within a broader family of related problems. In the MAML framework, each task contains support and query sets used for task-specific adaptation and meta-parameter updates. This abstraction allows the meta-learning framework to handle diverse physics problems in a unified manner by implementing the compute_loss method. It follows the Strategy design pattern, enabling different task types to be plugged into the MetaPINN framework without modifying core algorithms. Parameters ---------- task_id : str Unique identifier for this task (e.g., 'Arc_I=200A', 'Poisson_k=1.5') support_data : Dict[str, torch.Tensor], optional Dictionary mapping data names to support set tensors. Support data is used for task-specific adaptation during the inner loop. Default is empty dict. Example: {'Domain': x_domain, 'Boundary': x_boundary} query_data : Dict[str, torch.Tensor], optional Dictionary mapping data names to query set tensors. Query data is disjoint from support data and used for meta-parameter updates. Default is empty dict. Same structure as support_data but with different samples. Attributes ---------- task_id : str Unique identifier for this task instance. support_data : Dict[str, torch.Tensor] Support set for inner loop training (few-shot adaptation). query_data : Dict[str, torch.Tensor] Query set for outer loop meta-updates (meta-gradient computation). """ def __init__(self, task_id: str, support_data: Dict[str, torch.Tensor] = None, query_data: Dict[str, torch.Tensor] = None): """ Initialize a meta-learning task. Parameters ---------- task_id : str Unique identifier for this task instance. support_data : Dict[str, torch.Tensor], optional Support set data for inner loop adaptation. Keys are equation term names (e.g., 'Domain', 'Boundary'). Values are collocation point tensors. Default is None (converted to empty dict). query_data : Dict[str, torch.Tensor], optional Query set data for meta-validation. Must have same structure as support_data but with different samples. Default is None. """ super().__init__() self.task_id = task_id self.support_data = support_data or {} self.query_data = query_data or {}
[docs] @abstractmethod def compute_loss(self, network: nn.Module, data_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Compute total loss for this task on given data. This abstract method defines how to evaluate the network's performance on task data. It is called during both inner loop adaptation (support set) and outer loop meta-updates (query set). Subclasses must implement this method. Parameters ---------- network : nn.Module Neural network to evaluate. May be adapted in inner loop. data_dict : Dict[str, torch.Tensor] Dictionary of data tensors for different equation terms. Keys match equation term names (e.g., 'Domain', 'Boundary'). Values are collocation points or boundary data. Returns ------- total_loss : torch.Tensor Scalar tensor representing the total weighted loss. loss_dict : Dict[str, torch.Tensor] Dictionary mapping equation term names to individual losses. Useful for monitoring training and statistical analysis. """ pass
[docs] def get_task_id(self) -> str: """ Get unique identifier for this task. Returns ------- str Task identifier (e.g., 'Arc_I=200A_R=0.01m'). """ return self.task_id
[docs] def get_support_data(self) -> Dict[str, torch.Tensor]: """ Get support set data for inner loop training. The support set is used to adapt the meta-initialized network to a specific task during the inner loop of MAML. Typically contains fewer samples than traditional PINN training. Returns ------- Dict[str, torch.Tensor] Dictionary mapping equation term names to support data tensors. """ return self.support_data
[docs] def get_query_data(self) -> Dict[str, torch.Tensor]: """ Get query set data for meta-validation. The query set is used to evaluate the adapted network and compute gradients for meta-parameter updates in the outer loop. It must be disjoint from the support set to ensure unbiased meta-learning. Returns ------- Dict[str, torch.Tensor] Dictionary mapping equation term names to query data tensors. """ return self.query_data
[docs] class PINNTask(MetaTask): """ PINN-specific meta-learning task implementation. This class bridges the PINN framework with the meta-learning paradigm by wrapping a PINN model and its equation terms into a task suitable for MAML. It implements the compute_loss method by iterating over all equation terms defined in the PINN. Parameters ---------- task_id : str Unique identifier for this physics task. pinn_model : PINN, optional PINN model instance containing equation definitions. Must have defined equation terms via add_equation(). support_data : Dict[str, torch.Tensor], optional Support set collocation points for inner loop adaptation. query_data : Dict[str, torch.Tensor], optional Query set collocation points for meta-update evaluation. Attributes ---------- pinn_model : PINN The underlying PINN model defining equation terms and residuals. loss_func : callable Loss function for comparing residuals to zero (e.g., MSE, smooth L1). """ def __init__(self, task_id: str, pinn_model: Optional[PINN] = None, support_data: Dict[str, torch.Tensor] = None, query_data: Dict[str, torch.Tensor] = None): """ Initialize a PINN task for meta-learning. Parameters ---------- task_id : str Unique identifier for this physics task. pinn_model : PINN, optional PINN model instance containing equation definitions. Must have defined equation terms via add_equation(). support_data : Dict[str, torch.Tensor], optional Support set collocation points for inner loop adaptation. query_data : Dict[str, torch.Tensor], optional Query set collocation points for meta-update evaluation. """ super().__init__(task_id, support_data, query_data) self.pinn_model = pinn_model self.loss_func = pinn_model.loss_func if pinn_model else F.mse_loss
[docs] def compute_loss(self, network: nn.Module, data_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Compute total weighted loss for this PINN task. Parameters ---------- network : nn.Module Neural network to evaluate (may be adapted in inner loop). data_dict : Dict[str, torch.Tensor] Dictionary mapping equation term names to data tensors. Example: {'Domain': x_pde, 'Boundary': x_bc, 'Initial': x_ic} Returns ------- total_loss : torch.Tensor Scalar weighted sum of all equation term losses. loss_dict : Dict[str, torch.Tensor] Dictionary of individual losses for monitoring. Keys: equation term names, Values: unweighted loss tensors. """ equation_terms = list(self.pinn_model.equation_terms.values()) loss_dict = {} weighted_losses = [] for eq_term in equation_terms: # Skip terms without corresponding data if eq_term.name not in data_dict: continue # Compute residual for this equation term batch_data = data_dict[eq_term.name] residual = eq_term.compute_residual(network, batch_data) # Compute loss (residual should be zero for physics satisfaction) target = torch.zeros_like(residual) loss = self.loss_func(residual, target) loss_dict[eq_term.name] = loss # Weight and accumulate for total loss weighted_losses.append(eq_term.weight * loss) # Sum all weighted losses total_loss = sum(weighted_losses) if weighted_losses else torch.tensor(0.0, dtype=REAL('torch')) return total_loss, loss_dict
[docs] class MetaPINN: """ Meta-Learning Framework for Physics-Informed Neural Networks using MAML. This class implements the Model-Agnostic Meta-Learning (MAML) algorithm specifically designed for physics-informed neural networks. It learns optimal network initializations across multiple related physics tasks, enabling rapid adaptation to new tasks with minimal fine-tuning. Parameters ---------- train_tasks : List[PINNTask] List of PINN tasks for meta-training. Each task should have support and query data defined. test_tasks : List[PINNTask], optional List of PINN tasks for meta-testing/evaluation. Attributes ---------- train_tasks : List[PINNTask] Training tasks for meta-learning. test_tasks : List[PINNTask] Test tasks for meta-evaluation. meta_network : nn.Module The meta-network whose parameters are meta-learned. loss_func : callable Loss function for computing residuals (default: smooth_l1_loss). writer : SummaryWriter TensorBoard writer for logging meta-training progress. history : Dict Training history storing meta-train losses. visualization_callbacks : Dict[str, VisualizationCallback] Registered callbacks for real-time visualization. outer_epochs : int Number of meta-training iterations completed. inner_epochs : int Number of adaptation steps per task (inner loop). outer_lr : float Meta-learning rate (outer loop, typically 1e-4 to 1e-3). inner_lr : float Task adaptation learning rate (inner loop, typically 1e-5 to 1e-3). beta1, beta2 : float Adam optimizer momentum parameters (default 0.9, 0.999). epsilon : float Adam optimizer numerical stability constant (default 1e-8). """ def __init__(self, train_tasks: List[PINNTask], test_tasks: List[PINNTask] = None): """ Initialize Meta-PINN framework with training tasks. Parameters ---------- train_tasks : List[PINNTask] List of PINN tasks for meta-training. Each task should have support and query data defined. test_tasks : List[PINNTask], optional List of PINN tasks for meta-testing/evaluation. """ self.train_tasks = train_tasks self.test_tasks = test_tasks # Initialize meta-network from first task's PINN model self.meta_network = train_tasks[0].pinn_model.network if train_tasks else None self.loss_func = F.smooth_l1_loss self.writer = None # TensorBoard writer (initialized in meta_train) self.history = { 'meta_train_loss': [], # Meta-training loss history } self.visualization_callbacks: Dict[str, VisualizationCallback] = {} # Training state tracking for resumable training self.last_outer_epochs = 0 # Last completed epoch (for resuming) self.outer_epochs = 0 # Total target epochs self.inner_epochs = 0 # Inner loop iterations per task self.outer_lr = 1e-4 # Meta-learning rate (outer loop) self.inner_lr = 1e-5 # Task adaptation learning rate (inner loop) self.beta1 = 0.9 # Adam momentum parameter 1 self.beta2 = 0.999 # Adam momentum parameter 2 self.epsilon = 1e-8 # Adam numerical stability constant
[docs] def load_meta_model(self, checkpoint_path: str): """ Load meta-model from checkpoint for resuming training. This method restores the meta-network parameters and training state from a saved checkpoint, enabling interrupted meta-training to resume seamlessly from the last saved epoch. Parameters ---------- checkpoint_path : str Path to the checkpoint file (.pth). Should contain: - meta_network_state_dict: Learned meta-parameters - outer_epochs: Last completed meta-training epoch - inner_epochs: Inner loop iteration count - outer_lr, inner_lr: Learning rates - beta1, beta2, epsilon: Adam optimizer hyperparameters """ checkpoint = torch.load(checkpoint_path, map_location=DEVICE()) # Restore meta-network parameters self.meta_network.load_state_dict(checkpoint['meta_network_state_dict']) # Restore training state self.last_outer_epochs = checkpoint.get('outer_epochs', 0) self.inner_epochs = checkpoint.get('inner_epochs', 0) self.outer_lr = checkpoint.get('outer_lr', self.outer_lr) self.inner_lr = checkpoint.get('inner_lr', self.inner_lr) self.beta1 = checkpoint.get('beta1', self.beta1) self.beta2 = checkpoint.get('beta2', self.beta2) self.epsilon = checkpoint.get('epsilon', self.epsilon) print(f"Meta-PINN model loaded from {checkpoint_path} at epoch {self.last_outer_epochs}.")
[docs] def save_meta_model(self, epoch: int, checkpoint_path: str): """ Save meta-model checkpoint for later resumption. This method saves the current meta-network parameters and training state to disk, allowing meta-training to be resumed from this point if interrupted. Checkpoints are typically saved at regular intervals during training. Parameters ---------- epoch : int Current meta-training epoch number. checkpoint_path : str Path where to save the checkpoint file. """ torch.save({ 'meta_network_state_dict': self.meta_network.state_dict(), 'outer_epochs': epoch, 'inner_epochs': self.inner_epochs, 'outer_lr': self.outer_lr, 'inner_lr': self.inner_lr, 'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, }, checkpoint_path) print(f"Meta-PINN model saved to {checkpoint_path}.")
[docs] def meta_train(self, outer_epochs: int, inner_epochs: int = 5, outer_lr: float = 1e-4, inner_lr: float = 1e-5, beta1: float = 0.9, beta2: float = 0.999, epsilon: float = 1e-8, print_freq: int = 10, tensorboard_logdir: str = None, log_freq: int = 50, checkpoint_dir: str = None, checkpoint_freq: int = 100, load_from_checkpoint: str = None): """ Execute meta-training using Model-Agnostic Meta-Learning (MAML). This method implements the bi-level optimization algorithm of MAML: - Inner Loop: Fast adaptation to individual tasks using gradient descent - Outer Loop: Meta-parameter update using accumulated query losses The goal is to learn an initialization that enables rapid adaptation to new physics tasks with minimal fine-tuning. After meta-training, the learned initialization can be applied to unseen tasks for few-shot learning. Parameters ---------- outer_epochs : int Number of meta-training iterations (outer loop). Typical range: 500-5000 depending on task complexity. inner_epochs : int, default=5 Number of gradient steps for task adaptation (inner loop). Typical range: 1-20 steps. outer_lr : float, default=1e-4 Meta-learning rate for outer loop. Controls update speed of meta-parameters. Typical range: 1e-5 to 1e-3. inner_lr : float, default=1e-5 Task adaptation learning rate for inner loop. Typical range: 1e-6 to 1e-3. beta1 : float, default=0.9 Adam momentum parameter for first moment estimation. beta2 : float, default=0.999 Adam momentum parameter for second moment estimation. epsilon : float, default=1e-8 Adam numerical stability constant. print_freq : int, default=10 Print training progress every print_freq epochs. tensorboard_logdir : str, optional Directory for TensorBoard logging. If None, logging is disabled. log_freq : int, default=50 Log metrics to TensorBoard every log_freq epochs. checkpoint_dir : str, optional Directory to save checkpoints during training. If None, no saves. checkpoint_freq : int, default=100 Save checkpoint every checkpoint_freq epochs. load_from_checkpoint : str, optional Path to checkpoint file for resuming training. """ if load_from_checkpoint: self.load_meta_model(load_from_checkpoint) else: self.last_outer_epochs = 0 self.outer_epochs = self.last_outer_epochs + outer_epochs self.inner_epochs = inner_epochs self.outer_lr = outer_lr self.inner_lr = inner_lr self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon # Setup TensorBoard if tensorboard_logdir: self.writer = SummaryWriter(tensorboard_logdir) # Create checkpoint directory if checkpoint_dir: os.makedirs(checkpoint_dir, exist_ok=True) # Initialize Adam optimizer moments for meta-parameters # Momentum (first moment): running average of gradients momentum = [0 for _, _ in enumerate(self.meta_network.parameters())] # Moving average (second moment): running average of squared gradients moving_avg = [0 for _, _ in enumerate(self.meta_network.parameters())] # Initialize meta-weights (will be updated after each meta-iteration) meta_weights = [param.data for param in self.meta_network.parameters()] # Meta-training loop (outer loop) for epoch in range(self.last_outer_epochs, self.outer_epochs): loss_qry_tot = 0 # Initialize total query loss for this meta-iteration task_weights = [param.data for param in self.meta_network.parameters()] # Store meta-parameters before inner loop loss_qry_task = [] # Track individual task query losses for monitoring # Inner loop: Adapt to each task using support set for task in self.train_tasks: # Reset meta-network to initial meta-parameters for this task for kk, param in enumerate(self.meta_network.parameters()): param.data = task_weights[kk] # Start inner loop from meta-initialization # Inner loop step 0: Compute initial support loss and gradient loss_spt, _ = task.compute_loss(self.meta_network, task.get_support_data()) # Compute gradient of support loss w.r.t. parameters grad = torch.autograd.grad(loss_spt, self.meta_network.parameters()) # One-step gradient descent for task adaptation new_weights = list(map(lambda p: p[1] - inner_lr * p[0], zip(grad, task_weights))) # Update network with adapted parameters for kk, param in enumerate(self.meta_network.parameters()): param.data = new_weights[kk] # Inner loop steps 1 to (inner_epochs-1): Continue adaptation for j in range(1, inner_epochs): # Compute support loss with updated parameters loss_spt, _ = task.compute_loss(self.meta_network, task.get_support_data()) # Compute gradient and update weights grad = torch.autograd.grad(loss_spt, self.meta_network.parameters()) new_weights = list(map(lambda p: p[1] - inner_lr * p[0], zip(grad, new_weights))) # Apply updated weights to network for kk, param in enumerate(self.meta_network.parameters()): param.data = new_weights[kk] # Evaluate adapted network on query set (for meta-update) loss_qry, _ = task.compute_loss(self.meta_network, task.get_query_data()) # Track individual task query loss (for monitoring) loss_qry_task.append(loss_qry.item()/len(self.train_tasks)) # Accumulate total query loss across all tasks loss_qry_tot += loss_qry # Outer loop: Compute meta-gradient from query loss # This is the key step in MAML: gradient of query loss w.r.t. meta-parameters update_grad = torch.autograd.grad(loss_qry, self.meta_network.parameters(), retain_graph=True) # Adam optimizer update for meta-parameters for kk, param in enumerate(meta_weights): # Update first moment (momentum) momentum[kk] = beta1*momentum[kk] + (1 - beta1)*update_grad[kk] # Update second moment (adaptive learning rate) moving_avg[kk] = beta2*moving_avg[kk] + (1 - beta2)*update_grad[kk]**2 # Bias-corrected moments corr_momentum = momentum[kk]/(1 - beta1**(epoch+1)) corr_moving_avg = moving_avg[kk]/(1 - beta2**(epoch+1)) # Adam update rule meta_weights[kk] = meta_weights[kk] - outer_lr*corr_momentum/(torch.sqrt(corr_moving_avg) + epsilon) # Average query loss across all tasks (meta-objective) loss_qry_tot /= len(self.train_tasks) self.history['meta_train_loss'].append(loss_qry_tot.item()) # Print progress at specified intervals if (epoch+1) % print_freq == 0: print('[%d/%d] Meta-Loss: %g' % (epoch+1, self.outer_epochs, loss_qry_tot.item())) # Apply updated meta-parameters to meta-network for kk, param in enumerate(self.meta_network.parameters()): param.data = meta_weights[kk] # Log to TensorBoard if self.writer and log_freq > 0 and (epoch + 1) % log_freq == 0: self.writer.add_scalar('Meta-Loss', loss_qry_tot.item(), epoch+1) self.writer.flush() # Save checkpoint periodically if checkpoint_dir and (epoch+1) % checkpoint_freq == 0: self.save_meta_model(epoch+1, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')) # Execute visualization callbacks self._execute_visualization_callbacks( epoch, train_task=task, meta_train_loss=loss_qry_tot.item(), ) if self.writer: self.writer.close()
[docs] def meta_test(self, test_tasks: List[PINNTask], results_dir: str = None): """ Evaluate meta-learned initialization on new test tasks. This method demonstrates the meta-learning benefit by initializing the network with learned meta-parameters and performing few-shot adaptation on new tasks with minimal training data. Parameters ---------- test_tasks : List[PINNTask] List of new tasks for meta-testing. Each task should have support data defined for few-shot adaptation. results_dir : str, optional Directory to save adapted model checkpoints. One checkpoint per task will be saved if provided. """ self.test_tasks = test_tasks if results_dir: os.makedirs(results_dir, exist_ok=True) save_test_task_model_file = [os.path.join(results_dir, f'{task.get_task_id()}_meta_model.pth') for task in test_tasks] # Store meta-parameters for reinitialization across tasks task_weights = [param.data for param in self.meta_network.parameters()] # Adapt to each test task independently for i, task in enumerate(self.test_tasks): # Reset to meta-initialization for this task for kk, param in enumerate(self.meta_network.parameters()): param.data = task_weights[kk] # Few-shot adaptation: First gradient step on support set loss_spt, _ = task.compute_loss(self.meta_network, task.get_support_data()) grad = torch.autograd.grad(loss_spt, self.meta_network.parameters()) new_weights = list(map(lambda p: p[1] - self.inner_lr * p[0], zip(grad, self.meta_network.parameters()))) for kk, param in enumerate(self.meta_network.parameters()): param.data = new_weights[kk] # Continue adaptation for remaining inner loop iterations for j in range(1, self.inner_epochs): loss_spt, _ = task.compute_loss(self.meta_network, task.get_support_data()) grad = torch.autograd.grad(loss_spt, self.meta_network.parameters()) new_weights = list(map(lambda p: p[1] - self.inner_lr * p[0], zip(grad, new_weights))) for kk, param in enumerate(self.meta_network.parameters()): param.data = new_weights[kk] # Save adapted parameters for this test task torch.save({'network_state_dict': self.meta_network.state_dict()}, save_test_task_model_file[i])
[docs] def register_visualization_callback(self, callback: VisualizationCallback): """ Register a visualization callback for training monitoring. Parameters ---------- callback : VisualizationCallback Callback instance to register (from pinn.VisualizationCallback). """ self.visualization_callbacks[callback.name] = callback print(f"Registered visualization callback: {callback.name} (log_freq={callback.log_freq})")
def _execute_visualization_callbacks(self, epoch: int, **kwargs): """ Execute all registered visualization callbacks. Parameters ---------- epoch : int Current training epoch. kwargs : dict Additional information to pass to callbacks. """ if not self.writer or len(self.visualization_callbacks) == 0: return # Add meta_pinn reference to kwargs for callbacks that need it kwargs['meta_pinn'] = self for callback_name, callback in self.visualization_callbacks.items(): # Check if it's time to execute this callback if callback.log_freq > 0 and (epoch + 1) % callback.log_freq == 0: try: # Execute callback visualization (pass network instead of meta_pinn) # This maintains compatibility with PINN.VisualizationCallback interface figures = callback.visualize(self.meta_network, epoch, self.writer, **kwargs) if figures and isinstance(figures, dict): # Log figures to TensorBoard for plot_name, fig in figures.items(): if hasattr(fig, 'savefig'): # Check if it's a matplotlib figure self.writer.add_figure( f'Visualization/{callback_name}/{plot_name}', fig, global_step=epoch ) plt.close(fig) # Close figure to save memory except Exception as e: print(f"Error in visualization callback '{callback_name}': {str(e)}") traceback.print_exc()
[docs] class MetaStaArc1DNet(nn.Module): """ Neural network wrapper for meta-learning 1D stationary arc plasma problems. Parameters ---------- network : nn.Module Backbone neural network (e.g., FNN) that maps r → N(r). Input shape: [batch_size, 1] (normalized radius) Output shape: [batch_size, 1] (log temperature) """ def __init__(self, network): super(MetaStaArc1DNet, self).__init__() self.network = network # Backbone network (maps r → log T)
[docs] def forward(self, x): """ Forward pass with exponential activation. Parameters ---------- x : torch.Tensor Input normalized radius, shape [batch_size, 1] Returns ------- torch.Tensor Normalized temperature (always positive), shape [batch_size, 1]. """ out = self.network(x) # Backbone output (log-space) out = torch.exp(out) # Convert to positive temperature return out
[docs] class MetaStaArc1DModel(PINN): """ PINN model for 1D stationary arc adapted for meta-learning. This class implements a specialized PINN for solving the 1D steady-state arc plasma equation in a meta-learning context. Unlike the standard StaArc1DModel, this version is designed as a task within the MetaPINN framework for learning across multiple arc configurations. Parameters ---------- R : float Arc radius in meters (e.g., 0.01 for 10mm). I : float Arc current in amperes (e.g., 100, 150, 200). Tb : float, default=300.0 Boundary temperature at r=R in Kelvin. T_red : float, default=1e4 Temperature normalization constant (typically 10000K). backbone_net : nn.Module, default=FNN([1,100,100,100,100,1]) Backbone neural network for temperature prediction. train_data_size : int, default=500 Number of collocation points for training. test_data_size : int, default=501 Number of points for evaluation/prediction. sample_mode : str, default='uniform' Sampling strategy ('uniform', 'random', or 'lhs'). GL_degree : int, default=100 Degree of Gauss-Legendre quadrature for arc conductance integration. prop : ArcPropSpline Material property splines (κ, σ, ε_nec as functions of T). """ def __init__( self, R, I, Tb=300.0, T_red=1e4, backbone_net=FNN(layers=[1, 100, 100, 100, 100, 1]), train_data_size=500, test_data_size=501, sample_mode='uniform', GL_degree=100, prop:ArcPropSpline=None, ): self.R = R # Arc radius [m] self.I = I # Arc current [A] self.T_red = T_red # Temperature normalization constant [K] self.Tb = Tb # Boundary temperature [K] self.train_data_size = train_data_size # Number of training collocation points self.test_data_size = test_data_size # Number of test points self.sample_mode = sample_mode # Sampling strategy self.GL_degree = GL_degree # Gauss-Legendre quadrature degree # Compute Gauss-Legendre quadrature points and weights for arc conductance self.Xq, self.Wq = calc_GL_coefs(GL_degree) self.prop = prop # Material properties (σ, κ, ε_nec) # Define 1D spatial domain [0, 1] (normalized radius) self.geo = Geo1D([0.0, 1.0]) # Wrap backbone network with meta-learning specific architecture network = MetaStaArc1DNet(backbone_net) super().__init__(network) # Use smooth L1 loss (Huber loss) for robustness to outliers self.set_loss_func(F.smooth_l1_loss) def _define_loss_terms(self): """ Define physics-informed loss terms for 1D stationary arc equation. Sets up two loss terms: 1. Domain loss: PDE residual in the interior (0 < r < R) 2. Boundary loss: Symmetry condition at centerline (r = 0) """ def _pde_residual(network, x): """ Compute PDE residual for the 1D stationary arc energy equation. The energy balance in cylindrical coordinates is enforced. The electric field E is computed from arc conductance using Gauss-Legendre quadrature. Parameters ---------- network : nn.Module Meta-learning network (MetaStaArc1DNet). x : torch.Tensor Normalized radial coordinates, shape [N, 1]. Returns ------- torch.Tensor PDE residual, shape [N, 1] (should be zero for exact solution). """ # Apply boundary condition transformation: T(r) = N(r)*(1-r) + Tb # This ensures T(R=1) = Tb automatically T = network(x)*(1.0 - x) + self.Tb/self.T_red # Compute temperature-dependent material properties kappa = self.prop.kappa(T.view(-1)*self.T_red).view(-1,1) # Thermal conductivity sigma = self.prop.sigma(T.view(-1)*self.T_red).view(-1,1) # Electrical conductivity nec = self.prop.nec(T.view(-1)*self.T_red).view(-1,1) # Net emission coefficient # Compute arc conductance using Gauss-Legendre quadrature # G = πR² ∫₀¹ r·σ(T(r)) dr (integral in normalized coordinates) Tq = network(self.Xq)*(1.0 - self.Xq) + self.Tb/self.T_red sigma_q = self.prop.sigma(Tq.view(-1)*self.T_red).view(-1,1) arc_cond = np.pi*self.R*self.R*torch.sum(self.Wq*self.Xq*sigma_q) # Compute energy source terms joule = sigma*(self.I/arc_cond)**2 # Joule heating: σ·E² radiation = 4*np.pi*nec # Radiation loss: 4π·ε_nec net_energy = (joule - radiation)/self.T_red*self.R*self.R # Normalized net energy # Compute thermal conduction term: (1/r) d/dr(r·κ·dT/dr) T_x = df_dX(T, x) # First derivative: dT/dr T_term = x*kappa*T_x # r·κ·dT/dr T_xx = df_dX(T_term, x) # d/dr(r·κ·dT/dr) # PDE residual: (1/r)·d/dr(r·κ·dT/dr) + net_energy = 0 # Multiply by r to avoid singularity: d/dr(r·κ·dT/dr) + r·net_energy = 0 func = T_xx + x*net_energy return func def _bc_residual(network, x): """ Compute boundary condition residual at centerline (r = 0). Enforces the symmetry condition: dT/dr|_{r=0} = 0. This is a natural consequence of cylindrical symmetry: the temperature gradient must vanish at the axis to ensure a unique, smooth solution. Parameters ---------- network : nn.Module Meta-learning network. x : torch.Tensor Boundary point at r = 0, shape [1, 1]. Returns ------- torch.Tensor Boundary residual (should be zero), shape [1, 1]. """ # Compute temperature with boundary transformation T = network(x)*(1.0 - x) + self.Tb/self.T_red # Compute temperature gradient T_x = df_dX(T, x) # Residual: dT/dr should be zero at r=0 return T_x # Sample collocation points in the domain (0 < r < 1) x_domain = self.geo.sample_domain(self.train_data_size, mode=self.sample_mode) # Sample boundary point at centerline (r = 0) x_bc = self.geo.sample_boundary() x_bc_left = x_bc[0] # Left boundary corresponds to r=0 # Add equation terms with weights # Domain: PDE residual in interior self.add_equation('Domain', _pde_residual, weight=1.0, data=x_domain) # Boundary: Symmetry condition at centerline (weighted higher for emphasis) self.add_equation('Boundary', _bc_residual, weight=10.0, data=x_bc_left)
[docs] def predict(self, input_data: torch.Tensor) -> torch.Tensor: """ Make temperature predictions using the trained/adapted network. This method applies the trained meta-network (after adaptation) to predict temperature distributions at given radial locations. The boundary condition transformation is applied to ensure T(R) = Tb. Parameters ---------- input_data : torch.Tensor Normalized radial coordinates (0 to 1), shape [N, 1]. Returns ------- torch.Tensor Normalized temperature predictions, shape [N, 1]. Physical temperature: T_physical = output * self.T_red """ self.network.eval() # Set to evaluation mode with torch.no_grad(): # Disable gradient computation for efficiency # Apply boundary transformation: T(r) = N(r)·(1-r) + Tb output = self.network(input_data)*(1.0 - input_data) + self.Tb/self.T_red return output
[docs] class StaArc1DTask(PINNTask): """ Task wrapper for 1D stationary arc problem in meta-learning framework. This class encapsulates a specific arc discharge configuration as a meta-learning task. It creates the underlying PINN model, samples support and query sets, and provides the interface required by MetaPINN for meta-training and meta-testing. Task Definition --------------- Each task corresponds to a unique arc discharge problem instance: - Physical parameters: Arc radius R, current I, boundary temp Tb - Material properties: Plasma gas type (SF6, Ar, etc.) - Support set: Small set of collocation points for adaptation - Query set: Separate set for meta-validation Parameters ---------- task_id : str Unique identifier for this arc configuration (e.g., 'Arc_I=200A_R=10mm'). R : float Arc radius in meters. I : float Arc current in amperes. Tb : float, default=300.0 Boundary temperature at r=R in Kelvin. T_red : float, default=1e4 Temperature normalization constant. backbone_net : nn.Module, default=FNN([1,100,100,100,100,1]) Backbone neural network architecture. thermo_file : str Path to thermodynamic properties file (κ, Cp, ρ vs T). nec_file : str Path to net emission coefficient file (ε_nec vs T). support_data_size : int, default=500 Number of collocation points in support set. query_data_size : int, default=400 Number of collocation points in query set. sample_mode : str, default='uniform' Sampling strategy ('uniform', 'random', 'lhs'). Attributes ---------- R : float Arc radius (stored for reference). pinn_model : MetaStaArc1DModel Underlying PINN model for this task. support_data : Dict[str, torch.Tensor] Support set collocation points {'Domain': x_spt, 'Boundary': x_bc}. query_data : Dict[str, torch.Tensor] Query set collocation points {'Domain': x_qry, 'Boundary': x_bc}. """ def __init__( self, task_id: str, R: float, I: float, Tb: float = 300.0, T_red: float = 1e4, backbone_net: nn.Module = FNN(layers=[1, 100, 100, 100, 100, 1]), thermo_file: str = None, nec_file: str = None, support_data_size: int = 500, query_data_size: int = 400, sample_mode: str = 'uniform' ): """ Initialize a stationary arc task for meta-learning. This constructor creates a complete task instance by: 1. Setting up material properties from data files 2. Creating the underlying MetaStaArc1DModel 3. Sampling support and query sets (disjoint) 4. Registering with parent PINNTask class The support and query sets are sampled independently to ensure they are disjoint, which is crucial for unbiased meta-learning. Parameters ---------- task_id : str Unique identifier for this task (e.g., 'Arc_I=200A_R=10mm'). R : float Arc radius in meters. I : float Arc current in amperes. Tb : float, default=300.0 Boundary temperature at r=R in Kelvin. T_red : float, default=1e4 Temperature normalization constant (typically 10000K). backbone_net : nn.Module, default=FNN([1,100,100,100,100,1]) Neural network architecture for temperature prediction. thermo_file : str, optional Path to thermodynamic properties CSV file (κ, Cp, ρ vs T). nec_file : str, optional Path to net emission coefficient CSV file (ε_nec vs T). support_data_size : int, default=500 Number of collocation points in support set (inner loop). query_data_size : int, default=400 Number of collocation points in query set (outer loop). sample_mode : str, default='uniform' Sampling strategy: 'uniform', 'random', or 'lhs'. Raises ------ FileNotFoundError If thermo_file or nec_file is not found. """ # Store arc radius for reference self.R = R # Load material properties from data files prop = ArcPropSpline(thermo_file, nec_file, R) # Create underlying PINN model for this arc configuration pinn_model = MetaStaArc1DModel( R=R, I=I, Tb=Tb, T_red=T_red, backbone_net=backbone_net, train_data_size=support_data_size, # Used for model setup test_data_size=query_data_size, # Used for evaluation grid sample_mode=sample_mode, prop=prop ) # Sample support set for inner loop adaptation # Support set: small dataset for fast task-specific fine-tuning support_data = { 'Domain': pinn_model.geo.sample_domain(support_data_size, mode=sample_mode), 'Boundary': pinn_model.geo.sample_boundary()[0] } # Sample query set for meta-update (outer loop) # Query set: independent dataset for computing meta-gradients query_data = { 'Domain': pinn_model.geo.sample_domain(query_data_size, mode=sample_mode), 'Boundary': pinn_model.geo.sample_boundary()[0] } # Register with parent PINNTask class super().__init__(task_id, pinn_model, support_data, query_data)