2. Physics-Informed Machine Learning (PIML)
This module implements Physics-Informed Neural Networks (PINNs) and their variants tailored for plasma physics.
2.1. Geometry
Geometry classes and sampling utilities for Physics-Informed Neural Networks (PINN).
This module provides comprehensive geometry classes and sampling utilities for Physics-Informed Neural Networks (PINNs), supporting flexible domain and boundary sampling for 1D, 2D, and time-dependent problems.
2.1.1. Core Geometry Classes
GeoTime: Temporal domain [ts, te]
Geo1D: Spatial domain [xl, xu]
Geo1DTime: Space-time domain for 1D problems
GeoPoly2D: Polygonal domain in 2D
GeoRect2D: Rectangular domain in 2D
GeoPoly2DTime: Space-time domain for 2D problems
- class ai4plasma.piml.geo.Geo1D(points: List[float])[source]
Bases:
GeometryOne-dimensional spatial domain (line segment) for 1D problems.
This class represents a 1D interval [xl, xu] and provides methods to sample spatial points within this interval. It’s fundamental for 1D steady-state problems or spatial discretization in 1D time-dependent problems.
- xl
Lower bound of the spatial domain.
- Type:
float
- xu
Upper bound of the spatial domain.
- Type:
float
- create_domain(points: List[float])[source]
Create the 1D spatial domain with specified bounds.
- Parameters:
points (List[float]) – Two-element list [xl, xu] defining the interval.
- sample_boundary(to_tensor: bool = True, require_grad: bool = True) List[ndarray | Tensor][source]
Sample the spatial boundary points (left and right endpoints).
- Parameters:
to_tensor (bool, default=True) – If True, return PyTorch tensors.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
List containing two elements [x_left, x_right], each (1, 1).
- Return type:
List[Union[np.ndarray, torch.Tensor]]
- sample_domain(N: int, mode: str | SamplingMode = SamplingMode.UNIFORM, include_boundary: bool = False, to_tensor: bool = True, require_grad: bool = True) ndarray | Tensor[source]
Sample spatial points within the 1D domain.
Generates N spatial points in the interval [xl, xu] using the specified sampling strategy. Points are returned in ascending order.
- Parameters:
N (int) – Number of spatial points to sample.
mode (Union[str, SamplingMode], default=SamplingMode.UNIFORM) – Sampling strategy (‘uniform’ or ‘random’).
include_boundary (bool, default=False) – If True, explicitly include xl and xu.
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Array or tensor of shape (N, 1) containing sampled points.
- Return type:
Union[np.ndarray, torch.Tensor]
- class ai4plasma.piml.geo.Geo1DTime(points: List[float], ts: float, te: float)[source]
Bases:
GeometryComposite 1D space-time domain for time-dependent 1D problems.
This class combines a 1D spatial domain (Geo1D) and a temporal domain (GeoTime) to handle time-dependent problems. It delegates spatial and temporal sampling to respective component geometries.
2. Applications
Suitable for solving time-dependent 1D PDEs: - Heat equation: du/dt = alpha * d2u/dx2 - Wave equation: d2u/dt2 = c2 * d2u/dx2 - Advection-diffusion: du/dt + v*du/dx = D*d2u/dx2
- sample_all_domain(Nx: int, Nt: int, mode: List[str | SamplingMode] = ['uniform', 'uniform'], include_boundary: bool = False, to_tensor: bool = True, require_grad: bool = True) Tuple[source]
Sample all domains including interior and boundaries for PINN training.
Primary method for generating training data for time-dependent 1D PDE problems. Generates: 1. Interior space-time points for PDE residual 2. Spatial boundaries across time for boundary conditions 3. Initial time boundary for initial conditions 4. Final time boundary for terminal conditions
- Parameters:
Nx (int) – Number of spatial sampling points.
Nt (int) – Number of temporal sampling points.
mode (List[Union[str, SamplingMode]], default=['uniform', 'uniform']) – List of sampling modes [space_mode, time_mode].
include_boundary (bool, default=False) – If True, include temporal boundaries in t samples.
to_tensor (bool, default=True) – If True, return as PyTorch tensors.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Tuple of (xt, (xb, xbt0, xbt1)) where: - xt: Interior space-time points, shape (Nx*Nt, 2) - xb: List of [left_boundary, right_boundary] across time - xbt0: Initial condition points at t=ts, shape (Nx, 2) - xbt1: Final condition points at t=te, shape (Nx, 2)
- Return type:
Tuple[Union[np.ndarray, torch.Tensor], Tuple]
- sample_domain(N: int, mode: str | SamplingMode = SamplingMode.UNIFORM, include_boundary: bool = False, to_tensor: bool = True, require_grad: bool = True) ndarray | Tensor[source]
Sample spatial points only (delegates to geo_space).
Note: Samples only the spatial domain. For space-time sampling, use sample_all_domain() instead.
- Parameters:
N (int) – Number of spatial samples.
mode (Union[str, SamplingMode], default=SamplingMode.UNIFORM) – Sampling strategy (‘uniform’ or ‘random’).
include_boundary (bool, default=False) – If True, include spatial boundaries.
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Spatial points with shape (N, 1).
- Return type:
Union[np.ndarray, torch.Tensor]
- sample_space_time_list(x: ndarray, t_list: List[float], to_tensor: bool = True, require_grad: bool = True) List[ndarray | Tensor][source]
Generate spatial snapshots at specific time instances (delegates to geo_time).
- Parameters:
x (np.ndarray) – Spatial points with shape (N, 1).
t_list (List[float]) – List of specific time values.
to_tensor (bool, default=True) – If True, return PyTorch tensors.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
List of space-time coordinate arrays.
- Return type:
List[Union[np.ndarray, torch.Tensor]]
- class ai4plasma.piml.geo.GeoPoly2D(points: ndarray)[source]
Bases:
GeometryTwo-dimensional polygonal domain for 2D spatial problems.
This class represents an arbitrary polygon defined by its vertices and uses the Shapely library for geometric operations. It supports sampling interior points using rejection sampling and boundary points on each polygon edge.
2. Polygon Properties
The polygon can be convex or non-convex, but should be: - Simple (edges don’t cross each other) - Non-degenerate (has non-zero area) - Properly oriented (vertices in consistent order)
- points
Array of polygon vertices, shape (n_vertices, 2).
- Type:
np.ndarray
- points_num
Number of vertices.
- Type:
int
- polygon
Shapely polygon object for geometric operations.
- Type:
Polygon
- bound
Bounding box (xmin, ymin, xmax, ymax).
- Type:
tuple
- create_domain(points: ndarray)[source]
Create the 2D polygonal domain from vertices.
- Parameters:
points (np.ndarray) – Array of polygon vertices with shape (n_vertices, 2).
- get_bounding_box(geo: Polygon) Tuple[float, float, float, float][source]
Get the axis-aligned bounding box of a geometry.
- Parameters:
geo (Polygon) – Shapely geometry object.
- Returns:
Tuple of (xmin, ymin, xmax, ymax) defining the bounding box.
- Return type:
Tuple[float, float, float, float]
- sample_boundary(N_list: List[int], mode: str | SamplingMode = SamplingMode.UNIFORM, to_tensor: bool = True, require_grad: bool = True) List[ndarray | Tensor][source]
Sample points on the boundary edges of the polygon.
Generates points along each edge with specified distribution. Each edge is sampled independently.
- Parameters:
N_list (List[int]) – Number of points per edge. Must match polygon vertices count.
mode (Union[str, SamplingMode], default=SamplingMode.UNIFORM) – Sampling strategy (‘uniform’ or ‘random’).
to_tensor (bool, default=True) – If True, return PyTorch tensors.
require_grad (bool, default=True) – If True, enable gradient computation.
- Returns:
List of boundary arrays, one per edge, shape (N_i, 2).
- Return type:
List[Union[np.ndarray, torch.Tensor]]
- Raises:
ValueError – If N_list length doesn’t match polygon edges.
- sample_domain(N: int, mode: str | SamplingMode = '', to_tensor: bool = True, require_grad: bool = True) ndarray | Tensor[source]
Sample interior points within the polygonal domain using rejection sampling.
Uses rejection sampling: randomly generates points in the bounding box and accepts only those inside the polygon.
- Parameters:
N (int) – Number of interior points to sample.
mode (Union[str, SamplingMode], default='') – Sampling mode (currently unused, reserved for future).
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Array or tensor of shape (N, 2) containing sampled points.
- Return type:
Union[np.ndarray, torch.Tensor]
- class ai4plasma.piml.geo.GeoPoly2DTime(points: ndarray, ts: float, te: float)[source]
Bases:
GeometryComposite 2D space-time domain for time-dependent 2D problems.
This class combines a 2D polygonal spatial domain (GeoPoly2D) and a temporal domain (GeoTime) to handle time-dependent problems in complex 2D geometries.
2. Applications
2D heat conduction in complex geometries
Fluid flow in irregular domains
Electromagnetic field evolution
Reaction-diffusion systems
- sample_all_domain(Nx: int, Nt: int, Nb_list: List[int], mode: List[str | SamplingMode] = ['uniform', 'uniform', 'uniform'], include_boundary: bool = False, to_tensor: bool = True, require_grad: bool = True) Tuple[source]
Sample all domains including interior and boundaries for PINN training.
Generates comprehensive training data for time-dependent 2D PDE problems.
- Parameters:
Nx (int) – Number of spatial sampling points in the interior.
Nt (int) – Number of temporal sampling points.
Nb_list (List[int]) – List specifying number of samples per polygon edge. Must have length equal to number of polygon vertices.
mode (List[Union[str, SamplingMode]], default=['uniform', 'uniform', 'uniform']) – List of sampling modes [space_mode, time_mode, boundary_mode].
include_boundary (bool, default=False) – If True, include temporal boundaries in t samples.
to_tensor (bool, default=True) – If True, return as PyTorch tensors.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Tuple of (xt, (xb, xbt0, xbt1)) where: - xt: Interior space-time points, shape (Nx*Nt, 3) - xb: List of boundary segments across time - xbt0: Initial condition points at t=ts, shape (Nx, 3) - xbt1: Final condition points at t=te, shape (Nx, 3)
- Return type:
Tuple[Union[np.ndarray, torch.Tensor], Tuple]
- class ai4plasma.piml.geo.GeoRect2D(xmin: float, xmax: float, ymin: float, ymax: float)[source]
Bases:
GeometryTwo-dimensional rectangular domain for 2D spatial problems.
This class represents an axis-aligned rectangle defined by its bounds in x and y directions. It provides efficient sampling methods for both interior (domain) and boundary points, with support for uniform grid sampling and random sampling strategies.
The rectangle is defined by [xmin, xmax] × [ymin, ymax] and supports: - Uniform grid sampling for structured collocation points - Random sampling for stochastic training approaches - Boundary sampling with controllable density on each edge
- xmin
Minimum x-coordinate
- Type:
float
- xmax
Maximum x-coordinate
- Type:
float
- ymin
Minimum y-coordinate
- Type:
float
- ymax
Maximum y-coordinate
- Type:
float
- width
Rectangle width (xmax - xmin)
- Type:
float
- height
Rectangle height (ymax - ymin)
- Type:
float
- area
Rectangle area (width × height)
- Type:
float
- create_domain(xmin: float, xmax: float, ymin: float, ymax: float)[source]
Create the 2D rectangular domain from bounds.
- Parameters:
xmin (float) – Minimum x-coordinate.
xmax (float) – Maximum x-coordinate.
ymin (float) – Minimum y-coordinate.
ymax (float) – Maximum y-coordinate.
- Raises:
ValueError – If bounds are invalid (min >= max in any dimension).
- sample_boundary(N_list: List[int] | int = None, mode: str | SamplingMode = SamplingMode.UNIFORM, to_tensor: bool = True, require_grad: bool = True) List[ndarray | Tensor][source]
Sample points on the four rectangle boundary edges.
The four edges are ordered as: left, right, bottom, top. Each edge is sampled independently with uniform or random strategy.
2. Edge Definitions
Left (Edge 0): x = x_min, y ∈ [y_min, y_max]
Right (Edge 1): x = x_max, y ∈ [y_min, y_max]
Bottom (Edge 2): y = y_min, x ∈ [x_min, x_max]
Top (Edge 3): y = y_max, x ∈ [x_min, x_max]
- param N_list:
Samples per edge [N_left, N_right, N_bottom, N_top]. - int: Same N on all 4 edges - List[4]: Individual N per edge - None: Defaults to 50 per edge
- type N_list:
Union[List[int], int], default=None
- param mode:
Strategy: ‘uniform’ (evenly spaced) or ‘random’ (uniformly distributed).
- type mode:
Union[str, SamplingMode], default=SamplingMode.UNIFORM
- param to_tensor:
If True, return PyTorch tensors.
- type to_tensor:
bool, default=True
- param require_grad:
If True, enable gradient computation.
- type require_grad:
bool, default=True
- returns:
Four arrays [left, right, bottom, top] each with shape (N_i, 2).
- rtype:
List[np.ndarray] or List[torch.Tensor]
- raises ValueError:
If N_list length ≠ 4 or values ≤ 0.
- sample_domain(N: int | Tuple[int, int] | List[int], mode: str | SamplingMode = SamplingMode.RANDOM, to_tensor: bool = True, require_grad: bool = True) ndarray | Tensor[source]
Sample interior points within the rectangular domain.
Supports two strategies: 1. UNIFORM: Nx × Ny grid 2. RANDOM: N uniformly random points
- Parameters:
N (Union[int, Tuple[int, int], List[int]]) – Grid resolution or number of samples. - int + uniform: N×N grid - (Nx, Ny) + uniform: Nx×Ny grid - int + random: N random points
mode (Union[str, SamplingMode], default=SamplingMode.RANDOM) – Sampling strategy (‘uniform’ or ‘random’).
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True, enable gradient computation.
- Returns:
Sampled points with shape (M, 2) where: - M = Nx*Ny for uniform mode - M = N for random mode
- Return type:
Union[np.ndarray, torch.Tensor]
- Raises:
ValueError – If N is invalid or mode unsupported.
- class ai4plasma.piml.geo.GeoTime(ts: float, te: float)[source]
Bases:
GeometryOne-dimensional temporal domain for time-dependent problems.
This class represents a time interval [ts, te] and provides methods to sample time points within this interval. It supports both uniform and random sampling, and can optionally include boundary time points.
- ts
Start time of the temporal domain.
- Type:
float
- te
End time of the temporal domain.
- Type:
float
- create_domain(ts: float, te: float)[source]
Create the temporal domain with specified time bounds.
- Parameters:
ts (float) – Start time for the domain.
te (float) – End time for the domain.
- sample_boundary(to_tensor: bool = True, require_grad: bool = True) List[ndarray | Tensor][source]
Sample the temporal boundary points (start and end times).
- Parameters:
to_tensor (bool, default=True) – If True, return PyTorch tensors.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
List containing two elements [ts_point, te_point], each with shape (1, 1).
- Return type:
List[Union[np.ndarray, torch.Tensor]]
- sample_domain(N: int, mode: str | SamplingMode = SamplingMode.UNIFORM, include_boundary: bool = False, to_tensor: bool = True, require_grad: bool = True) ndarray | Tensor[source]
Sample time points within the temporal domain.
Generates N time points in the interval [ts, te] using the specified sampling strategy. Points are returned in ascending order.
- Parameters:
N (int) – Number of time points to sample.
mode (Union[str, SamplingMode], default=SamplingMode.UNIFORM) – Sampling strategy (‘uniform’ or ‘random’).
include_boundary (bool, default=False) – If True, explicitly include ts and te in samples.
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Array or tensor of shape (N, 1) containing sampled time points.
- Return type:
Union[np.ndarray, torch.Tensor]
- sample_space_time(x: ndarray, t: ndarray, to_tensor: bool = True, require_grad: bool = True) ndarray | Tensor[source]
Combine spatial and temporal sampling points into space-time coordinates.
Creates a Cartesian product of spatial points x and temporal points t. The output is organized such that for each time point, all spatial points are listed sequentially.
- Parameters:
x (np.ndarray) – Spatial coordinates with shape (N, d).
t (np.ndarray) – Temporal coordinates with shape (M, 1).
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Space-time coordinates with shape (N*M, d+1).
- Return type:
Union[np.ndarray,torch.Tensor]
- sample_space_time_boundary(x: ndarray, xb_list: List[ndarray], t: ndarray, tb_list: List[ndarray], to_tensor: bool = True, require_grad: bool = True) Tuple[source]
Generate space-time boundary sampling points for PINN boundary conditions.
Creates three types of boundary point collections: 1. Spatial boundaries across all time points 2. Initial condition points at t=ts 3. Final condition points at t=te
- Parameters:
x (np.ndarray) – Interior spatial points with shape (N, d).
xb_list (List[np.ndarray]) – List of spatial boundary arrays, e.g., [x_left, x_right].
t (np.ndarray) – Temporal points with shape (M, 1).
tb_list (List[np.ndarray]) – Boundary times [t_start, t_end].
to_tensor (bool, default=True) – If True, return PyTorch tensors.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
Tuple of (xb, xt0, xt1) where xb is list of boundary arrays, xt0 is initial condition points, xt1 is final condition points.
- Return type:
Tuple[List, Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]
- sample_space_time_list(x: ndarray, t_list: List[float], to_tensor: bool = True, require_grad: bool = True) List[ndarray | Tensor][source]
Generate space-time snapshots at specific time instances.
Creates spatial snapshots at discrete time points in t_list. Useful for evaluating solutions at specific observation times.
- Parameters:
x (np.ndarray) – Spatial points with shape (N, d).
t_list (List[float]) – List of specific time values for snapshots.
to_tensor (bool, default=True) – If True, return PyTorch tensors.
require_grad (bool, default=True) – If True and to_tensor=True, enable gradient computation.
- Returns:
List of space-time coordinate arrays, one per time in t_list. Each element has shape (N, d+1).
- Return type:
List[Union[np.ndarray, torch.Tensor]]
- class ai4plasma.piml.geo.Geometry[source]
Bases:
objectAbstract base class for geometric domains in PINN problems.
This class defines the interface that all geometry subclasses must implement, including domain creation, interior sampling, and boundary sampling. It provides a consistent API for handling various geometric shapes in physics simulations.
2. Interface Requirements
Subclasses must implement: - create_domain(): Define the geometric domain - sample_domain(): Generate sampling points inside the domain - sample_boundary(): Generate sampling points on the boundary
- create_domain(*args, **kwargs)[source]
Create and initialize the geometric domain.
This abstract method must be overridden by subclasses.
- Raises:
NotImplementedError – Must be implemented by subclasses.
- sample_boundary(to_tensor: bool = True, require_grad: bool = True, **kwargs)[source]
Sample points on the geometric boundary.
Abstract method. Must be implemented by subclasses.
- Parameters:
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True, enable gradient computation.
**kwargs (dict) – Additional parameters.
- Returns:
Boundary points.
- Return type:
Union[torch.Tensor, np.ndarray, list]
- Raises:
NotImplementedError – Must be implemented by subclasses.
- sample_domain(N: int, mode: str | SamplingMode = SamplingMode.UNIFORM, to_tensor: bool = True, require_grad: bool = True, **kwargs)[source]
Sample points inside the geometric domain.
Abstract method. Must be implemented by subclasses.
- Parameters:
N (int) – Number of sampling points.
mode (Union[str, SamplingMode], default=SamplingMode.UNIFORM) – Sampling strategy.
to_tensor (bool, default=True) – If True, return PyTorch tensor.
require_grad (bool, default=True) – If True, enable gradient computation.
**kwargs (dict) – Additional parameters.
- Returns:
Sampled points.
- Return type:
Union[torch.Tensor, np.ndarray]
- Raises:
NotImplementedError – Must be implemented by subclasses.
- class ai4plasma.piml.geo.SamplingMode(*values)[source]
Bases:
EnumEnumeration for sampling strategies to avoid magic strings.
- UNIFORM
Evenly spaced sampling with uniform grid spacing.
- Type:
str
- RANDOM
Random sampling with uniform distribution.
- Type:
str
- LHS
Latin Hypercube Sampling (reserved for future implementation).
- Type:
str
- LHS = 'lhs'
- RANDOM = 'random'
- UNIFORM = 'uniform'
2.2. PINN
Advanced Physics-Informed Neural Networks (PINNs) for solving PDEs and multi-physics problems.
This module implements a flexible and extensible framework for training Physics-Informed Neural Networks (PINNs) to solve complex partial differential equations (PDEs) and coupled multi-physics problems. PINNs embed physics knowledge directly into neural networks through residual-based loss functions, enabling accurate solutions without requiring large amounts of labeled training data.
2.2.1. PINN Classes
EquationTerm: Encapsulates a single physics constraint with residual function.
VisualizationCallback: Abstract base for custom visualization during training.
PINN: Main physics-informed neural network model class.
- class ai4plasma.piml.pinn.EquationTerm(name: str, residual_fn: Callable, weight: float = 1.0, data: Tensor = None)[source]
Bases:
objectRepresents a single physics equation term with its residual function and weight.
This class encapsulates one component of the physics loss function in a PINN model. For example, a PDE problem might have separate equation terms for: - Interior domain residual (PDE satisfaction) - Boundary condition residuals - Initial condition residuals - Additional constraints or regularization terms
Each term has: - A residual function that evaluates how well the network satisfies the equation - A weight factor that controls its contribution to the total loss - Associated data points where the residual is evaluated - Optional DataLoader for batch training on large datasets
2. Attributes:
- namestr
Unique identifier for this equation term (e.g., ‘domain’, ‘boundary_left’, ‘initial_condition’). Used for logging and loss tracking.
- residual_fnCallable
Function with signature: residual_fn(network, data) -> torch.Tensor Computes the residual (error) at given data points. Should return a tensor where ideally all values are close to zero when the physics is satisfied.
- weightfloat
Multiplicative weight for this loss term. Higher weights make the optimizer prioritize satisfying this equation. Typical range: 0.1 to 100.
- datatorch.Tensor
Input data points where the residual is evaluated. Shape depends on problem dimensionality: (N, d) for d-dimensional problems with N points.
- dataloaderDataLoader or None
PyTorch DataLoader for batch training. Created via create_dataloader() when batch training is enabled. None for full-batch training.
- compute_residual(network, batch_data: Tensor = None)[source]
Compute residual using the neural network.
This method evaluates the residual function at the specified data points. The residual represents how well the network solution satisfies the physics equation at those points. Ideally, residuals should be close to zero.
2. Parameters:
- networknn.Module
The neural network model (PINN solution)
- batch_datatorch.Tensor, optional
Batch of data points for residual computation. If None, uses self.data. This allows for batch training where different batches are used in different iterations.
2. Returns:
- torch.Tensor
Residual values at the evaluation points. Shape depends on the residual function but typically (N,) or (N, output_dim) for N points.
- create_dataloader(batch_size: int, shuffle: bool = False, drop_last: bool = False)[source]
Create a PyTorch DataLoader for batched training on this equation term.
For large datasets, batch training is more memory-efficient and can lead to better generalization. This method wraps the data tensor in a DataLoader that provides automatic batching and optional shuffling.
2. Parameters:
- batch_sizeint
Number of samples per batch. Smaller batches use less memory but may be noisier. Typical values: 32, 64, 128, 256, 512.
- shufflebool, optional
Whether to shuffle the data at each epoch. Default: False. Shuffling can improve convergence but changes the order of samples.
- drop_lastbool, optional
Whether to drop the last incomplete batch if the dataset size is not divisible by batch_size. Default: False.
2. Returns:
- DataLoader or None
PyTorch DataLoader for iterating over batches, or None if no data is available.
- get_dataloader()[source]
Retrieve the current DataLoader for this equation term.
2. Returns:
- DataLoader or None
The previously created DataLoader, or None if create_dataloader() has not been called yet or if the data was updated.
- update_data(new_data: Tensor)[source]
Update the data points for this equation term.
Useful for: - Adaptive sampling (resampling points in regions with high error) - Time-dependent problems (updating temporal points) - Progressive training (starting with coarse then fine grids)
2. Parameters:
- new_datatorch.Tensor
New input data tensor
- update_weight(new_weight: float)[source]
Update the weight of this equation term.
This allows dynamic adjustment of loss weights during training, which can be useful for: - Curriculum learning (gradually emphasizing different terms) - Adaptive weighting based on loss magnitudes - Manual tuning during training
2. Parameters:
- new_weightfloat
New weight value for this term
- class ai4plasma.piml.pinn.PINN(network)[source]
Bases:
BaseModel,ABCAdvanced Physics-Informed Neural Network (PINN) base class for solving PDEs and multi-physics problems.
This class implements a flexible framework for training neural networks to solve partial differential equations (PDEs) by incorporating physics constraints directly into the loss function. Unlike traditional supervised learning, PINNs learn from: 1. Governing equations (PDE residuals in the domain) 2. Boundary conditions (BCs) 3. Initial conditions (ICs) for time-dependent problems 4. Optional observational data
2. Attributes:
- networknn.Module
Neural network that approximates the solution u(x,t,…)
- equation_termsDict[str, EquationTerm]
Dictionary of physics equation terms indexed by name
- writerSummaryWriter or None
TensorBoard writer for logging (None if not configured)
- optimizertorch.optim.Optimizer or None
Optimizer for training
- loss_funcnn.Module
Loss function (typically MSELoss for PINNs)
- start_epochint
Starting epoch for training (non-zero if resumed)
- training_historyDict
Historical record of losses and epochs
- adaptive_weightsbool
Whether to use adaptive loss weighting
- weight_update_freqint
Frequency of weight updates (if adaptive)
- visualization_callbacksDict[str, VisualizationCallback]
Registered visualization callbacks
- add_equation(name: str, residual_fn: Callable, weight: float = 1.0, data: Tensor = None)[source]
Add a physics equation term to the model for loss calculation during training.
Registers a new equation term with the PINN model. Each equation represents one component of the multi-objective loss function being minimized during training. Typical equations include domain PDEs, boundary conditions, initial conditions, and data constraints.
2. Parameters:
- namestr
Unique identifier for this equation term. Used for: - Loss tracking and logging - Weight management and adjustment - Identifying equations in get_equation_info() - Accessing individual loss contributions
Should be descriptive (e.g., ‘domain_pde’, ‘bc_left’, ‘initial’, ‘data_fit’) to make training logs readable.
- residual_fnCallable
Function that computes the residual (error) at given points.
Signature: residual_fn(network: nn.Module, data: torch.Tensor) -> torch.Tensor
Parameters: - network: The neural network (nn.Module) being trained - data: Input tensor of evaluation points, shape (N, d)
Returns: - Residual tensor, typically shape (N,) or (N, output_dim) - Residual should be ~0 at points satisfying the equation - Must maintain computational graph for backpropagation
Implementation Tips: - Use torch.autograd.grad() to compute derivatives - Set create_graph=True to enable second derivatives - All operations should be differentiable - Return residual (not loss)
- weightfloat, optional
Loss weight for this equation term in the total loss. Default: 1.0
Interpretation: - weight = 1.0: default/unit contribution - weight > 1.0: emphasize this constraint - weight << 1.0: de-emphasize relative to other terms - weight = 0.0: effectively disabled
Typical values: - Domain PDE: 1.0-5.0 - Essential BC: 10.0-100.0 - Natural BC: 1.0-10.0 - Initial conditions: 5.0-10.0 - Data approximation: 0.01-1.0
Can be adjusted later via set_equation_weight() or enable_adaptive_weights().
- datatorch.Tensor, optional
Input data tensor where the equation is evaluated. Default: None
Shape: (N, d) where: - N: number of evaluation points - d: input dimensionality (problem dependent)
Examples: - 1D domain: shape (1000, 1) - 2D domain: shape (10000, 2) - 2D + time: shape (5000, 3) for (x, y, t) - Boundary on 2D: shape (100, 2)
Can be None initially and set later via set_equation_data(). For batched training, data is used to create DataLoaders automatically.
2. Returns:
- None
The equation is registered internally and accessible via get_equation(name).
2. Raises:
- None
No explicit error checking at registration; errors occur if: - residual_fn is not callable - data incompatible with network input dimension - name already exists (overwrites silently)
- calc_loss(weights_override: Dict[str, float] = None, batch_data: Dict[str, Tensor] = None) Tuple[Tensor, Dict[str, Tensor]][source]
Calculate total loss as weighted sum of individual physics equations.
This is the core loss computation function for PINN training. It evaluates all registered equation residuals, computes individual losses, and returns both the total weighted loss and a breakdown by equation.
- The total loss is computed as:
L_total = Σᵢ wᵢ * L(residualᵢ)
where wᵢ is the weight for equation i, and L() is the loss function (typically MSE).
2. Parameters:
- weights_overrideDict[str, float], optional
Temporary weight overrides for specific equation terms. Default: None - Used to temporarily change weights without modifying the model state - Only affects this loss calculation; permanent weights unchanged - Example: {‘domain’: 0.5, ‘boundary’: 10.0} - Terms not in dictionary use their current model weights
- batch_dataDict[str, torch.Tensor], optional
Batch data for each equation term (for batched training). Default: None - Maps equation names to their batch data tensors - Used when batch processing large datasets - If None, uses full dataset for each equation term - Example: {‘domain’: batch_tensor, ‘boundary’: batch_tensor}
2. Returns:
- total_losstorch.Tensor
Scalar tensor representing weighted sum of all losses. This is what gets backpropagated during training.
- loss_dictDict[str, torch.Tensor]
Dictionary mapping equation names to their individual loss values. Each value is a scalar tensor (unweighted individual loss). Example: {‘domain’: 0.01, ‘boundary’: 0.05, ‘initial’: 0.02}
2. Raises:
- RuntimeError
If no equations have been defined (empty equation_terms dict). Message: ‘No equations defined. Implement _define_loss_terms() in subclass.’
- compute_residual(name: str, input_data: Tensor = None) Tensor[source]
Compute residual for a specific equation term.
Evaluates how well the network satisfies a particular physics equation at specified points. Residual values represent the error in the PDE or constraint at those points. Ideally, residuals should be close to zero for a well-trained PINN.
2. Parameters:
- namestr
Name of the equation term whose residual to compute. Must be a registered equation name (from add_equation).
- input_datatorch.Tensor, optional
Input points where residual is evaluated. Default: None - If None: uses the stored data for this equation term - If provided: temporarily uses this data (does not modify model) - Should have same dimensionality as training data
2. Returns:
- torch.Tensor
Residual values at the evaluation points. Shape depends on the residual_fn implementation, typically: - (N,) or (N, 1) for scalar output - (N, m) for vector output with m components
2. Raises:
- ValueError
If the equation term name is not found in the model. Message: ‘Equation term “{name}” not found’
- create_lr_scheduler(scheduler_name: str, **kwargs)[source]
Create and set a learning rate scheduler for the current optimizer.
This method creates a learning rate scheduler that will be used during training. It requires an optimizer to be already set (via create_default_optimizer() or train()).
2. Parameters:
- scheduler_namestr
Name of the scheduler class from torch.optim.lr_scheduler Common options: - ‘StepLR’: Decay LR by gamma every step_size epochs - ‘ExponentialLR’: Decay LR exponentially with gamma each epoch - ‘CosineAnnealingLR’: Annealing with cosine function - ‘ReduceLROnPlateau’: Reduce LR when metric plateaus - ‘CyclicLR’: Cyclically vary learning rate - ‘LambdaLR’: Apply custom function to LR
- kwargsdict
Parameters specific to the scheduler Examples: - ‘StepLR’: step_size (int), gamma (float, default=0.1) - ‘ExponentialLR’: gamma (float) - ‘CosineAnnealingLR’: T_max (int), eta_min (float, default=0) - ‘ReduceLROnPlateau’: mode (‘min’/’max’), factor, patience, threshold, etc.
2. Returns:
- torch.optim.lr_scheduler._LRScheduler
The created scheduler instance
2. Raises:
- RuntimeError
If no optimizer has been set yet
- ValueError
If the specified scheduler name is not found
- create_optimizer(optimizer_name: str = 'Adam', lr: float = 0.0001, **kwargs)[source]
Create and set a optimizer for the network.
This method allows setting up an optimizer before calling train(), which enables users to create a learning rate scheduler using the optimizer before training starts.
2. Parameters:
- optimizer_namestr, optional
Name of the optimizer class from torch.optim. Default: ‘Adam’ Common options: ‘Adam’, ‘SGD’, ‘LBFGS’, ‘RMSprop’, ‘AdamW’, ‘Adamax’
- lrfloat, optional
Learning rate. Default: 1e-4
- kwargsdict
Additional parameters specific to the optimizer Examples: - For ‘Adam’: weight_decay, betas=(0.9, 0.999), eps, amsgrad - For ‘SGD’: momentum, weight_decay, nesterov - For ‘LBFGS’: max_iter, max_eval, line_search_fn
2. Returns:
- torch.optim.Optimizer
The created optimizer instance (also stored in self.optimizer)
2. Raises:
- ValueError
If the specified optimizer name is not found in torch.optim
- enable_adaptive_weights(enable: bool = True, update_freq: int = 10)[source]
Enable or disable adaptive weight adjustment during training.
Adaptive weighting automatically balances different physics equations by adjusting their loss weights based on current loss magnitudes. This helps prevent one term from dominating the total loss and ensures all constraints are satisfied. Particularly useful for: - Multi-physics problems with equations of different scales - Problems where loss magnitudes vary significantly between terms - Avoiding manual tuning of loss weights - Curriculum learning where priorities change during training
- The adaptive weight for each term is computed as:
new_weight = current_weight * (average_loss / current_term_loss)
This inverse scaling ensures that equations with larger residuals get higher weights to encourage the optimizer to satisfy them better.
2. Parameters:
- enablebool, optional
Whether to enable adaptive weighting. Default: True - True: Enable adaptive weight adjustments - False: Disable adaptive weighting (use fixed weights)
- update_freqint, optional
Frequency of weight updates in epochs. Default: 10 - The weights are updated every N epochs - Larger values = less frequent updates (more stable training) - Smaller values = more responsive to loss changes
- get_equation(name: str) EquationTerm | None[source]
Retrieve an equation term by its unique name.
This method provides access to a specific equation term object, useful for: - Inspecting equation properties and current weights - Modifying equation data or functions programmatically - Computing residuals for analysis - Debugging equation configuration
2. Parameters:
- namestr
Unique identifier of the equation term to retrieve. Should match the name used in add_equation() call.
2. Returns:
- EquationTerm or None
The requested EquationTerm object if it exists, None otherwise. The EquationTerm contains: - name: identifier string - residual_fn: function for computing residuals - weight: current weight in the loss function - data: input data points - dataloader: optional batched data loader
- get_equation_info() Dict[source]
Get comprehensive information about all defined equations.
This method returns a summary of all registered equation terms, including their weights and associated data shapes. Useful for: - Inspecting model configuration before training - Debugging weight imbalances - Verifying data shapes and problem dimensions - Monitoring multi-physics problem setup
2. Returns:
- Dict
Dictionary with equation names as keys and info dicts as values.
- get_training_history() Dict[source]
Retrieve the complete training history (losses and epochs).
Returns a copy of the internal training history, useful for: - Analyzing convergence behavior after training - Creating custom plots of loss over time - Comparing different training runs - Detecting training anomalies or divergence - Implementing early stopping or custom training logic
2. Returns:
- Dict
Copy of training history with the following keys: - ‘loss’: List[float] - Total weighted loss values, one per epoch - ‘epoch’: List[int] - Epoch numbers (typically 1, 2, 3, …)
- static plot_1d_comparison(x_data: ndarray, y_pred: ndarray, y_true: ndarray = None, y_ref: ndarray = None, title: str = '1D Comparison', xlabel: str = 'x', ylabel: str = 'y') Figure[source]
Create a 1D comparison plot of predictions vs ground truth/reference.
Generates a line plot comparing predicted and reference solutions along a 1D domain. Useful for visualizing solution accuracy in 1D problems.
2. Parameters:
- x_datanp.ndarray
x-coordinates (domain points). Shape: (N,)
- y_prednp.ndarray
Predicted solution values. Shape: (N,)
- y_truenp.ndarray, optional
Ground truth/analytical solution. Default: None If provided, it is plotted as a dashed red line.
- y_refnp.ndarray, optional
Reference/alternative solution. Default: None If provided, it is plotted as a dash-dot green line.
- titlestr, optional
Plot title. Default: ‘1D Comparison’
- xlabelstr, optional
x-axis label. Default: ‘x’
- ylabelstr, optional
y-axis label. Default: ‘y’
2. Returns:
- plt.Figure
Matplotlib figure object that can be displayed, saved, or logged.
- static plot_2d_comparison(data_pred: ndarray, data_true: ndarray = None, title_pred: str = 'Prediction', title_true: str = 'Ground Truth', cbar_label: str = 'value', figsize=(16, 6)) Figure[source]
Create side-by-side 2D comparison plots (predicted vs ground truth).
Generates a figure with 1 or 2 subplots to compare 2D field predictions with analytical or numerical solutions side-by-side for easy comparison.
2. Parameters:
- data_prednp.ndarray
Predicted 2D field data. Shape: (M, N)
- data_truenp.ndarray, optional
Ground truth/analytical 2D field data. Default: None - If provided, creates 2-panel figure - If None, creates 1-panel figure showing only prediction
- title_predstr, optional
Title for prediction subplot. Default: ‘Prediction’
- title_truestr, optional
Title for ground truth subplot. Default: ‘Ground Truth’
- cbar_labelstr, optional
Colorbar label for both subplots. Default: ‘value’
- figsizetuple, optional
Figure size (width, height) in inches. Default: (16, 6) - For 2-panel: (16, 6) is typical - For 1-panel: (8, 6)
2. Returns:
- plt.Figure
Matplotlib figure with 1 or 2 subplots and colorbars.
- static plot_2d_heatmap(data: ndarray, title: str = '2D Heatmap', xlabel: str = 'x', ylabel: str = 'y', cbar_label: str = 'value') Figure[source]
Create a 2D heatmap (contour plot) visualization.
Generates a 2D false-color image showing spatial field values. Commonly used for visualizing: - Solution fields in 2D spatial domains - Error distributions over 2D regions - Residual magnitude maps
2. Parameters:
- datanp.ndarray
2D array of values to visualize. Shape: (M, N) Each element represents the field value at that grid point.
- titlestr, optional
Plot title. Default: ‘2D Heatmap’
- xlabelstr, optional
x-axis label. Default: ‘x’
- ylabelstr, optional
y-axis label. Default: ‘y’
- cbar_labelstr, optional
Colorbar label. Default: ‘value’
2. Returns:
- plt.Figure
Matplotlib figure object with colorbar for value scale.
- static plot_error_heatmap(data_pred: ndarray, data_true: ndarray, title: str = 'Absolute Error', cbar_label: str = 'error') Figure[source]
Create an error heatmap showing absolute difference between predictions and truth.
Visualizes the spatial distribution of prediction errors using a false-color heatmap.
2. Parameters:
- data_prednp.ndarray
Predicted 2D field data. Shape: (M, N)
- data_truenp.ndarray
Ground truth 2D field data. Shape: (M, N) Must have same shape as data_pred.
- titlestr, optional
Plot title. Default: ‘Absolute Error’
- cbar_labelstr, optional
Colorbar label. Default: ‘error’
2. Returns:
- plt.Figure
Matplotlib figure with error heatmap and colorbar.
- static plot_residual_distribution(residuals: ndarray, title: str = 'Residual Distribution', xlabel: str = 'Residual', ylabel: str = 'Frequency') Figure[source]
Create a histogram visualization of residual distribution.
2. Parameters:
- residualsnp.ndarray
Array of residual values, typically flattened. Shape: (N,) or can be any shape (will be flattened internally). Values represent how well each equation is satisfied at evaluation points.
- titlestr, optional
Plot title. Default: ‘Residual Distribution’
- xlabelstr, optional
x-axis label. Default: ‘Residual’
- ylabelstr, optional
y-axis label. Default: ‘Frequency’
2. Returns:
- plt.Figure
Matplotlib figure with histogram and statistical annotations.
- predict(input_data: Tensor) Tensor[source]
Make predictions using the trained network.
Evaluates the neural network at given input points to generate predictions. The network operates in evaluation mode (no dropout/batchnorm effects) and without gradient computation for efficiency.
2. Parameters:
- input_datatorch.Tensor
Input data for prediction. Shape depends on problem dimensionality: - 1D problem: (N, 1) where N is number of points - 2D problem: (N, 2) for spatial coordinates - 2D+time: (N, 3) for (x, y, t) coordinates
Input should be the same dimensionality as training data.
2. Returns:
- torch.Tensor
Network output (predictions) at input points. Shape typically (N, output_dim) where output_dim depends on problem.
Examples: - For single scalar field: shape (N, 1) - For vector field: shape (N, vector_dim) - Tensor is returned on the same device as the network
- register_visualization_callback(callback: VisualizationCallback)[source]
Register a visualization callback to be executed during training.
Visualization callbacks enable real-time plotting and analysis during training without modifying the core training loop. Useful for: - Monitoring solution evolution over training - Comparing predicted vs analytical solutions - Visualizing residuals and error distributions - Creating animations of training progress - Tracking problem-specific metrics
The callback’s visualize() method is called periodically (every log_freq epochs) and the returned figures are automatically logged to TensorBoard.
2. Parameters:
- callbackVisualizationCallback
Visualization callback instance. Must be a subclass of VisualizationCallback and implement the visualize() method. The callback should define: - name: unique identifier for the callback - log_freq: execution frequency (every N epochs) - visualize(): method that creates and returns matplotlib figures
- remove_equation(name: str)[source]
Remove an equation term from the model by name.
This method allows dynamic removal of physics equations that were previously added to the PINN model. Useful for removing constraints during different training phases or switching between problem configurations.
2. Parameters:
- namestr
Unique name of the equation term to remove. Must be a name that was previously added via add_equation().
- set_all_equation_weights(weights: Dict[str, float])[source]
Set weights for multiple equation terms at once.
This convenience method allows updating all equation weights in a single call, rather than calling set_equation_weight() multiple times. Useful for: - Switching between different loss configurations - Curriculum learning schedules (gradually changing all weights) - Rebalancing the loss function for different training phases - Implementing weight scheduling algorithms
2. Parameters:
- weightsDict[str, float]
Dictionary mapping equation term names to their new weight values. Format: {‘equation_name’: weight_value, …} Example: {‘domain’: 1.0, ‘boundary’: 20.0, ‘initial’: 5.0}
2. Raises:
- ValueError
If any equation name in the dictionary is not found in the model. The error is raised by the internal set_equation_weight() call.
- set_equation_data(name: str, data: Tensor)[source]
Update the evaluation points (data) for a specific equation term.
This method allows changing which points are used to evaluate a particular equation’s residual. Essential for: - Adaptive mesh refinement / point resampling in high-error regions - Progressive training (coarse to fine grids) - Time-stepping problems (updating temporal points) - Importance sampling (focusing on difficult regions) - Dynamically adding new training data during training
2. Parameters:
- namestr
Unique name of the equation term whose data to update. Must be a name that was previously added via add_equation().
- datatorch.Tensor
New input data tensor for this equation term. Shape should be (N, d) where N is number of points and d is dimensionality. Example shapes: - 1D problem: (1000, 1) - 2D problem: (10000, 2) - 2D+time problem: (5000, 3) for (x, y, t) coordinates
2. Raises:
- ValueError
If the equation name is not found in the model. Message: ‘Equation term “{name}” not found’
- set_equation_weight(name: str, weight: float)[source]
Update the weight (loss contribution) of a specific equation term.
The weight controls how much this particular equation contributes to the total loss function during training. This is critical for: - Balancing multi-physics problems with competing objectives - Emphasizing important constraints (e.g., boundary conditions) - Curriculum learning (gradually changing weights during training) - Handling multi-scale problems with disparate magnitudes
2. Parameters:
- namestr
Unique name of the equation term whose weight to update. Must be a name that was previously added via add_equation().
- weightfloat
New weight value for this equation term. Interpretation: - weight > 0: emphasizes this equation in the total loss - weight = 1.0: default/baseline influence - weight >> 1.0: strongly enforces the constraint (typical: 5-100) - weight << 1.0: reduces constraint importance (typical: 0.01-0.1) - weight = 0: effectively disables the equation
2. Raises:
- ValueError
If the equation name is not found in the model. Message: ‘Equation term “{name}” not found’
- train(num_epochs, optimizer=None, optimizer_cfg: Dict = None, lr: float = 0.0001, lr_scheduler=None, weights_override: Dict[str, float] = None, print_loss: bool = True, print_loss_freq: int = 1, tensorboard_logdir: str = None, save_final_model: bool = False, final_model_path: str = None, checkpoint_dir: str = None, checkpoint_freq: int = 1, resume_from: str = None, batch_size: int | None = None, shuffle_batches: bool = False, visualization_kwargs: Dict = None)[source]
Train the PINN model with advanced options and optional batch processing.
2. Parameters:
num_epochs (int): number of epochs to train
optimizer (torch.optim.Optimizer): pre-built optimizer instance (optional)
- optimizer_cfg (dict): configuration to build optimizer. Example:
{‘name’: ‘Adam’, ‘params’: {‘lr’: 1e-4, ‘weight_decay’: 0}}
lr (float): default learning rate if optimizer is not specified (default: 1e-4)
- lr_scheduler: optional pre-built learning rate scheduler instance.
Note: Can also be created via create_lr_scheduler() before calling train()
weights_override (dict): temporarily override loss weights for specific terms
print_loss (bool): whether to print loss each epoch
print_loss_freq (int): print loss every N epochs
tensorboard_logdir (str): path for tensorboard logs
save_final_model (bool): whether to save final model
final_model_path (str): path to save final model
checkpoint_dir (str): directory for epoch checkpoints
checkpoint_freq (int): save checkpoint every N epochs
resume_from (str): path to checkpoint to resume from
batch_size (int): batch size for training data loading. If None (default), loads all data at once.
shuffle_batches (bool): whether to shuffle batches during training
visualization_kwargs (dict): additional arguments to pass to visualization callbacks
2. Optimizer Setup Approaches (in order of priority):
Pre-set via create_optimizer(): Most flexible approach >>> pinn.create_default_optimizer(‘Adam’, lr=1e-3) >>> pinn.train(num_epochs=1000)
Pre-set via set_optimizer(): Direct optimizer assignment >>> opt = torch.optim.SGD(pinn.network.parameters(), lr=0.01, momentum=0.9) >>> pinn.set_optimizer(opt) >>> pinn.train(num_epochs=1000)
Pass optimizer parameter to train() >>> pinn.train(num_epochs=1000, optimizer=custom_optimizer)
Pass optimizer_cfg parameter to train() >>> pinn.train(num_epochs=1000, optimizer_cfg={‘name’: ‘SGD’, ‘params’: {‘lr’: 0.01}})
Default: Adam with specified lr >>> pinn.train(num_epochs=1000, lr=1e-3)
2. Learning Rate Scheduler Setup Approaches:
Pre-created via create_lr_scheduler(): Recommended approach >>> pinn.create_optimizer(‘Adam’, lr=1e-3) >>> pinn.create_lr_scheduler(‘StepLR’, step_size=500, gamma=0.5) >>> pinn.train(num_epochs=2000)
Pass pre-built scheduler to train() >>> pinn.create_optimizer(‘Adam’, lr=1e-3) >>> scheduler = torch.optim.lr_scheduler.StepLR(pinn.optimizer, step_size=500, gamma=0.5) >>> pinn.train(num_epochs=2000, lr_scheduler=scheduler)
- class ai4plasma.piml.pinn.VisualizationCallback(name: str, log_freq: int = 10)[source]
Bases:
objectBase class for custom visualization callbacks executed during PINN training.
Visualization callbacks allow you to create and log custom plots or figures during training without modifying the core PINN training loop. This is useful for: - Monitoring solution evolution over time - Comparing predictions with analytical solutions - Visualizing residuals and errors - Creating animations of training progress - Tracking problem-specific metrics
The callback is executed at regular intervals (every log_freq epochs) and can generate matplotlib figures that are automatically logged to TensorBoard.
Subclasses must implement the visualize() method which receives: - The current network state - Current epoch number - TensorBoard writer - Additional kwargs from the training loop (e.g., loss_dict, total_loss)
2. Attributes:
- namestr
Unique identifier for this callback. Used in TensorBoard logging paths and console output. Example: ‘1D_Solution’, ‘2D_Heatmap’
- log_freqint
Frequency of visualization (every N epochs). Higher values reduce overhead but provide less frequent feedback. Typical range: 10-100.
- abstractmethod visualize(network, epoch: int, writer: SummaryWriter, **kwargs) Dict[str, Figure][source]
Perform custom visualization and return matplotlib figures.
This method is called automatically during training at the specified frequency. It should create one or more matplotlib figures showing relevant information about the current state of training.
2. Parameters:
- networknn.Module
The neural network being trained. Set to eval() mode before inference if you don’t want dropout/batchnorm to affect visualization.
- epochint
Current epoch number (1-indexed). Useful for labeling plots.
- writerSummaryWriter
TensorBoard writer instance. Can be used for additional custom logging if needed, though figures are logged automatically.
- kwargsdict
Additional arguments passed from the training loop, which may include: - ‘loss_dict’: Dict mapping equation names to their loss values - ‘total_loss’: Total weighted loss value - Any custom kwargs passed to train() via visualization_kwargs parameter
2. Returns:
- Dict[str, plt.Figure]
Dictionary mapping plot names to matplotlib Figure objects. Each figure will be logged to TensorBoard at the path: ‘Visualization/{callback_name}/{plot_name}’
Example return values: {‘comparison’: fig1, ‘error_heatmap’: fig2, ‘residuals’: fig3}
Return None or empty dict if no visualization should be logged.
2.3. CS-PINN (Conservative PINN)
CS-PINN models for plasma simulations in AI4Plasma.
This module provides specialized Physics-Informed Neural Network (PINN) implementations and visualization tools for simulating plasma discharge using cubic spline interpolation and automatic differentiation on PyTorch.
2.3.1. CS-PINN Classes
The CS-PINN (Coefficient-Subnet Physics-Informed Neural Network) framework integrates:
StaArc1DModel: PINN model for steady-state 1D arc plasma.
TraArc1DTempModel: PINN model for transient 1D arc plasma without radial velocity.
TraArc1DModel: PINN model for transient 1D arc plasma with radial velocity.
StaArc1DVisCallback: Visualization callback for steady-state arc plasma simulation.
TraArc1DTempVisCallback: Visualization callback for transient arc plasma simulation.
2.3.2. CS-PINN References
- [1] L. Zhong, B. Wu, and Y. Wang, “Low-temperature plasma simulation based on
physics-informed neural networks: Frameworks and preliminary applications,” Physics of Fluids, vol. 34, no. 8, p. 087116, 2022.
- [2] L. Zhong, Q. Gu, and B. Wu, “Deep learning for thermal plasma simulation:
Solving 1-D arc model as an example,” Computer Physics Communications, vol. 257, p. 107496, 2020.
- class ai4plasma.piml.cs_pinn.StaArc1DModel(R, I, Tb=300.0, T_red=10000.0, backbone_net=FNN( (act_fun): Tanh() (net): Sequential( (linear1): Linear(in_features=1, out_features=100, bias=True) (activation1): Tanh() (linear2): Linear(in_features=100, out_features=100, bias=True) (activation2): Tanh() (linear3): Linear(in_features=100, out_features=100, bias=True) (activation3): Tanh() (linear4): Linear(in_features=100, out_features=100, bias=True) (activation4): Tanh() (linear5): Linear(in_features=100, out_features=1, bias=True) ) ), train_data_size=500, test_data_size=501, sample_mode='uniform', GL_degree=100, prop: ArcPropSpline = None)[source]
Bases:
PINNPINN model for solving 1D steady-state arc discharge energy equation.
Implements a Physics-Informed Neural Network specifically designed for arc plasma simulations. Solves the Elenbaas-Heller equation considering Joule heating, thermal conduction, and radiation losses. The model enforces automatic boundary conditions through the StaArc1DNet architecture.
- R
Arc radius [m]
- Type:
float
- I
Arc current [A]
- Type:
float
- T_red
Temperature reduction factor for normalization [K]
- Type:
float
- Tb
Boundary temperature at r = R [K]
- Type:
float
- Xq
Gauss-Legendre quadrature abscissae for arc conductance integral
- Type:
torch.Tensor
- Wq
Gauss-Legendre quadrature weights
- Type:
torch.Tensor
- prop
Material properties interpolation object
- Type:
- Parameters(Constructor)
- ------------------------
- R
Arc radius [m]. Normalized to 1.0 in the network.
- Type:
float
- I
Arc current [A]. Used to compute electric field and Joule heating.
- Type:
float
- Tb
Boundary temperature at r = R [K] (default: 300.0). Set to ambient temperature at the arc boundary.
- Type:
float, optional
- T_red
Temperature reduction factor for normalization [K] (default: 1e4). Used for non-dimensionalization: T_normalized = T_physical / T_red
- Type:
float, optional
- backbone_net
Backbone neural network that outputs the unscaled network function. Default: FNN with architecture [1, 100, 100, 100, 100, 1]. Will be wrapped in StaArc1DNet for boundary condition enforcement.
- Type:
nn.Module, optional
- train_data_size
Number of training collocation points in the domain (default: 500). Collocation points are sampled according to sample_mode.
- Type:
int, optional
- test_data_size
Number of test/evaluation points for visualization (default: 501).
- Type:
int, optional
- sample_mode
Collocation point sampling strategy (default: ‘uniform’). Options: ‘uniform’ (uniform grid), ‘lhs’ (Latin hypercube), ‘random’
- Type:
str, optional
- GL_degree
Degree of Gauss-Legendre quadrature for arc conductance integral (default: 100). Higher degree provides better accuracy for integral computation.
- Type:
int, optional
- prop
Arc material properties object (ArcPropSpline instance) for temperature- dependent properties like κ, σ, ε_nec. If None, properties must be provided externally.
- Type:
ArcPropSpline, optional
- class ai4plasma.piml.cs_pinn.StaArc1DNet(network, R=1.0, Tb=0.03)[source]
Bases:
ModuleNeural network wrapper for solving 1D stationary arc equation with automatic boundary condition enforcement.
- network
Backbone neural network (e.g., FNN) that maps r → N(r)
- Type:
nn.Module
- R
Normalized arc radius (default: 1.0)
- Type:
float
- Tb
Normalized boundary temperature at r = R (default: 0.03)
- Type:
float
- Parameters(Constructor)
- ------------------------
- network
Backbone neural network (e.g., FNN) that maps r → N(r)
- Type:
nn.Module
- R
Normalized arc radius (default: 1.0)
- Type:
float, optional
- Tb
Normalized boundary temperature at r = R (default: 0.03)
- Type:
float, optional
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ai4plasma.piml.cs_pinn.StaArc1DVisCallback(model: StaArc1DModel, log_freq: int = 50, save_history: bool = True, history_freq: int = None, T_csv_file: str = None, gif_enabled: bool = False, gif_dir: str = None, gif_freq: int = None, gif_duration_ms: int = 300, gif_cleanup_tmp: bool = True)[source]
Bases:
VisualizationCallbackCustom visualization callback for 1D stationary arc PINN training.
- This callback provides comprehensive monitoring of the arc model training:
Real-time TensorBoard logging during training
Training history tracking for post-training animation
Temperature distribution evolution and error tracking
Loss convergence monitoring
- save_final_results(network: Module, save_dir: str = None, epoch: int = None, **kwargs)[source]
Save final result figures similar to TensorBoard panels and a dedicated loss plot.
Outputs: - final_panels.png: the same 2x2 panel figure used in TensorBoard - loss_curve.png: standalone loss curve using collected history
Parameters: - network: trained network to generate the final panel figure - save_dir: output directory; defaults to gif_dir - epoch: epoch number to show in the title; if None, uses last recorded or 0 - kwargs: optional training info for annotations (e.g., total_loss)
- save_gif(gif_path: str = None, duration_ms: int = None, loop: int = 0)[source]
Assemble saved frames into a GIF showing T vs T_ref and loss evolution.
Parameters: - gif_path: optional output path; defaults to <gif_dir>/<gif_filename> - duration_ms: per-frame duration in milliseconds; defaults to gif_duration_ms - loop: number of loops (0 = infinite)
- visualize(network, epoch: int, writer: SummaryWriter, **kwargs)[source]
Generate visualization plots for the current training epoch.
This method is called automatically by PINN.train() at specified intervals. It performs two main tasks: 1. Creates comparison plots and logs to TensorBoard 2. Saves prediction snapshots for training animation (if enabled)
2. Parameters:
- networknn.Module
The neural network being trained
- epochint
Current training epoch number
- writerSummaryWriter
TensorBoard writer for logging figures
- kwargsdict
- Additional training information:
‘total_loss’: Total loss value (torch.Tensor or float)
‘loss_dict’: Dictionary of individual loss terms
2. Returns:
- dict
Dictionary mapping plot names to matplotlib figures
- class ai4plasma.piml.cs_pinn.TraArc1DModel(R, I, Tb=300.0, Tinit_func=None, T_red=10000.0, t_red=0.001, backbone_net=FNN( (act_fun): Tanh() (net): Sequential( (linear1): Linear(in_features=2, out_features=300, bias=True) (activation1): Tanh() (linear2): Linear(in_features=300, out_features=300, bias=True) (activation2): Tanh() (linear3): Linear(in_features=300, out_features=300, bias=True) (activation3): Tanh() (linear4): Linear(in_features=300, out_features=300, bias=True) (activation4): Tanh() (linear5): Linear(in_features=300, out_features=300, bias=True) (activation5): Tanh() (linear6): Linear(in_features=300, out_features=300, bias=True) (activation6): Tanh() (linear7): Linear(in_features=300, out_features=2, bias=True) ) ), train_data_x_size=200, train_data_t_size=100, sample_mode='uniform', prop: ArcPropSpline = None)[source]
Bases:
PINNPINN model for solving coupled 1D transient arc plasma equations with temperature and radial velocity.
Implements a Physics-Informed Neural Network specifically designed for simulating transient (time-dependent) arc discharge phenomena with full coupling between energy and momentum transport. Solves coupled nonlinear transient equations considering temperature evolution, radial velocity, thermal conduction, convection, and radiation effects.
2. Physical Constraints
T(R, t) = Tb (Dirichlet boundary condition at arc radius)
∂T/∂r(0, t) = 0 (symmetry condition for temperature at centerline)
V(0, t) = 0 (symmetry condition for velocity at centerline)
T(r, 0) = Tinit_func(r) (initial temperature distribution)
Mass conservation through continuity equation
2. Parameters (Constructor)
- Rfloat
Arc radius [m]
- Ifloat
Arc current [A]
- Tbfloat, optional
Boundary temperature at r = R [K] (default: 300.0)
- Tinit_funccallable
Function that returns initial temperature profile T(r)
- T_redfloat, optional
Temperature reduction factor for normalization [K] (default: 1e4)
- t_redfloat, optional
Time reduction factor for normalization [s] (default: 1e-3)
- backbone_netnn.Module, optional
Backbone neural network with 2 outputs (T, V) (default: FNN with 7 layers)
- train_data_x_sizeint, optional
Number of spatial training collocation points (default: 200)
- train_data_t_sizeint, optional
Number of temporal training collocation points (default: 100)
- sample_modestr, optional
Sampling strategy: ‘uniform’, ‘lhs’, or ‘random’ (default: ‘uniform’)
- propArcPropSpline, optional
Arc material properties object (default: None)
- class ai4plasma.piml.cs_pinn.TraArc1DNet(network, R=1.0, Tb=0.03)[source]
Bases:
ModuleNeural network wrapper for coupled 1D transient arc equations with automatic boundary condition enforcement.
This network simultaneously predicts both temperature T(r, t) and radial velocity V(r, t) while automatically satisfying multiple boundary conditions: 1. Temperature boundary condition: T(R, t) = Tb (at arc boundary) 2. Velocity symmetry condition: V(0, t) = 0 (at centerline)
- The network applies transformations to the backbone network outputs:
T(r, t) = (r - R) · N₁(r, t) + Tb V(r, t) = r · N₂(r, t)
where N₁(r, t) and N₂(r, t) are the two outputs of the backbone network. This design is for coupled energy and momentum transport in arc discharges, where the automatic boundary condition enforcement reduces training complexity.
2. Parameters (Constructor)
- networknn.Module
Backbone neural network (e.g., FNN) that maps (r, t) → [N₁(r, t), N₂(r, t)] Input shape: [batch_size, 2] where [:, 0]=r and [:, 1]=t Output shape: [batch_size, 2] representing [temperature, velocity] predictions
- Rfloat, optional
Normalized arc radius (default: 1.0)
- Tbfloat, optional
Normalized boundary temperature at r = R (default: 0.03)
- forward(x)[source]
Forward pass with automatic boundary condition enforcement for both T and V.
- Parameters:
x (torch.Tensor) – Input tensor of shape [batch_size, 2] where: x[:, 0] : r (normalized radius, range [0, 1]) x[:, 1] : t (normalized time, range [0, 1])
- Returns:
T (torch.Tensor) – Temperature T(r, t) in normalized units, shape [batch_size, 1] Satisfies T(R, t) = Tb for all t by construction
V (torch.Tensor) – Radial velocity V(r, t) in normalized units, shape [batch_size, 1] Satisfies V(0, t) = 0 for all t by construction (symmetry)
- class ai4plasma.piml.cs_pinn.TraArc1DTempModel(R, I, Tb=300.0, Tinit_func=None, T_red=10000.0, t_red=0.001, backbone_net=FNN( (act_fun): Tanh() (net): Sequential( (linear1): Linear(in_features=2, out_features=200, bias=True) (activation1): Tanh() (linear2): Linear(in_features=200, out_features=200, bias=True) (activation2): Tanh() (linear3): Linear(in_features=200, out_features=200, bias=True) (activation3): Tanh() (linear4): Linear(in_features=200, out_features=200, bias=True) (activation4): Tanh() (linear5): Linear(in_features=200, out_features=200, bias=True) (activation5): Tanh() (linear6): Linear(in_features=200, out_features=200, bias=True) (activation6): Tanh() (linear7): Linear(in_features=200, out_features=1, bias=True) ) ), train_data_x_size=200, train_data_t_size=100, sample_mode='uniform', prop: ArcPropSpline = None)[source]
Bases:
PINNPINN model for solving 1D transient arc discharge energy equation without radial velocity.
This class implements a Physics-Informed Neural Network specifically designed for simulating transient (time-dependent) arc discharge phenomena. The model solves the nonlinear transient energy balance equation considering time-dependent temperature evolution, thermal conduction, and radiation effects.
2. Physical Constraints:
T(R, t) = Tb (Dirichlet boundary condition at arc radius)
∂T/∂r(0, t) = 0 (symmetry condition at centerline)
T(r, 0) = Tinit_func(r) (initial temperature distribution)
- class ai4plasma.piml.cs_pinn.TraArc1DTempNet(network, R=1.0, Tb=0.03)[source]
Bases:
ModuleNeural network wrapper for solving 1D transient arc temperature equation with automatic boundary condition enforcement.
This network automatically satisfies the Dirichlet boundary condition at r = R (arc radius) for all time steps by construction, eliminating the need for explicit boundary loss terms. The temperature output is modified to ensure T(R, t) = Tb for all t.
- The network applies a transformation to the backbone network output:
T(r, t) = (r - R) · N(r, t) + Tb
where N(r, t) is the backbone network output (takes both r and t as inputs), R is the normalized arc radius (typically 1.0), and Tb is the boundary temperature (normalized). At r = R, the boundary condition T(R, t) = Tb is satisfied for all t. Unlike the stationary case (StaArc1DNet), this network handles time-dependent temperature evolution.
2. Parameters (Constructor)
- networknn.Module
Backbone neural network (e.g., FNN) that maps (r, t) → N(r, t) Input shape: [batch_size, 2] where [:, 0]=r and [:, 1]=t Output shape: [batch_size, 1] representing network prediction
- Rfloat, optional
Normalized arc radius (default: 1.0)
- Tbfloat, optional
Normalized boundary temperature at r = R (default: 0.03)
- forward(x)[source]
Forward pass with automatic boundary condition enforcement.
- Parameters:
x (torch.Tensor) – Input tensor of shape [batch_size, 2] where: x[:, 0] : r (normalized radius, range [0, 1]) x[:, 1] : t (normalized time, range [0, 1])
- Returns:
Temperature T(r, t) in normalized units, shape [batch_size, 1] Satisfies T(R, t) = Tb for all t by construction
- Return type:
torch.Tensor
- class ai4plasma.piml.cs_pinn.TraArc1DTempVisCallback(model: TraArc1DTempModel, log_freq: int = 50, save_history: bool = True, history_freq: int = None, x_eval: ndarray = array([[0.], [0.005], [0.01], [0.015], [0.02], [0.025], [0.03], [0.035], [0.04], [0.045], [0.05], [0.055], [0.06], [0.065], [0.07], [0.075], [0.08], [0.085], [0.09], [0.095], [0.1], [0.105], [0.11], [0.115], [0.12], [0.125], [0.13], [0.135], [0.14], [0.145], [0.15], [0.155], [0.16], [0.165], [0.17], [0.175], [0.18], [0.185], [0.19], [0.195], [0.2], [0.205], [0.21], [0.215], [0.22], [0.225], [0.23], [0.235], [0.24], [0.245], [0.25], [0.255], [0.26], [0.265], [0.27], [0.275], [0.28], [0.285], [0.29], [0.295], [0.3], [0.305], [0.31], [0.315], [0.32], [0.325], [0.33], [0.335], [0.34], [0.345], [0.35], [0.355], [0.36], [0.365], [0.37], [0.375], [0.38], [0.385], [0.39], [0.395], [0.4], [0.405], [0.41], [0.415], [0.42], [0.425], [0.43], [0.435], [0.44], [0.445], [0.45], [0.455], [0.46], [0.465], [0.47], [0.475], [0.48], [0.485], [0.49], [0.495], [0.5], [0.505], [0.51], [0.515], [0.52], [0.525], [0.53], [0.535], [0.54], [0.545], [0.55], [0.555], [0.56], [0.565], [0.57], [0.575], [0.58], [0.585], [0.59], [0.595], [0.6], [0.605], [0.61], [0.615], [0.62], [0.625], [0.63], [0.635], [0.64], [0.645], [0.65], [0.655], [0.66], [0.665], [0.67], [0.675], [0.68], [0.685], [0.69], [0.695], [0.7], [0.705], [0.71], [0.715], [0.72], [0.725], [0.73], [0.735], [0.74], [0.745], [0.75], [0.755], [0.76], [0.765], [0.77], [0.775], [0.78], [0.785], [0.79], [0.795], [0.8], [0.805], [0.81], [0.815], [0.82], [0.825], [0.83], [0.835], [0.84], [0.845], [0.85], [0.855], [0.86], [0.865], [0.87], [0.875], [0.88], [0.885], [0.89], [0.895], [0.9], [0.905], [0.91], [0.915], [0.92], [0.925], [0.93], [0.935], [0.94], [0.945], [0.95], [0.955], [0.96], [0.965], [0.97], [0.975], [0.98], [0.985], [0.99], [0.995], [1.]], dtype=float32), t_eval: list = [0.1, 0.5, 0.9], T_csv_file: list[str] = ['', '', ''], gif_enabled: bool = False, gif_dir: str = None, gif_freq: int = None, gif_duration_ms: int = 300, gif_cleanup_tmp: bool = True)[source]
Bases:
VisualizationCallbackCustom visualization callback for 1D transient arc PINN training (temperature only).
- This callback provides comprehensive monitoring of the transient arc model training:
Real-time TensorBoard logging during training
Training history tracking for post-training animation
Temperature distribution evolution at multiple time steps
Material property evolution over time
Loss convergence monitoring
- save_final_results(network: Module, save_dir: str = None, epoch: int = None, **kwargs)[source]
Save final result figures showing temperature evolution and loss curve.
Outputs: - final_panels_transient.png: multi-panel figure with temperature at different times and loss curve - loss_curve_transient.png: standalone loss curve plot - center_temp_evolution.png: center temperature evolution over time at different epochs
2. Parameters:
- networknn.Module
Trained network for final prediction
- save_dirstr, optional
Output directory; defaults to gif_dir
- epochint, optional
Epoch number for title; if None, uses last recorded or 0
- kwargsdict
Optional training info
- save_gif(gif_path: str = None, duration_ms: int = None, loop: int = 0)[source]
Assemble saved frames into a GIF showing transient temperature evolution and loss.
2. Parameters:
- gif_pathstr, optional
Output path for the GIF; defaults to <gif_dir>/training_animation_transient.gif
- duration_msint, optional
Per-frame duration in milliseconds; defaults to gif_duration_ms
- loopint, default=0
Number of loops (0 = infinite)
- visualize(network, epoch: int, writer: SummaryWriter, **kwargs)[source]
Generate visualization plots for the current training epoch.
2. Parameters:
- networknn.Module
The neural network being trained
- epochint
Current training epoch number
- writerSummaryWriter
TensorBoard writer for logging figures
- kwargsdict
Additional training information (e.g., ‘total_loss’)
2. Returns:
- dict
Dictionary mapping plot names to matplotlib figures
- class ai4plasma.piml.cs_pinn.TraArc1DVelNet(network)[source]
Bases:
ModuleNeural network wrapper for solving 1D transient arc with radial velocity using automatic boundary condition enforcement.
This network automatically satisfies the boundary conditions at r = 0 (symmetry) by construction. The velocity output is modified to ensure V(0, t) = 0 for all time steps, which is a physical requirement due to radial symmetry.
- The network applies a transformation to the backbone network output:
V(r, t) = r · N(r, t)
where N(r, t) is the backbone network output (takes both r and t as inputs), and r is the normalized radial coordinate. At r = 0, the symmetry condition V(0, t) = 0 is automatically satisfied. In arc discharge simulations, the radial velocity must be zero due to cylindrical symmetry.
2. Parameters (Constructor)
- networknn.Module
Backbone neural network (e.g., FNN) that maps (r, t) → N(r, t) Input shape: [batch_size, 2] where [:, 0]=r and [:, 1]=t Output shape: [batch_size, 1] representing velocity prediction
- forward(x)[source]
Forward pass with automatic symmetry condition enforcement.
- Parameters:
x (torch.Tensor) – Input tensor of shape [batch_size, 2] where: x[:, 0] : r (normalized radius, range [0, 1]) x[:, 1] : t (normalized time, range [0, 1])
- Returns:
Radial velocity V(r, t) in normalized units, shape [batch_size, 1] Satisfies V(0, t) = 0 for all t by construction (symmetry)
- Return type:
torch.Tensor
- class ai4plasma.piml.cs_pinn.TraArc1DVisCallback(model: TraArc1DModel, log_freq: int = 50, save_history: bool = True, history_freq: int = None, x_eval: ndarray = array([[0.], [0.005], [0.01], [0.015], [0.02], [0.025], [0.03], [0.035], [0.04], [0.045], [0.05], [0.055], [0.06], [0.065], [0.07], [0.075], [0.08], [0.085], [0.09], [0.095], [0.1], [0.105], [0.11], [0.115], [0.12], [0.125], [0.13], [0.135], [0.14], [0.145], [0.15], [0.155], [0.16], [0.165], [0.17], [0.175], [0.18], [0.185], [0.19], [0.195], [0.2], [0.205], [0.21], [0.215], [0.22], [0.225], [0.23], [0.235], [0.24], [0.245], [0.25], [0.255], [0.26], [0.265], [0.27], [0.275], [0.28], [0.285], [0.29], [0.295], [0.3], [0.305], [0.31], [0.315], [0.32], [0.325], [0.33], [0.335], [0.34], [0.345], [0.35], [0.355], [0.36], [0.365], [0.37], [0.375], [0.38], [0.385], [0.39], [0.395], [0.4], [0.405], [0.41], [0.415], [0.42], [0.425], [0.43], [0.435], [0.44], [0.445], [0.45], [0.455], [0.46], [0.465], [0.47], [0.475], [0.48], [0.485], [0.49], [0.495], [0.5], [0.505], [0.51], [0.515], [0.52], [0.525], [0.53], [0.535], [0.54], [0.545], [0.55], [0.555], [0.56], [0.565], [0.57], [0.575], [0.58], [0.585], [0.59], [0.595], [0.6], [0.605], [0.61], [0.615], [0.62], [0.625], [0.63], [0.635], [0.64], [0.645], [0.65], [0.655], [0.66], [0.665], [0.67], [0.675], [0.68], [0.685], [0.69], [0.695], [0.7], [0.705], [0.71], [0.715], [0.72], [0.725], [0.73], [0.735], [0.74], [0.745], [0.75], [0.755], [0.76], [0.765], [0.77], [0.775], [0.78], [0.785], [0.79], [0.795], [0.8], [0.805], [0.81], [0.815], [0.82], [0.825], [0.83], [0.835], [0.84], [0.845], [0.85], [0.855], [0.86], [0.865], [0.87], [0.875], [0.88], [0.885], [0.89], [0.895], [0.9], [0.905], [0.91], [0.915], [0.92], [0.925], [0.93], [0.935], [0.94], [0.945], [0.95], [0.955], [0.96], [0.965], [0.97], [0.975], [0.98], [0.985], [0.99], [0.995], [1.]], dtype=float32), t_eval: list = [0.1, 0.5, 0.9], TV_csv_file: list[str] = ['', '', ''], gif_enabled: bool = False, gif_dir: str = None, gif_freq: int = None, gif_duration_ms: int = 300, gif_cleanup_tmp: bool = True)[source]
Bases:
VisualizationCallbackCustom visualization callback for 1D transient arc PINN training with temperature and radial velocity.
- This callback provides comprehensive monitoring of the coupled transient arc model training:
Real-time TensorBoard logging during training
Side-by-side comparison of temperature and velocity at multiple time steps
Training history tracking for post-training animation
Reference data comparison with error metrics (when CSV files provided)
Material property evolution monitoring over time
Loss convergence tracking with logarithmic scale visualization
- save_final_results(network: Module, save_dir: str = None, epoch: int = None, **kwargs)[source]
Save final result figures showing temperature evolution and loss curve.
Outputs: - final_panels_transient.png: multi-panel figure with temperature at different times and loss curve - loss_curve_transient.png: standalone loss curve plot - center_temp_evolution.png: center temperature evolution over time at different epochs
2. Parameters:
- networknn.Module
Trained network for final prediction
- save_dirstr, optional
Output directory; defaults to gif_dir
- epochint, optional
Epoch number for title; if None, uses last recorded or 0
- kwargsdict
Optional training info
- save_gif(gif_path: str = None, duration_ms: int = None, loop: int = 0)[source]
Assemble saved frames into a GIF showing transient temperature evolution and loss.
2. Parameters:
- gif_pathstr, optional
Output path for the GIF; defaults to <gif_dir>/training_animation_transient.gif
- duration_msint, optional
Per-frame duration in milliseconds; defaults to gif_duration_ms
- loopint, default=0
Number of loops (0 = infinite)
- visualize(network, epoch: int, writer: SummaryWriter, **kwargs)[source]
Generate visualization plots for the current training epoch.
2. Parameters:
- networknn.Module
The neural network being trained
- epochint
Current training epoch number
- writerSummaryWriter
TensorBoard writer for logging figures
- kwargsdict
Additional training information (e.g., ‘total_loss’)
2. Returns:
- dict
Dictionary mapping plot names to matplotlib figures
- ai4plasma.piml.cs_pinn.calc_GL_coefs(degree)[source]
Calculate Gauss-Legendre quadrature coefficients for arc conductance integral.
- Parameters:
degree (int) – Degree of Gauss-Legendre quadrature (number of quadrature points)
- Returns:
Xq (torch.Tensor) – Abscissae (quadrature points) mapped to [0, 1], shape (degree, 1)
Wq (torch.Tensor) – Quadrature weights normalized to [0, 1], shape (degree, 1)
- ai4plasma.piml.cs_pinn.get_TVfunc_from_file(csv_file)[source]
Load reference temperature and velocity profiles from CSV file for comparison.
- Parameters:
csv_file (str) – Path to CSV file containing reference temperature and velocity data. The CSV should have columns: ‘r(m)’ (radius in meters), ‘T(K)’ (temperature in Kelvin), and ‘V(m/s)’ (velocity in m/s).
- Returns:
T_spline (function) – Cubic spline interpolation function for temperature at radius r
V_spline (function) – Cubic spline interpolation function for velocity at radius r
- Raises:
FileNotFoundError – If csv_file does not exist.
- ai4plasma.piml.cs_pinn.get_Tfunc_from_file(csv_file)[source]
Load reference temperature profile from CSV file for comparison.
- Parameters:
csv_file (str) – Path to CSV file containing reference temperature data. The CSV should have columns: ‘r(m)’ (radius in meters) and ‘T(K)’ (temperature in Kelvin).
- Returns:
A cubic spline interpolation function T_spline(r) that returns interpolated temperature at radius r.
- Return type:
function
- Raises:
FileNotFoundError – If csv_file does not exist.
2.4. RK-PINN (Runge-Kutta PINN)
RK-PINN models for plasma simulations in AI4Plasma.
This module implements a Runge-Kutta Physics-Informed Neural Network (RK-PINN) for solving 1D corona discharge problems. It extends standard PINN methodology with implicit Runge-Kutta time integration schemes for improved accuracy and stability in temporal evolution of corona discharge phenomena.
2.4.1. RK-PINN Classes
Corona1DRKNet: Neural network with built-in boundary condition enforcement.
Corona1DRKModel: PINN model for corona discharge using RK time stepping.
Corona1DRKVisCallback: Visualization callback with multi-panel plots.
2.4.2. RK-PINN References
- [1] L. Zhong, B. Wu, and Y. Wang, “Low-temperature plasma simulation
based on physics-informed neural networks: Frameworks and preliminary applications,” Physics of Fluids, vol. 34, no. 8, p. 087116, 2022.
- class ai4plasma.piml.rk_pinn.Corona1DRKModel(R, T, P, V0, dt, Ne_init_func=None, Np_func=None, N_red=1000000000000000.0, t_red=5e-09, V_red=10000.0, gamma=0.066, train_data_size=500, sample_mode='uniform', q=50, backbone_net=FNN( (act_fun): Tanh() (net): Sequential( (linear1): Linear(in_features=1, out_features=300, bias=True) (activation1): Tanh() (linear2): Linear(in_features=300, out_features=300, bias=True) (activation2): Tanh() (linear3): Linear(in_features=300, out_features=300, bias=True) (activation3): Tanh() (linear4): Linear(in_features=300, out_features=300, bias=True) (activation4): Tanh() (linear5): Linear(in_features=300, out_features=102, bias=True) ) ), prop: CoronaPropSpline = None)[source]
Bases:
PINNPINN model for solving 1D corona discharge equations using implicit Runge-Kutta time integration.
This class implements a Physics-Informed Neural Network with implicit Runge-Kutta (RK) temporal discretization for corona discharge simulations. Unlike standard PINNs that typically use automatic differentiation for temporal derivatives, RK-PINN discretizes time explicitly using RK formulas, enabling better control over temporal accuracy and stability.
The model solves coupled nonlinear equations for electric potential (Φ) and electron density (Ne) in a cylindrical corona discharge geometry, with physics-based loss functions that enforce both the governing PDEs and boundary conditions.
- Parameters:
R (float) – Domain radius [m].
T (float) – Gas temperature [K].
P (float) – Gas pressure [Pa].
V0 (float) – Applied voltage at electrode [V].
dt (float) – Normalized time step (Δt_norm).
Ne_init_func (callable, optional) – Initial condition function for electron density Ne(r). Required.
Np_func (callable, optional) – Positive ion density function Np(r). Required.
N_red (float, default=1e15) – Electron density reduction factor [m⁻³].
t_red (float, default=5e-9) – Time reduction factor [s].
V_red (float, default=10e3) – Voltage reduction factor [V].
gamma (float, default=0.066) – Secondary electron emission coefficient [dimensionless].
train_data_size (int, default=500) – Number of training collocation points.
sample_mode ({'uniform', 'lhs', 'random'}, default='uniform') – Sampling strategy for collocation points.
q (int, default=50) – Order of implicit Runge-Kutta method (stages = q+1).
backbone_net (nn.Module, default=FNN([1, 300, 300, 300, 300, 102])) – Neural network architecture.
prop (CoronaPropSpline, optional) – Material property object (transport coefficients).
- class ai4plasma.piml.rk_pinn.Corona1DRKNet(network, q, R=1.0, V0=None)[source]
Bases:
ModuleNeural network wrapper for solving 1D corona discharge with automatic boundary condition enforcement for coupled potential and electron density fields.
This network automatically satisfies boundary conditions at the domain edges by construction, reducing training complexity and improving physical consistency.
- Parameters:
network (nn.Module) – Backbone neural network (e.g., FNN) that maps r → [N₁, N₂]. Input shape: [batch_size, 1] representing radial coordinate r. Output shape: [batch_size, 2*(q+1)] representing [Φ stages, Ne stages].
q (int) – Order of the implicit Runge-Kutta method. Determines number of RK stages: q+1. Total network outputs: 2*(q+1) [q+1 for Φ, q+1 for Ne].
R (float, optional) – Normalized domain radius (default: 1.0). For physical radius Rphys meters, normalize as r_norm = r_phys / Rphys.
V0 (float, optional) – Normalized applied voltage at electrode r=0 (default: None). Physical voltage V₀phys is normalized as V_norm = V₀phys / V_red.
- forward(x)[source]
Forward pass through the network with boundary condition enforcement.
- Parameters:
x (torch.Tensor) – Radial coordinates, shape (batch_size, 1).
- Returns:
Phi (torch.Tensor) – Electric potential at all RK stages, shape (batch_size, q+1).
Ne (torch.Tensor) – Electron density at all RK stages, shape (batch_size, q+1).
- class ai4plasma.piml.rk_pinn.Corona1DRKVisCallback(model: Corona1DRKModel, log_freq: int = 50, save_history: bool = True, history_freq: int = None, x_eval: ndarray = array([[0.], [0.005], [0.01], [0.015], [0.02], [0.025], [0.03], [0.035], [0.04], [0.045], [0.05], [0.055], [0.06], [0.065], [0.07], [0.075], [0.08], [0.085], [0.09], [0.095], [0.1], [0.105], [0.11], [0.115], [0.12], [0.125], [0.13], [0.135], [0.14], [0.145], [0.15], [0.155], [0.16], [0.165], [0.17], [0.175], [0.18], [0.185], [0.19], [0.195], [0.2], [0.205], [0.21], [0.215], [0.22], [0.225], [0.23], [0.235], [0.24], [0.245], [0.25], [0.255], [0.26], [0.265], [0.27], [0.275], [0.28], [0.285], [0.29], [0.295], [0.3], [0.305], [0.31], [0.315], [0.32], [0.325], [0.33], [0.335], [0.34], [0.345], [0.35], [0.355], [0.36], [0.365], [0.37], [0.375], [0.38], [0.385], [0.39], [0.395], [0.4], [0.405], [0.41], [0.415], [0.42], [0.425], [0.43], [0.435], [0.44], [0.445], [0.45], [0.455], [0.46], [0.465], [0.47], [0.475], [0.48], [0.485], [0.49], [0.495], [0.5], [0.505], [0.51], [0.515], [0.52], [0.525], [0.53], [0.535], [0.54], [0.545], [0.55], [0.555], [0.56], [0.565], [0.57], [0.575], [0.58], [0.585], [0.59], [0.595], [0.6], [0.605], [0.61], [0.615], [0.62], [0.625], [0.63], [0.635], [0.64], [0.645], [0.65], [0.655], [0.66], [0.665], [0.67], [0.675], [0.68], [0.685], [0.69], [0.695], [0.7], [0.705], [0.71], [0.715], [0.72], [0.725], [0.73], [0.735], [0.74], [0.745], [0.75], [0.755], [0.76], [0.765], [0.77], [0.775], [0.78], [0.785], [0.79], [0.795], [0.8], [0.805], [0.81], [0.815], [0.82], [0.825], [0.83], [0.835], [0.84], [0.845], [0.85], [0.855], [0.86], [0.865], [0.87], [0.875], [0.88], [0.885], [0.89], [0.895], [0.9], [0.905], [0.91], [0.915], [0.92], [0.925], [0.93], [0.935], [0.94], [0.945], [0.95], [0.955], [0.96], [0.965], [0.97], [0.975], [0.98], [0.985], [0.99], [0.995], [1.]], dtype=float32), corona_csv_file: str = None, gif_enabled: bool = False, gif_dir: str = None, gif_freq: int = None, gif_duration_ms: int = 300, gif_cleanup_tmp: bool = True)[source]
Bases:
VisualizationCallbackCustom visualization callback for 1D corona discharge RK-PINN model training.
This callback provides comprehensive real-time monitoring and post-training visualization capabilities for corona discharge simulations using the RK-PINN framework. It generates publication-quality figures showing electric potential (Φ), electron density (Ne) evolution, and training convergence metrics.
- Parameters:
model (Corona1DRKModel) – The corona discharge model instance. Provides access to geometry, physical parameters, and material properties.
log_freq (int, default=50) – Frequency (in epochs) for logging visualizations to TensorBoard. Example: log_freq=50 means visualize every 50 epochs.
save_history (bool, default=True) – Whether to save prediction snapshots for creating training animations. Set to False to save memory if animations are not needed.
history_freq (int, optional) – Frequency (in epochs) for saving history snapshots. If None, defaults to log_freq. Use larger values (e.g., 200) for very long training runs.
x_eval (np.ndarray, shape (n_r, 1), optional) – Radial evaluation grid for visualization (normalized radius 0→1). Default: 201 points linearly spaced from 0 to 1. Use finer grids for higher-resolution visualizations. Use coarser grids for faster evaluation.
corona_csv_file (str, optional) – Path to CSV file containing reference Φ and Ne profiles. CSV columns required: ‘r(cm)’, ‘U(V)’, ‘Ne(m^-3)’. If None, reference comparison is skipped. Enables error analysis and validation against experimental/reference data.
gif_enabled (bool, default=False) – Whether to save per-epoch frames and generate a training animation GIF. When True, PNG frames are automatically saved at gif_freq intervals and assembled into an animated GIF at training end. Useful for post-training analysis and presentations.
gif_dir (str, optional) – Output directory for GIF animation and final plots. If None, defaults to current working directory.
gif_freq (int, optional) – Frequency (in epochs) to save frames for GIF creation. If None, uses history_freq. Larger values (e.g., 500) produce shorter GIFs with fewer frames.
gif_duration_ms (int, default=300) – Duration per frame in milliseconds for the GIF animation.
gif_cleanup_tmp (bool, default=True) – Whether to automatically delete temporary PNG frames after GIF creation. Set to False to retain frames for manual re-assembly or inspection.
- save_final_results(network: Module, save_dir: str = None, epoch: int = None, **kwargs)[source]
Save final result figures showing potential and density distributions.
This method generates and saves publication-quality figures at the end of training. It demonstrates the trained model’s final predictions without further training dynamics.
- Parameters:
network (nn.Module) – Trained neural network for final prediction generation. Should be in evaluation mode (will be set by this method).
save_dir (str, optional) – Output directory for final figure files. If None, defaults to self.gif_dir. Directory is created if it doesn’t exist.
epoch (int, optional) – Epoch number to display in figure title. If None, uses the last recorded epoch from history. If no history, defaults to 0.
kwargs (dict) –
- Optional training information:
’total_loss’: Final loss value for annotation
- save_gif(gif_path: str = None, duration_ms: int = None, loop: int = 0)[source]
Assemble saved PNG frames into an animated GIF showing training progress.
This method combines temporary PNG frames collected during training into a single GIF animation. The animation shows how the Ne and Φ predictions evolve during training, along with the loss curve convergence.
- Parameters:
gif_path (str, optional) – Output path for the GIF animation file. If None, defaults to <gif_dir>/training_animation.gif. Directory is created if it doesn’t exist.
duration_ms (int, optional) – Per-frame duration in milliseconds for the GIF playback. If None, uses self.gif_duration_ms (default 300 ms).
loop (int, default=0) – Number of animation loops. 0 = infinite loop (gif continues cycling). n > 0 = gif cycles n times then stops.
- visualize(network, epoch: int, writer: SummaryWriter, **kwargs)[source]
Generate visualization plots for the current training epoch.
This method is the main visualization callback invoked automatically by the PINN.train() method at specified intervals (controlled by log_freq). It performs three main tasks: 1. Evaluates the network on a uniform radial grid 2. Saves predictions to history (for animation) 3. Creates multi-panel figure and saves for TensorBoard and GIF
- Parameters:
network (nn.Module) – The neural network being trained (Corona1DRKNet).
epoch (int) – Current training epoch number.
writer (SummaryWriter) – TensorBoard writer for logging figures and metrics.
kwargs (dict) –
- Additional training information:
’total_loss’: Total loss value (torch.Tensor or float)
’loss_dict’: Dictionary of individual loss terms
- Returns:
Dictionary mapping visualization names to matplotlib figures
- Return type:
dict
- ai4plasma.piml.rk_pinn.get_PhiNe_func_from_file(csv_file)[source]
Load reference potential and electron density profiles from CSV file for comparison.
This function reads experimental or reference simulation data from a CSV file containing electric potential (Φ) and electron density (Ne) profiles as functions of radius. The data is interpolated using cubic splines for smooth evaluation at arbitrary radial positions.
- Parameters:
csv_file (str) – Path to CSV file containing reference corona discharge data.
- Returns:
Phi_spline (scipy.interpolate.CubicSpline) – Returns interpolated potential (V) at radius r (m).
Ne_spline (scipy.interpolate.CubicSpline) – Returns interpolated electron density (m⁻³) at radius r (m).
- Raises:
FileNotFoundError – If the CSV file does not exist at the specified path.
KeyError – If required columns are missing from the CSV file.
- ai4plasma.piml.rk_pinn.load_butcher_table(q)[source]
Load Butcher tableau for implicit Runge-Kutta method.
This function loads pre-computed Butcher tableau coefficients for implicit Runge-Kutta methods from a .npy file. The file path is constructed relative to the location of this module file, ensuring correct loading regardless of the current working directory.
If the local file is not found, the function will attempt to download it from HuggingFace Datasets (repository: mathboylinlin/ai4plasma_butcher_table).
- Parameters:
q (int) – Order of the Runge-Kutta method.
- Returns:
Butcher tableau coefficients for the specified order, shape (q+1, q).
- Return type:
torch.Tensor
- Raises:
FileNotFoundError – If the Butcher table cannot be loaded locally and download fails.
ImportError – If huggingface_hub is not installed when download is attempted.
2.5. Meta-PINN (Meta-Learning PINN)
Meta-learning Physics-Informed Neural Networks (Meta-PINN).
This module implements Model-Agnostic Meta-Learning (MAML) for physics-informed neural networks. It provides a task abstraction, bi-level optimization pipeline, and training utilities for rapid adaptation to new physics tasks with limited data.
2.5.1. Meta-PINN Classes
MetaTask: Abstract task interface for meta-learning.
PINNTask: PINN-specific task implementation with equation term losses.
MetaPINN: MAML trainer for PINN task batches.
2.5.2. Meta-PINN References
- [1] L. Zhong, B. Wu, and Y. Wang, “Accelerating physics-informed neural network
based 1D arc simulation by meta learning,” Journal of Physics D: Applied Physics, vol. 56, p. 074006, 2023.
- class ai4plasma.piml.meta_pinn.MetaPINN(train_tasks: List[PINNTask], test_tasks: List[PINNTask] = None)[source]
Bases:
objectMeta-Learning Framework for Physics-Informed Neural Networks using MAML.
This class implements the Model-Agnostic Meta-Learning (MAML) algorithm specifically designed for physics-informed neural networks. It learns optimal network initializations across multiple related physics tasks, enabling rapid adaptation to new tasks with minimal fine-tuning.
- Parameters:
- meta_network
The meta-network whose parameters are meta-learned.
- Type:
nn.Module
- loss_func
Loss function for computing residuals (default: smooth_l1_loss).
- Type:
callable
- writer
TensorBoard writer for logging meta-training progress.
- Type:
SummaryWriter
- history
Training history storing meta-train losses.
- Type:
Dict
- visualization_callbacks
Registered callbacks for real-time visualization.
- Type:
Dict[str, VisualizationCallback]
- outer_epochs
Number of meta-training iterations completed.
- Type:
int
- inner_epochs
Number of adaptation steps per task (inner loop).
- Type:
int
- outer_lr
Meta-learning rate (outer loop, typically 1e-4 to 1e-3).
- Type:
float
- inner_lr
Task adaptation learning rate (inner loop, typically 1e-5 to 1e-3).
- Type:
float
- beta1, beta2
Adam optimizer momentum parameters (default 0.9, 0.999).
- Type:
float
- epsilon
Adam optimizer numerical stability constant (default 1e-8).
- Type:
float
- load_meta_model(checkpoint_path: str)[source]
Load meta-model from checkpoint for resuming training.
This method restores the meta-network parameters and training state from a saved checkpoint, enabling interrupted meta-training to resume seamlessly from the last saved epoch.
- Parameters:
checkpoint_path (str) – Path to the checkpoint file (.pth). Should contain: - meta_network_state_dict: Learned meta-parameters - outer_epochs: Last completed meta-training epoch - inner_epochs: Inner loop iteration count - outer_lr, inner_lr: Learning rates - beta1, beta2, epsilon: Adam optimizer hyperparameters
- meta_test(test_tasks: List[PINNTask], results_dir: str = None)[source]
Evaluate meta-learned initialization on new test tasks.
This method demonstrates the meta-learning benefit by initializing the network with learned meta-parameters and performing few-shot adaptation on new tasks with minimal training data.
- Parameters:
test_tasks (List[PINNTask]) – List of new tasks for meta-testing. Each task should have support data defined for few-shot adaptation.
results_dir (str, optional) – Directory to save adapted model checkpoints. One checkpoint per task will be saved if provided.
- meta_train(outer_epochs: int, inner_epochs: int = 5, outer_lr: float = 0.0001, inner_lr: float = 1e-05, beta1: float = 0.9, beta2: float = 0.999, epsilon: float = 1e-08, print_freq: int = 10, tensorboard_logdir: str = None, log_freq: int = 50, checkpoint_dir: str = None, checkpoint_freq: int = 100, load_from_checkpoint: str = None)[source]
Execute meta-training using Model-Agnostic Meta-Learning (MAML).
This method implements the bi-level optimization algorithm of MAML: - Inner Loop: Fast adaptation to individual tasks using gradient descent - Outer Loop: Meta-parameter update using accumulated query losses
The goal is to learn an initialization that enables rapid adaptation to new physics tasks with minimal fine-tuning. After meta-training, the learned initialization can be applied to unseen tasks for few-shot learning.
- Parameters:
outer_epochs (int) – Number of meta-training iterations (outer loop). Typical range: 500-5000 depending on task complexity.
inner_epochs (int, default=5) – Number of gradient steps for task adaptation (inner loop). Typical range: 1-20 steps.
outer_lr (float, default=1e-4) – Meta-learning rate for outer loop. Controls update speed of meta-parameters. Typical range: 1e-5 to 1e-3.
inner_lr (float, default=1e-5) – Task adaptation learning rate for inner loop. Typical range: 1e-6 to 1e-3.
beta1 (float, default=0.9) – Adam momentum parameter for first moment estimation.
beta2 (float, default=0.999) – Adam momentum parameter for second moment estimation.
epsilon (float, default=1e-8) – Adam numerical stability constant.
print_freq (int, default=10) – Print training progress every print_freq epochs.
tensorboard_logdir (str, optional) – Directory for TensorBoard logging. If None, logging is disabled.
log_freq (int, default=50) – Log metrics to TensorBoard every log_freq epochs.
checkpoint_dir (str, optional) – Directory to save checkpoints during training. If None, no saves.
checkpoint_freq (int, default=100) – Save checkpoint every checkpoint_freq epochs.
load_from_checkpoint (str, optional) – Path to checkpoint file for resuming training.
- register_visualization_callback(callback: VisualizationCallback)[source]
Register a visualization callback for training monitoring.
- Parameters:
callback (VisualizationCallback) – Callback instance to register (from pinn.VisualizationCallback).
- save_meta_model(epoch: int, checkpoint_path: str)[source]
Save meta-model checkpoint for later resumption.
This method saves the current meta-network parameters and training state to disk, allowing meta-training to be resumed from this point if interrupted. Checkpoints are typically saved at regular intervals during training.
- Parameters:
epoch (int) – Current meta-training epoch number.
checkpoint_path (str) – Path where to save the checkpoint file.
- class ai4plasma.piml.meta_pinn.MetaStaArc1DModel(R, I, Tb=300.0, T_red=10000.0, backbone_net=FNN( (act_fun): Tanh() (net): Sequential( (linear1): Linear(in_features=1, out_features=100, bias=True) (activation1): Tanh() (linear2): Linear(in_features=100, out_features=100, bias=True) (activation2): Tanh() (linear3): Linear(in_features=100, out_features=100, bias=True) (activation3): Tanh() (linear4): Linear(in_features=100, out_features=100, bias=True) (activation4): Tanh() (linear5): Linear(in_features=100, out_features=1, bias=True) ) ), train_data_size=500, test_data_size=501, sample_mode='uniform', GL_degree=100, prop: ArcPropSpline = None)[source]
Bases:
PINNPINN model for 1D stationary arc adapted for meta-learning.
This class implements a specialized PINN for solving the 1D steady-state arc plasma equation in a meta-learning context. Unlike the standard StaArc1DModel, this version is designed as a task within the MetaPINN framework for learning across multiple arc configurations.
- Parameters:
R (float) – Arc radius in meters (e.g., 0.01 for 10mm).
I (float) – Arc current in amperes (e.g., 100, 150, 200).
Tb (float, default=300.0) – Boundary temperature at r=R in Kelvin.
T_red (float, default=1e4) – Temperature normalization constant (typically 10000K).
backbone_net (nn.Module, default=FNN([1,100,100,100,100,1])) – Backbone neural network for temperature prediction.
train_data_size (int, default=500) – Number of collocation points for training.
test_data_size (int, default=501) – Number of points for evaluation/prediction.
sample_mode (str, default='uniform') – Sampling strategy (‘uniform’, ‘random’, or ‘lhs’).
GL_degree (int, default=100) – Degree of Gauss-Legendre quadrature for arc conductance integration.
prop (ArcPropSpline) – Material property splines (κ, σ, ε_nec as functions of T).
- predict(input_data: Tensor) Tensor[source]
Make temperature predictions using the trained/adapted network.
This method applies the trained meta-network (after adaptation) to predict temperature distributions at given radial locations. The boundary condition transformation is applied to ensure T(R) = Tb.
- Parameters:
input_data (torch.Tensor) – Normalized radial coordinates (0 to 1), shape [N, 1].
- Returns:
Normalized temperature predictions, shape [N, 1]. Physical temperature: T_physical = output * self.T_red
- Return type:
torch.Tensor
- class ai4plasma.piml.meta_pinn.MetaStaArc1DNet(network)[source]
Bases:
ModuleNeural network wrapper for meta-learning 1D stationary arc plasma problems.
- Parameters:
network (nn.Module) – Backbone neural network (e.g., FNN) that maps r → N(r). Input shape: [batch_size, 1] (normalized radius) Output shape: [batch_size, 1] (log temperature)
- class ai4plasma.piml.meta_pinn.MetaTask(task_id: str, support_data: Dict[str, Tensor] = None, query_data: Dict[str, Tensor] = None)[source]
Bases:
ABCAbstract base class for defining meta-learning tasks.
A meta-learning task encapsulates a specific problem instance within a broader family of related problems. In the MAML framework, each task contains support and query sets used for task-specific adaptation and meta-parameter updates.
This abstraction allows the meta-learning framework to handle diverse physics problems in a unified manner by implementing the compute_loss method. It follows the Strategy design pattern, enabling different task types to be plugged into the MetaPINN framework without modifying core algorithms.
- Parameters:
task_id (str) – Unique identifier for this task (e.g., ‘Arc_I=200A’, ‘Poisson_k=1.5’)
support_data (Dict[str, torch.Tensor], optional) – Dictionary mapping data names to support set tensors. Support data is used for task-specific adaptation during the inner loop. Default is empty dict. Example: {‘Domain’: x_domain, ‘Boundary’: x_boundary}
query_data (Dict[str, torch.Tensor], optional) – Dictionary mapping data names to query set tensors. Query data is disjoint from support data and used for meta-parameter updates. Default is empty dict. Same structure as support_data but with different samples.
- task_id
Unique identifier for this task instance.
- Type:
str
- support_data
Support set for inner loop training (few-shot adaptation).
- Type:
Dict[str, torch.Tensor]
- query_data
Query set for outer loop meta-updates (meta-gradient computation).
- Type:
Dict[str, torch.Tensor]
- abstractmethod compute_loss(network: Module, data_dict: Dict[str, Tensor]) Tuple[Tensor, Dict[str, Tensor]][source]
Compute total loss for this task on given data.
This abstract method defines how to evaluate the network’s performance on task data. It is called during both inner loop adaptation (support set) and outer loop meta-updates (query set). Subclasses must implement this method.
- Parameters:
network (nn.Module) – Neural network to evaluate. May be adapted in inner loop.
data_dict (Dict[str, torch.Tensor]) – Dictionary of data tensors for different equation terms. Keys match equation term names (e.g., ‘Domain’, ‘Boundary’). Values are collocation points or boundary data.
- Returns:
total_loss (torch.Tensor) – Scalar tensor representing the total weighted loss.
loss_dict (Dict[str, torch.Tensor]) – Dictionary mapping equation term names to individual losses. Useful for monitoring training and statistical analysis.
- get_query_data() Dict[str, Tensor][source]
Get query set data for meta-validation.
The query set is used to evaluate the adapted network and compute gradients for meta-parameter updates in the outer loop. It must be disjoint from the support set to ensure unbiased meta-learning.
- Returns:
Dictionary mapping equation term names to query data tensors.
- Return type:
Dict[str, torch.Tensor]
- get_support_data() Dict[str, Tensor][source]
Get support set data for inner loop training.
The support set is used to adapt the meta-initialized network to a specific task during the inner loop of MAML. Typically contains fewer samples than traditional PINN training.
- Returns:
Dictionary mapping equation term names to support data tensors.
- Return type:
Dict[str, torch.Tensor]
- class ai4plasma.piml.meta_pinn.PINNTask(task_id: str, pinn_model: PINN | None = None, support_data: Dict[str, Tensor] = None, query_data: Dict[str, Tensor] = None)[source]
Bases:
MetaTaskPINN-specific meta-learning task implementation.
This class bridges the PINN framework with the meta-learning paradigm by wrapping a PINN model and its equation terms into a task suitable for MAML. It implements the compute_loss method by iterating over all equation terms defined in the PINN.
- Parameters:
task_id (str) – Unique identifier for this physics task.
pinn_model (PINN, optional) – PINN model instance containing equation definitions. Must have defined equation terms via add_equation().
support_data (Dict[str, torch.Tensor], optional) – Support set collocation points for inner loop adaptation.
query_data (Dict[str, torch.Tensor], optional) – Query set collocation points for meta-update evaluation.
- loss_func
Loss function for comparing residuals to zero (e.g., MSE, smooth L1).
- Type:
callable
- compute_loss(network: Module, data_dict: Dict[str, Tensor]) Tuple[Tensor, Dict[str, Tensor]][source]
Compute total weighted loss for this PINN task.
- Parameters:
network (nn.Module) – Neural network to evaluate (may be adapted in inner loop).
data_dict (Dict[str, torch.Tensor]) – Dictionary mapping equation term names to data tensors. Example: {‘Domain’: x_pde, ‘Boundary’: x_bc, ‘Initial’: x_ic}
- Returns:
total_loss (torch.Tensor) – Scalar weighted sum of all equation term losses.
loss_dict (Dict[str, torch.Tensor]) – Dictionary of individual losses for monitoring. Keys: equation term names, Values: unweighted loss tensors.
- class ai4plasma.piml.meta_pinn.StaArc1DTask(task_id: str, R: float, I: float, Tb: float = 300.0, T_red: float = 10000.0, backbone_net: Module = FNN( (act_fun): Tanh() (net): Sequential( (linear1): Linear(in_features=1, out_features=100, bias=True) (activation1): Tanh() (linear2): Linear(in_features=100, out_features=100, bias=True) (activation2): Tanh() (linear3): Linear(in_features=100, out_features=100, bias=True) (activation3): Tanh() (linear4): Linear(in_features=100, out_features=100, bias=True) (activation4): Tanh() (linear5): Linear(in_features=100, out_features=1, bias=True) ) ), thermo_file: str = None, nec_file: str = None, support_data_size: int = 500, query_data_size: int = 400, sample_mode: str = 'uniform')[source]
Bases:
PINNTaskTask wrapper for 1D stationary arc problem in meta-learning framework.
This class encapsulates a specific arc discharge configuration as a meta-learning task. It creates the underlying PINN model, samples support and query sets, and provides the interface required by MetaPINN for meta-training and meta-testing.
2. Task Definition
Each task corresponds to a unique arc discharge problem instance: - Physical parameters: Arc radius R, current I, boundary temp Tb - Material properties: Plasma gas type (SF6, Ar, etc.) - Support set: Small set of collocation points for adaptation - Query set: Separate set for meta-validation
- param task_id:
Unique identifier for this arc configuration (e.g., ‘Arc_I=200A_R=10mm’).
- type task_id:
str
- param R:
Arc radius in meters.
- type R:
float
- param I:
Arc current in amperes.
- type I:
float
- param Tb:
Boundary temperature at r=R in Kelvin.
- type Tb:
float, default=300.0
- param T_red:
Temperature normalization constant.
- type T_red:
float, default=1e4
- param backbone_net:
Backbone neural network architecture.
- type backbone_net:
nn.Module, default=FNN([1,100,100,100,100,1])
- param thermo_file:
Path to thermodynamic properties file (κ, Cp, ρ vs T).
- type thermo_file:
str
- param nec_file:
Path to net emission coefficient file (ε_nec vs T).
- type nec_file:
str
- param support_data_size:
Number of collocation points in support set.
- type support_data_size:
int, default=500
- param query_data_size:
Number of collocation points in query set.
- type query_data_size:
int, default=400
- param sample_mode:
Sampling strategy (‘uniform’, ‘random’, ‘lhs’).
- type sample_mode:
str, default=’uniform’
- R
Arc radius (stored for reference).
- Type:
float
- pinn_model
Underlying PINN model for this task.
- Type:
- support_data
Support set collocation points {‘Domain’: x_spt, ‘Boundary’: x_bc}.
- Type:
Dict[str, torch.Tensor]
- query_data
Query set collocation points {‘Domain’: x_qry, ‘Boundary’: x_bc}.
- Type:
Dict[str, torch.Tensor]
2.6. NAS-PINN (Neural Architecture Search for PINN)
Neural Architecture Search for Physics-Informed Neural Networks (NAS-PINN).
This module implements the NAS-PINN framework for automatically searching the optimal architecture of Physics-Informed Neural Networks (PINNs) to solve partial differential equations (PDEs). It uses a differentiable architecture search method within a relaxed search space to find network architectures that balance accuracy and computational efficiency.
2.6.1. NAS-PINN Classes
NasPINN: Main class implementing the architecture search framework.
2.6.2. NAS-PINN References
- [1] Y. Wang, L. Zhong, “NAS-PINN: Neural architecture search-guided physics-informed neural
network for solving PDEs,” Journal of Computational Physics, vol. 496, p. 112603, 2024.
- class ai4plasma.piml.nas_pinn.NasPINN(pinn_model: PINN)[source]
Bases:
objectNeural Architecture Search for Physics-Informed Neural Networks (NAS-PINN).
This class implements the NAS-PINN framework for automated architecture search in Physics-Informed Neural Networks (PINNs). It performs differentiable architecture search to find the optimal network architecture for solving partial differential equations (PDEs) within a given search space. The framework combines bi-level optimization: inner loop for weight adaptation and outer loop for architecture parameter optimization.
- pinn_model
The PINN model instance with relaxed FNN structure and searchable architecture parameters.
- Type:
- writer
TensorBoard writer for logging training metrics. Initialized during search process.
- Type:
SummaryWriter, optional
- history
Dictionary storing training history, including ‘search_loss’ trajectory.
- Type:
dict
- visualization_callback
Dictionary storing visualization callbacks indexed by name.
- Type:
dict
- last_outer_epochs
Number of completed outer loop iterations (useful for resuming training).
- Type:
int
- outer_epochs
Target total number of outer loop iterations.
- Type:
int
- inner_epochs
Number of inner loop iterations per outer loop step.
- Type:
int
- outer_opt
Optimizer for outer loop (architecture parameter updates).
- Type:
torch.optim.Optimizer
- inner_opt
Optimizer for inner loop (network weight updates).
- Type:
torch.optim.Optimizer
- load_nas_model(checkpoint_path: str)[source]
Load NAS-PINN model and training state from checkpoint.
Restores the NAS parameters, architecture parameters, and optimizer states from a previously saved checkpoint file. This enables resuming interrupted training from the exact point where it was saved.
- Parameters:
checkpoint_path (str) – Path to the checkpoint file (.pth format).
- save_nas_model(epoch: int, checkpoint_path: str)[source]
Save NAS-PINN model and training state to checkpoint.
Saves the current network parameters, architecture parameters, and optimizer states to enable training resumption. Checkpoints are essential for long-running architecture searches that may be interrupted.
- Parameters:
epoch (int) – Current epoch number in the outer loop (for tracking and logging progress).
checkpoint_path (str) – Path where the checkpoint file should be saved (.pth format).
- search(outer_epochs: int, inner_epochs: int, outer_opt: Optimizer = None, inner_opt: Optimizer = None, print_freq: int = 10, tensorboard_logdir: str = None, log_freq: int = 50, checkpoint_dir: str = None, checkpoint_freq: int = 100, load_from_checkpoint: str = None, final_model_path: str = None)[source]
Search for optimal architecture parameters using differentiable search.
Executes the NAS-PINN algorithm with bi-level optimization:
Inner Loop: Updates network weights using calc_loss() with fixed architecture parameters (g).
Outer Loop: Updates architecture parameters using calc_loss_archi() with the adapted network weights.
- Parameters:
outer_epochs (int) – Number of search iterations (outer loop). Typical range depends on PDE complexity: 500-500,000 for different problems.
inner_epochs (int) – Number of gradient steps for weight adaptation per outer epoch (inner loop). Typical range: 1-20 steps, controls inner loop optimization depth.
outer_opt (torch.optim.Optimizer, optional) – Optimizer for outer loop architecture parameter updates. If None, defaults to Adam(lr=1e-5).
inner_opt (torch.optim.Optimizer, optional) – Optimizer for inner loop weight updates. If None, defaults to Adam(lr=1e-4).
print_freq (int, default=10) – Print training loss statistics every print_freq epochs to console.
tensorboard_logdir (str, optional) – Directory for TensorBoard event logs. If None, TensorBoard logging is disabled. Create logs at specified interval (log_freq) for performance monitoring.
log_freq (int, default=50) – Log ‘Loss’ and ‘Loss-archi’ metrics to TensorBoard every log_freq epochs.
checkpoint_dir (str, optional) – Directory to save periodic checkpoints. If None, no checkpoints are saved. Directory is created if it doesn’t exist.
checkpoint_freq (int, default=100) – Save checkpoint every checkpoint_freq epochs for training resumption.
load_from_checkpoint (str, optional) – Path to checkpoint file for resuming interrupted training. If provided, restores network parameters, architecture parameters, optimizer states, and training epoch count from checkpoint.
final_model_path (str, optional) – Path to save the final model after completing architecture search. If provided, the best model (final network state with architecture parameters) will be saved at this location.
- Returns:
Training history is stored in self.history[‘search_loss’]. Network and architecture parameters are updated in-place. Final network architecture can be extracted via self.pinn_model.network.searched_neuron().
- Return type: