"""CS-PINN models for plasma simulations in AI4Plasma.
This module provides specialized Physics-Informed Neural Network (PINN) implementations
and visualization tools for simulating plasma discharge using cubic spline interpolation
and automatic differentiation on PyTorch.
CS-PINN Classes
---------------
The CS-PINN (Coefficient-Subnet Physics-Informed Neural Network) framework integrates:
- `StaArc1DModel`: PINN model for steady-state 1D arc plasma.
- `TraArc1DTempModel`: PINN model for transient 1D arc plasma without radial velocity.
- `TraArc1DModel`: PINN model for transient 1D arc plasma with radial velocity.
- `StaArc1DVisCallback`: Visualization callback for steady-state arc plasma simulation.
- `TraArc1DTempVisCallback`: Visualization callback for transient arc plasma simulation.
CS-PINN References
------------------
[1] L. Zhong, B. Wu, and Y. Wang, "Low-temperature plasma simulation based on
physics-informed neural networks: Frameworks and preliminary applications,"
Physics of Fluids, vol. 34, no. 8, p. 087116, 2022.
[2] L. Zhong, Q. Gu, and B. Wu, "Deep learning for thermal plasma simulation:
Solving 1-D arc model as an example," Computer Physics Communications,
vol. 257, p. 107496, 2020.
"""
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate as intp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from ai4plasma.core.network import FNN
from ai4plasma.config import REAL
from ai4plasma.utils.common import numpy2torch
from ai4plasma.utils.math import df_dX, calc_relative_l2_err
from ai4plasma.plasma.prop import ArcPropSpline
from ai4plasma.utils.io import img2gif
from .geo import Geo1D, Geo1DTime
from .pinn import PINN, VisualizationCallback
[docs]
class StaArc1DNet(nn.Module):
"""
Neural network wrapper for solving 1D stationary arc equation with
automatic boundary condition enforcement.
Attributes
----------
network : nn.Module
Backbone neural network (e.g., FNN) that maps r → N(r)
R : float
Normalized arc radius (default: 1.0)
Tb : float
Normalized boundary temperature at r = R (default: 0.03)
Parameters (Constructor)
------------------------
network : nn.Module
Backbone neural network (e.g., FNN) that maps r → N(r)
R : float, optional
Normalized arc radius (default: 1.0)
Tb : float, optional
Normalized boundary temperature at r = R (default: 0.03)
"""
def __init__(self, network, R=1.0, Tb=0.03):
super(StaArc1DNet, self).__init__()
self.network = network
self.R = R # Reduced arc radius
self.Tb = Tb # Reduced boundary temperature at r=R
[docs]
def forward(self, x):
out = self.network(x)
out = (x - self.R)*out + self.Tb # Enforce boundary condition at r=R
return out
[docs]
def calc_GL_coefs(degree):
"""
Calculate Gauss-Legendre quadrature coefficients for arc conductance integral.
Parameters
----------
degree : int
Degree of Gauss-Legendre quadrature (number of quadrature points)
Returns
-------
Xq : torch.Tensor
Abscissae (quadrature points) mapped to [0, 1], shape (degree, 1)
Wq : torch.Tensor
Quadrature weights normalized to [0, 1], shape (degree, 1)
"""
quad_x, quad_w = np.polynomial.legendre.leggauss(degree)
quad_x, quad_w = quad_x.reshape((-1,1)).astype(REAL()), quad_w.reshape((-1,1)).astype(REAL())
Xq = numpy2torch(quad_x*0.5 + 0.5, require_grad=False)
Wq = numpy2torch(quad_w, require_grad=False)
return Xq, Wq
[docs]
def get_Tfunc_from_file(csv_file):
"""
Load reference temperature profile from CSV file for comparison.
Parameters
----------
csv_file : str
Path to CSV file containing reference temperature data.
The CSV should have columns: 'r(m)' (radius in meters) and
'T(K)' (temperature in Kelvin).
Returns
-------
function
A cubic spline interpolation function T_spline(r) that returns
interpolated temperature at radius r.
Raises
------
FileNotFoundError
If csv_file does not exist.
"""
# Check file existence
if not os.path.exists(csv_file):
raise FileNotFoundError(f"CSV file not found: {csv_file}")
df = pd.read_csv(csv_file)
r_data = df['r(m)'].values.astype(REAL())
T_data = df['T(K)'].values.astype(REAL())
T_spline = intp.CubicSpline(r_data, T_data, extrapolate=True)
return T_spline
[docs]
class StaArc1DModel(PINN):
"""
PINN model for solving 1D steady-state arc discharge energy equation.
Implements a Physics-Informed Neural Network specifically designed for arc
plasma simulations. Solves the Elenbaas-Heller equation considering Joule
heating, thermal conduction, and radiation losses. The model enforces
automatic boundary conditions through the StaArc1DNet architecture.
Attributes
----------
R : float
Arc radius [m]
I : float
Arc current [A]
T_red : float
Temperature reduction factor for normalization [K]
Tb : float
Boundary temperature at r = R [K]
Xq : torch.Tensor
Gauss-Legendre quadrature abscissae for arc conductance integral
Wq : torch.Tensor
Gauss-Legendre quadrature weights
geo : Geo1D
Geometry object for domain and boundary sampling
prop : ArcPropSpline
Material properties interpolation object
Parameters (Constructor)
------------------------
R : float
Arc radius [m]. Normalized to 1.0 in the network.
I : float
Arc current [A]. Used to compute electric field and Joule heating.
Tb : float, optional
Boundary temperature at r = R [K] (default: 300.0).
Set to ambient temperature at the arc boundary.
T_red : float, optional
Temperature reduction factor for normalization [K] (default: 1e4).
Used for non-dimensionalization: T_normalized = T_physical / T_red
backbone_net : nn.Module, optional
Backbone neural network that outputs the unscaled network function.
Default: FNN with architecture [1, 100, 100, 100, 100, 1].
Will be wrapped in StaArc1DNet for boundary condition enforcement.
train_data_size : int, optional
Number of training collocation points in the domain (default: 500).
Collocation points are sampled according to sample_mode.
test_data_size : int, optional
Number of test/evaluation points for visualization (default: 501).
sample_mode : str, optional
Collocation point sampling strategy (default: 'uniform').
Options: 'uniform' (uniform grid), 'lhs' (Latin hypercube), 'random'
GL_degree : int, optional
Degree of Gauss-Legendre quadrature for arc conductance integral
(default: 100). Higher degree provides better accuracy for integral computation.
prop : ArcPropSpline, optional
Arc material properties object (ArcPropSpline instance) for temperature-
dependent properties like κ, σ, ε_nec. If None, properties must be provided
externally.
"""
def __init__(
self,
R,
I,
Tb=300.0,
T_red=1e4,
backbone_net=FNN(layers=[1, 100, 100, 100, 100, 1]),
train_data_size=500,
test_data_size=501,
sample_mode='uniform',
GL_degree=100,
prop:ArcPropSpline=None,
):
self.R = R
self.I = I
self.T_red = T_red
self.Tb = Tb
self.train_data_size = train_data_size
self.test_data_size = test_data_size
self.sample_mode = sample_mode
self.GL_degree = GL_degree
self.Xq, self.Wq = calc_GL_coefs(GL_degree)
self.prop = prop
self.geo = Geo1D([0.0, 1.0])
network = StaArc1DNet(backbone_net, R=1.0, Tb=Tb/T_red)
super().__init__(network)
self.set_loss_func(F.smooth_l1_loss)
def _define_loss_terms(self):
"""
Define physics-informed loss terms for the steady-state arc model.
This method constructs the complete loss function by defining residuals for:
1. Energy PDE in the domain (steady-state energy balance)
2. Boundary condition at r=0 (symmetry: dT/dr = 0)
"""
def _pde_residual(network, x):
"""
PDE residual for stationary arc equation in normalized coordinates.
"""
T = network(x)
kappa = self.prop.kappa(T.view(-1)*self.T_red).view(-1,1)
sigma = self.prop.sigma(T.view(-1)*self.T_red).view(-1,1)
nec = self.prop.nec(T.view(-1)*self.T_red).view(-1,1)
Tq = network(self.Xq)
sigma_q = self.prop.sigma(Tq.view(-1)*self.T_red).view(-1,1)
arc_cond = np.pi*self.R*self.R*torch.sum(self.Wq*self.Xq*sigma_q)
joule = sigma*(self.I/arc_cond)**2
radiation = 4*np.pi*nec
net_energy = (joule - radiation)/self.T_red*self.R*self.R
T_x = df_dX(T, x)
T_term = x*kappa*T_x
T_xx = df_dX(T_term, x)
func = T_xx + x*net_energy
return func
def _bc_residual(network, x):
"""
Boundary condition residual at r=0 (symmetry condition).
"""
T = network(x)
T_x = df_dX(T, x)
return T_x
# Sample domain collocation points
x_domain = self.geo.sample_domain(self.train_data_size, mode=self.sample_mode)
# Sample boundary points
x_bc = self.geo.sample_boundary()
x_bc_left = x_bc[0]
# Add equation terms with weights
self.add_equation('Domain', _pde_residual, weight=1.0, data=x_domain)
self.add_equation('Left Boundary', _bc_residual, weight=10.0, data=x_bc_left)
[docs]
class StaArc1DVisCallback(VisualizationCallback):
"""
Custom visualization callback for 1D stationary arc PINN training.
This callback provides comprehensive monitoring of the arc model training:
1. Real-time TensorBoard logging during training
2. Training history tracking for post-training animation
3. Temperature distribution evolution and error tracking
4. Loss convergence monitoring
"""
def __init__(self, model: 'StaArc1DModel', log_freq: int = 50,
save_history: bool = True, history_freq: int = None,
T_csv_file: str = None,
gif_enabled: bool = False,
gif_dir: str = None,
gif_freq: int = None,
gif_duration_ms: int = 300,
gif_cleanup_tmp: bool = True):
"""
Initialize the arc visualization callback.
Parameters:
-----------
model : StaArc1DModel
The arc model instance, needed to access material properties and parameters
log_freq : int, default=50
Frequency (in epochs) for logging visualizations to TensorBoard.
E.g., log_freq=50 means visualize every 50 epochs.
save_history : bool, default=True
Whether to save prediction snapshots for creating training animations.
Set to False if you don't need animations to save memory.
history_freq : int, optional
Frequency (in epochs) for saving history snapshots.
If None, defaults to log_freq.
Use a larger value to reduce memory consumption.
T_csv_file : str, optional
Path to CSV file containing reference temperature data for comparison.
If provided, the reference temperature profile will be loaded and used for error analysis.
gif_enabled : bool, default=False
Whether to save per-epoch frames and generate a GIF animation that shows
both the temperature comparison (T vs T_ref) and the loss curve.
gif_dir : str, optional (default=None)
Output directory for the final GIF and final summary plots.
If None, defaults to the current working directory.
The temporary frames directory (gif_tmp_dir) is automatically created
as <gif_dir>/tmp_frames.
The GIF filename is fixed as 'training_animation.gif' and saved in gif_dir.
gif_freq : int, optional
Frequency (in epochs) to save frames for the GIF. If None, uses history_freq.
gif_duration_ms : int, default=300
Duration per frame in milliseconds in the final GIF.
gif_cleanup_tmp : bool, default=True
Whether to remove the temporary PNG frames after saving the GIF.
"""
super().__init__(name='CS-PINN_1D_Arc', log_freq=log_freq)
self.model = model
# Create evaluation grid for visualization
self.x = model.geo.sample_domain(model.test_data_size, mode='uniform',
include_boundary=True, to_tensor=False)
# Physical parameters for context (from material properties)
self.T_red = model.T_red
self.R = model.R
self.I = model.I
self.Tb = model.Tb
# Load reference temperature profile if CSV file is provided
if T_csv_file is not None:
self.T_ref_func = get_Tfunc_from_file(T_csv_file)
self.T_ref = self.T_ref_func(self.x*self.R) # Physical units
# Training history tracking for animation
self.save_history = save_history
self.history_freq = history_freq if history_freq is not None else log_freq
self.history = {
'epochs': [], # List of epoch numbers
'axis': self.x, # Evaluation grid (fixed), with shape (N,1)
'T_ref': self.T_ref if T_csv_file is not None else None, # Reference temperature profile for comparison
'T': [], # List of temperature profiles (T_red*T_reduced)
'losses': [], # List of total loss values
'arc_conductance': [], # List of arc conductance values per epoch
'integral_powers': [], # List of integrated Joule heating values
'T_center': [] # List of center temperatures at r→0
}
# GIF configuration/state
self.gif_enabled = gif_enabled
# Set gif_dir to current working directory if not provided
if gif_dir is None:
self.gif_dir = os.getcwd()
else:
self.gif_dir = gif_dir
# gif_tmp_dir is automatically set as gif_dir/tmp_frames
self.gif_tmp_dir = os.path.join(self.gif_dir, 'tmp_frames')
# gif_filename is fixed
self.gif_filename = 'training_animation.gif'
self.gif_freq = gif_freq if gif_freq is not None else self.history_freq
self.gif_duration_ms = gif_duration_ms
self.gif_cleanup_tmp = gif_cleanup_tmp
self._gif_frames = []
if self.gif_enabled:
os.makedirs(self.gif_dir, exist_ok=True)
os.makedirs(self.gif_tmp_dir, exist_ok=True)
def _compute_material_properties(self, T_reduced: np.ndarray) -> dict:
"""
Compute material properties at given temperature profile.
Parameters:
-----------
T_reduced : np.ndarray
Temperature profile in reduced units (divide by T_red to get physical units)
Returns:
--------
dict
Dictionary containing:
- 'T_physical': Physical temperature (K)
- 'kappa': Thermal conductivity (W/(m·K))
- 'sigma': Electrical conductivity (1/(Ω·m))
- 'nec': Net emission coefficient (W/m³)
- 'max_T': Maximum temperature
- 'T_center': Center temperature at r→0
"""
T_physical = T_reduced * self.T_red
# Ensure material property class is available
if self.model.prop is None:
return {
'T_physical': T_physical,
'kappa': None,
'sigma': None,
'nec': None,
'max_T': T_physical.max(),
'T_center': T_physical[0]
}
# Compute properties using model's material property class
# T_tensor = torch.tensor(T_physical, dtype=REAL())
with torch.no_grad():
T_tensor = numpy2torch(T_physical, require_grad=False)
kappa = self.model.prop.kappa(T_tensor).cpu().numpy()
sigma = self.model.prop.sigma(T_tensor).cpu().numpy()
nec = self.model.prop.nec(T_tensor).cpu().numpy()
return {
'T_physical': T_physical,
'kappa': kappa,
'sigma': sigma,
'nec': nec,
'max_T': T_physical.max(),
'T_center': T_physical[0]
}
def _make_figure(self, epoch: int, T_reduced: np.ndarray, props: dict, **kwargs) -> plt.Figure:
"""
Build the multi-panel matplotlib figure used for both TensorBoard logging
and GIF frames.
Parameters:
- epoch: current epoch number (int)
- T_reduced: model output in reduced units (np.ndarray)
- props: material property dictionary from _compute_material_properties()
- kwargs: optional training info (e.g., total_loss)
Returns:
- matplotlib Figure
"""
fig = plt.figure(figsize=(14, 10))
gs = fig.add_gridspec(2, 2, hspace=0.35, wspace=0.3)
# Panel 1: Temperature profile - Prediction vs Reference
ax1 = fig.add_subplot(gs[0, 0])
T_physical = T_reduced * self.T_red
# Plot network prediction
ax1.plot(self.x, T_physical, 'b-', linewidth=2.5, label='CS-PINN Prediction')
# Plot reference temperature if available
if self.T_ref is not None:
ax1.plot(self.x, self.T_ref, 'r--', linewidth=2.0, label='Reference', alpha=0.8)
# Compute and display error metrics
error = np.abs(T_physical.flatten() - self.T_ref.flatten())
relative_error = error / (self.T_ref.flatten() + 1e-10) * 100
max_error = error.max()
mean_error = error.mean()
max_rel_error = relative_error.max()
rel_l2_error = calc_relative_l2_err(self.T_ref, T_physical)
info_text = (f'Epoch {epoch}\n'
f'Max T: {props["max_T"]:.0f} K\n'
f'Center T: {props["T_center"]:.0f} K\n'
f'Max Error: {max_error:.1f} K\n'
f'Mean Error: {mean_error:.1f} K\n'
f'Max Rel Error: {max_rel_error:.2f}%\n'
f'Rel L2 Error: {rel_l2_error:.5g}')
else:
ax1.fill_between(self.x.flatten(), T_physical.flatten(), alpha=0.3)
info_text = (f'Epoch {epoch}\n'
f'Max T: {props["max_T"]:.0f} K\n'
f'Center T: {props["T_center"]:.0f} K')
ax1.set_xlabel('Normalized radius r/R', fontsize=11)
ax1.set_ylabel('Temperature (K)', fontsize=11)
ax1.set_title('Temperature Distribution', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.legend(loc='best', fontsize=10)
ax1.text(0.05, 0.10, info_text,
transform=ax1.transAxes, fontsize=9, verticalalignment='bottom',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Panel 2: Material properties (kappa and sigma)
if props['kappa'] is not None and props['sigma'] is not None:
ax2 = fig.add_subplot(gs[0, 1])
ax2_twin = ax2.twinx()
line1 = ax2.plot(self.x, props['kappa'], 'g-', linewidth=2, label='κ(T) Thermal conductivity')
line2 = ax2_twin.plot(self.x, props['sigma'], 'orange', linewidth=2, linestyle='--',
label='σ(T) Electrical conductivity')
ax2.set_xlabel('Normalized radius r/R', fontsize=11)
ax2.set_ylabel('κ (W/(m·K))', fontsize=10, color='g')
ax2_twin.set_ylabel('σ (1/(Ω·m))', fontsize=10, color='orange')
ax2.set_title('Material Properties', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.tick_params(axis='y', labelcolor='g')
ax2_twin.tick_params(axis='y', labelcolor='orange')
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax2.legend(lines, labels, loc='best', fontsize=9)
else:
ax2 = fig.add_subplot(gs[0, 1])
ax2.text(0.5, 0.5, 'Material properties\nnot available',
ha='center', va='center', transform=ax2.transAxes,
fontsize=12, style='italic', color='gray')
ax2.set_title('Material Properties', fontsize=12, fontweight='bold')
# Panel 3: Radiation term (net emission coefficient)
if props['nec'] is not None:
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(self.x, props['nec'], 'r-', linewidth=2)
ax3.fill_between(self.x.flatten(), props['nec'].flatten(), alpha=0.3, color='red')
ax3.set_xlabel('Normalized radius r/R', fontsize=11)
ax3.set_ylabel('nec (W/m³)', fontsize=11)
ax3.set_title('Radiation Term (Net Emission Coefficient)', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3)
ax3.text(0.05, 0.1, f'Max nec: {props["nec"].max():.2e} W/m³',
transform=ax3.transAxes, fontsize=10, verticalalignment='bottom',
bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
else:
ax3 = fig.add_subplot(gs[1, 0])
ax3.text(0.5, 0.5, 'Radiation term\nnot available',
ha='center', va='center', transform=ax3.transAxes,
fontsize=12, style='italic', color='gray')
ax3.set_title('Radiation Term', fontsize=12, fontweight='bold')
# Panel 4: Loss curve (training history)
ax4 = fig.add_subplot(gs[1, 1])
if self.history['losses']:
loss_epochs = self.history['epochs']
loss_values = self.history['losses']
ax4.semilogy(loss_epochs, loss_values, 'purple', linewidth=2.5, marker='o', markersize=4)
ax4.set_xlabel('Epoch', fontsize=11)
ax4.set_ylabel('Loss (log scale)', fontsize=11)
ax4.set_title('Training Loss Convergence', fontsize=12, fontweight='bold')
ax4.grid(True, alpha=0.3, which='both')
current_loss = kwargs.get('total_loss', None)
if current_loss is not None:
if isinstance(current_loss, torch.Tensor):
loss_val = current_loss.item()
else:
loss_val = float(current_loss)
ax4.text(0.95, 0.95, f'Current loss: {loss_val:.2e}',
transform=ax4.transAxes, fontsize=10, verticalalignment='top',
horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
else:
ax4.text(0.5, 0.5, 'No loss history yet',
ha='center', va='center', transform=ax4.transAxes,
fontsize=12, style='italic', color='gray')
ax4.set_title('Training Loss Convergence', fontsize=12, fontweight='bold')
fig.suptitle(f'Arc 1D PINN Model - Epoch {epoch}', fontsize=14, fontweight='bold', y=0.995)
return fig
[docs]
def visualize(self, network, epoch: int, writer: SummaryWriter, **kwargs):
"""
Generate visualization plots for the current training epoch.
This method is called automatically by PINN.train() at specified intervals.
It performs two main tasks:
1. Creates comparison plots and logs to TensorBoard
2. Saves prediction snapshots for training animation (if enabled)
Parameters:
-----------
network : nn.Module
The neural network being trained
epoch : int
Current training epoch number
writer : SummaryWriter
TensorBoard writer for logging figures
kwargs : dict
Additional training information:
- 'total_loss': Total loss value (torch.Tensor or float)
- 'loss_dict': Dictionary of individual loss terms
Returns:
--------
dict
Dictionary mapping plot names to matplotlib figures
"""
network.eval()
# Generate predictions on evaluation grid (in reduced units)
with torch.no_grad():
T_reduced = network(numpy2torch(self.x, require_grad=False)).cpu().numpy()
# Compute material properties at current temperature profile
props = self._compute_material_properties(T_reduced.flatten())
# Save history for animation (only at specified frequency)
if self.save_history and epoch % self.history_freq == 0:
self.history['epochs'].append(epoch)
self.history['T'].append(T_reduced.copy())
self.history['T_center'].append(props['T_center'])
# Extract total loss from training info
total_loss = kwargs.get('total_loss', None)
if total_loss is not None:
self.history['losses'].append(total_loss.item())
# Create figure with multiple subplots for comprehensive visualization
fig = self._make_figure(epoch=epoch, T_reduced=T_reduced, props=props, **kwargs)
# Optionally save frame for GIF
if self.gif_enabled and (epoch % self.gif_freq == 0):
frame_path = os.path.join(self.gif_tmp_dir, f'epoch_{epoch:06d}.png')
try:
fig.savefig(frame_path, dpi=120)
self._gif_frames.append(frame_path)
except Exception as e:
print(f'Warning: failed to save GIF frame at epoch {epoch}: {e}')
return {'arc_visualization': fig}
[docs]
def save_gif(self, gif_path: str = None, duration_ms: int = None, loop: int = 0):
"""
Assemble saved frames into a GIF showing T vs T_ref and loss evolution.
Parameters:
- gif_path: optional output path; defaults to <gif_dir>/<gif_filename>
- duration_ms: per-frame duration in milliseconds; defaults to gif_duration_ms
- loop: number of loops (0 = infinite)
"""
if not self.gif_enabled:
print('GIF not enabled; nothing to save.')
return
if len(self._gif_frames) == 0:
print('No frames collected; GIF not created.')
return
out_path = gif_path if gif_path is not None else os.path.join(self.gif_dir, self.gif_filename)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
frames_sorted = sorted(self._gif_frames)
try:
img2gif(frames_sorted, out_path, duration=(duration_ms or self.gif_duration_ms), loop=loop)
print(f'GIF saved: {out_path}')
except Exception as e:
print(f'Failed to create GIF: {e}')
return
if self.gif_cleanup_tmp:
# Best-effort cleanup of temporary frames
try:
for f in frames_sorted:
if os.path.exists(f):
os.remove(f)
if os.path.isdir(self.gif_tmp_dir) and len(os.listdir(self.gif_tmp_dir)) == 0:
os.rmdir(self.gif_tmp_dir)
except Exception as e:
print(f'Warning: failed cleanup of tmp frames: {e}')
[docs]
def save_final_results(self, network: nn.Module, save_dir: str = None, epoch: int = None, **kwargs):
"""
Save final result figures similar to TensorBoard panels and a dedicated loss plot.
Outputs:
- final_panels.png: the same 2x2 panel figure used in TensorBoard
- loss_curve.png: standalone loss curve using collected history
Parameters:
- network: trained network to generate the final panel figure
- save_dir: output directory; defaults to gif_dir
- epoch: epoch number to show in the title; if None, uses last recorded or 0
- kwargs: optional training info for annotations (e.g., total_loss)
"""
out_dir = save_dir or self.gif_dir
os.makedirs(out_dir, exist_ok=True)
# Compute current prediction and properties
network.eval()
with torch.no_grad():
T_reduced = network(numpy2torch(self.x, require_grad=False)).cpu().numpy()
props = self._compute_material_properties(T_reduced.flatten())
# Determine epoch label
if epoch is None:
epoch = self.history['epochs'][-1] if self.history['epochs'] else 0
# Save panel figure
fig = self._make_figure(epoch=epoch, T_reduced=T_reduced, props=props, **kwargs)
panels_path = os.path.join(out_dir, 'final_panels.png')
try:
fig.savefig(panels_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Final panels saved: {panels_path}')
except Exception as e:
print(f'Failed to save final panels: {e}')
# Save standalone loss curve
if self.history['losses']:
plt.figure(figsize=(7, 5))
plt.semilogy(self.history['epochs'], self.history['losses'], 'purple', linewidth=2.0, marker='o', markersize=4)
plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('Training Loss Curve')
plt.grid(True, alpha=0.3, which='both')
loss_path = os.path.join(out_dir, 'loss_curve.png')
try:
plt.savefig(loss_path, dpi=150, bbox_inches='tight')
print(f'Loss curve saved: {loss_path}')
except Exception as e:
print(f'Failed to save loss curve: {e}')
finally:
plt.close()
else:
print('No loss history recorded; loss_curve.png not created.')
[docs]
class TraArc1DTempNet(nn.Module):
"""
Neural network wrapper for solving 1D transient arc temperature equation with
automatic boundary condition enforcement.
This network automatically satisfies the Dirichlet boundary condition at r = R
(arc radius) for all time steps by construction, eliminating the need for
explicit boundary loss terms. The temperature output is modified to ensure
T(R, t) = Tb for all t.
The network applies a transformation to the backbone network output:
T(r, t) = (r - R) · N(r, t) + Tb
where N(r, t) is the backbone network output (takes both r and t as inputs),
R is the normalized arc radius (typically 1.0), and Tb is the boundary
temperature (normalized). At r = R, the boundary condition T(R, t) = Tb is
satisfied for all t. Unlike the stationary case (StaArc1DNet), this network
handles time-dependent temperature evolution.
Parameters (Constructor)
------------------------
network : nn.Module
Backbone neural network (e.g., FNN) that maps (r, t) → N(r, t)
Input shape: [batch_size, 2] where [:, 0]=r and [:, 1]=t
Output shape: [batch_size, 1] representing network prediction
R : float, optional
Normalized arc radius (default: 1.0)
Tb : float, optional
Normalized boundary temperature at r = R (default: 0.03)
"""
def __init__(self, network, R=1.0, Tb=0.03):
super(TraArc1DTempNet, self).__init__()
self.network = network
self.R = R # Reduced arc radius
self.Tb = Tb # Reduced boundary temperature at r=R
[docs]
def forward(self, x):
"""
Forward pass with automatic boundary condition enforcement.
Parameters
----------
x : torch.Tensor
Input tensor of shape [batch_size, 2] where:
x[:, 0] : r (normalized radius, range [0, 1])
x[:, 1] : t (normalized time, range [0, 1])
Returns
-------
torch.Tensor
Temperature T(r, t) in normalized units, shape [batch_size, 1]
Satisfies T(R, t) = Tb for all t by construction
"""
out = self.network(x)
T = out*(x[:,0:1] - self.R) + self.Tb # Enforce T(R, t) = Tb
return T
[docs]
class TraArc1DTempModel(PINN):
"""
PINN model for solving 1D transient arc discharge energy equation without radial velocity.
This class implements a Physics-Informed Neural Network specifically designed for
simulating transient (time-dependent) arc discharge phenomena. The model solves the
nonlinear transient energy balance equation considering time-dependent temperature
evolution, thermal conduction, and radiation effects.
Physical Constraints:
---------------------
- T(R, t) = Tb (Dirichlet boundary condition at arc radius)
- ∂T/∂r(0, t) = 0 (symmetry condition at centerline)
- T(r, 0) = Tinit_func(r) (initial temperature distribution)
"""
def __init__(
self,
R,
I,
Tb=300.0,
Tinit_func=None,
T_red=1e4,
t_red=1e-3,
backbone_net=FNN(layers=[2, 200, 200, 200, 200, 200, 200, 1]),
train_data_x_size=200,
train_data_t_size=100,
sample_mode='uniform',
prop:ArcPropSpline=None,
):
self.R = R
self.I = I
self.T_red = T_red
self.Tb = Tb
if Tinit_func is None:
raise ValueError('Tinit_func (initial condition) must be provided for transient arc model.')
else:
self.Tinit_func = Tinit_func
self.t_red = t_red
self.train_data_x_size = train_data_x_size
self.train_data_t_size = train_data_t_size
self.sample_mode = sample_mode
self.prop = prop
self.geo = Geo1DTime([0.0, 1.0], ts=0.0, te=1.0)
network = TraArc1DTempNet(backbone_net, R=1.0, Tb=Tb/T_red)
super().__init__(network)
self.set_loss_func(F.smooth_l1_loss)
def _define_loss_terms(self):
"""
Define physics-informed loss terms for the transient arc model without radial velocity.
This method constructs the complete loss function by defining residuals for:
1. Energy PDE in the domain (transient temperature evolution without convection)
2. Boundary condition at r=0 (symmetry: ∂T/∂r = 0)
3. Initial condition at t=0 (prescribed temperature distribution)
"""
def _pde_residual(network, x):
"""
PDE residual for transient arc equation without radial velocity.
Parameters
----------
network : nn.Module
Neural network model
x : torch.Tensor
Input tensor where x[:,0] is r and x[:,1] is t
"""
T = network(x)
kappa = self.prop.kappa(T.view(-1)*self.T_red).view(-1,1)
Cp = self.prop.Cp(T.view(-1)*self.T_red).view(-1,1)
rho = self.prop.rho(T.view(-1)*self.T_red).view(-1,1)
nec = self.prop.nec(T.view(-1)*self.T_red).view(-1,1)
joule = 0
radiation = 4*np.pi*nec
net_energy = joule - radiation
T_x = df_dX(T, x)
T_r = T_x[:,0:1]
T_t = T_x[:,1:2]
r = x[:,0:1]
T_term = r*kappa*T_r
T_xx = df_dX(T_term, x)
T_rr = T_xx[:,0:1]
func = T_t - (net_energy*(self.t_red/self.T_red) + T_rr/r*(self.t_red/self.R/self.R))/(rho*Cp)
return func
def _bc_residual(network, x):
"""
Boundary condition residual at r=0 (symmetry condition).
"""
T = network(x)
T_x = df_dX(T, x)
func_bc = T_x[:,0:1]
return func_bc
def _init_residual(network, x):
"""
Initial condition residual at t=0.
"""
T = network(x)
func_ic = T - Ti
return func_ic
# Sample domain collocation points
xt_domain, xt_bc = self.geo.sample_all_domain(Nx=self.train_data_x_size,
Nt=self.train_data_t_size,
mode=[self.sample_mode, self.sample_mode])
xb = xt_bc[0][0]
xi = xt_bc[1]
if isinstance(xi, torch.Tensor):
_xi = xi.detach().cpu().numpy()
Ti = self.Tinit_func(_xi[:,0:1]*self.R)/self.T_red
Ti = numpy2torch(Ti)
else: # np.ndarray
Ti = self.Tinit_func(xi[:,0:1]*self.R)/self.T_red
# Add equation terms with weights
self.add_equation('Domain', _pde_residual, weight=1.0, data=xt_domain)
self.add_equation('Left Boundary', _bc_residual, weight=10.0, data=xb)
self.add_equation('Initial Condition', _init_residual, weight=10.0, data=xi)
[docs]
class TraArc1DTempVisCallback(VisualizationCallback):
"""
Custom visualization callback for 1D transient arc PINN training (temperature only).
This callback provides comprehensive monitoring of the transient arc model training:
1. Real-time TensorBoard logging during training
2. Training history tracking for post-training animation
3. Temperature distribution evolution at multiple time steps
4. Material property evolution over time
5. Loss convergence monitoring
"""
def __init__(self, model: 'TraArc1DTempModel', log_freq: int = 50,
save_history: bool = True, history_freq: int = None,
x_eval: np.ndarray = np.linspace(0, 1, 201, dtype=REAL()).reshape(-1,1),
t_eval: list = [0.1, 0.5, 0.9],
T_csv_file: list[str] = ['','',''], # Paths to CSV files for reference temperature data at different times
# num_time_snapshots: int = 5,
gif_enabled: bool = False,
gif_dir: str = None,
gif_freq: int = None,
gif_duration_ms: int = 300,
gif_cleanup_tmp: bool = True):
"""
Initialize the transient arc visualization callback for real-time training monitoring.
This callback provides comprehensive real-time monitoring and post-training visualization
capabilities for transient arc simulations. It tracks temperature evolution at multiple
time points and generates publication-quality figures and animations.
Parameters:
-----------
model : TraArc1DTempModel
The transient arc model instance for temperature-only simulation.
Used to access geometry, physical parameters, and material properties.
log_freq : int, default=50
Frequency (in epochs) for logging visualizations to TensorBoard.
Controls how often multi-panel figures are generated and logged.
save_history : bool, default=True
Whether to save prediction snapshots for creating training animations.
Set to False to reduce memory consumption if animations are not needed.
history_freq : int, optional
Frequency (in epochs) for saving history snapshots.
If None, defaults to log_freq.
Use larger values (e.g., 200) to reduce memory for long training runs.
x_eval : np.ndarray, shape (n_r, 1)
Radial evaluation grid for visualization (normalized radius, 0 to 1).
Default: 201 points linearly spaced from 0 to 1.
Use finer grids for higher-resolution visualizations (more detail).
Use coarser grids for faster evaluation and smaller GIF files.
t_eval : list or np.ndarray
List of normalized time points (0 to 1) at which to display temperature
snapshots in the multi-panel figures.
Example: [0.1, 0.5, 0.9] shows temperature at 10%, 50%, 90% of total time.
Default: [0.1, 0.5, 0.9] (3 snapshots per epoch).
T_csv_file : list[str], optional
Paths to CSV files containing reference temperature data at different times
for comparison with PINN predictions.
- One CSV file per time point in t_eval (must match length)
- CSV columns: 'r(m)' (radius in meters), 'T(K)' (temperature in Kelvin)
- Empty string '' or None skips reference comparison for that time point
Example: ['ref_0.1ms.csv', 'ref_0.5ms.csv', 'ref_0.9ms.csv']
gif_enabled : bool, default=False
Whether to save per-epoch frames and generate a training animation GIF.
When True, PNG frames are saved at gif_freq intervals and assembled into
a GIF showing loss convergence and temperature evolution.
Enables visual inspection of training progress and convergence behavior.
gif_dir : str, optional (default=None)
Output directory for GIF animation and final summary plots.
If None, defaults to current working directory.
gif_freq : int, optional
Frequency (in epochs) to save frames for GIF assembly.
If None, uses history_freq.
Use larger values (e.g., 500) for smaller GIF files and faster creation.
gif_duration_ms : int, default=300
Duration per frame in milliseconds for the final GIF animation.
gif_cleanup_tmp : bool, default=True
Whether to automatically delete temporary PNG frames after GIF creation.
Set to False to retain frames for manual inspection or re-assembly.
"""
super().__init__(name='CS-PINN_1D_Arc_Transient_noV', log_freq=log_freq)
self.model = model
self.x_eval = x_eval
self.t_eval = t_eval
self.xt_list = model.geo.sample_space_time_list(x_eval, t_eval, require_grad=False)
self.T_ref_func_list = [get_Tfunc_from_file(csv_file) for csv_file in T_csv_file]
# Physical parameters
self.T_red = model.T_red
self.t_red = model.t_red
self.R = model.R
self.I = model.I
self.Tb = model.Tb
self.Tinit_func = model.Tinit_func
# Training history tracking
self.save_history = save_history
self.history_freq = history_freq if history_freq is not None else log_freq
self.history = {
'epochs': [], # List of epoch numbers
'r_eval': self.x_eval, # Radial evaluation grid (fixed)
't_eval': self.t_eval, # Time evaluation points (fixed)
'T': [], # List of T(r,t) arrays [n_epochs, n_time, n_r]
'losses': [], # List of total loss values
'T_center_t': [], # List of center temp at all time points
}
# GIF configuration/state
self.gif_enabled = gif_enabled
if gif_dir is None:
self.gif_dir = os.getcwd()
else:
self.gif_dir = gif_dir
self.gif_tmp_dir = os.path.join(self.gif_dir, 'tmp_frames')
self.gif_filename = 'training_animation.gif'
self.gif_freq = gif_freq if gif_freq is not None else self.history_freq
self.gif_duration_ms = gif_duration_ms
self.gif_cleanup_tmp = gif_cleanup_tmp
self._gif_frames = []
if self.gif_enabled:
os.makedirs(self.gif_dir, exist_ok=True)
os.makedirs(self.gif_tmp_dir, exist_ok=True)
def _compute_material_properties_at_t(self, T_physical: np.ndarray, t_reduced: float) -> dict:
"""
Compute material properties at a given time step and temperature profile.
Parameters:
-----------
T_physical : np.ndarray
Temperature in physical units (K), shape (n_r,)
t_reduced : float
Normalized time value for reference in output
Returns:
--------
dict
Dictionary containing material properties and statistics
"""
if self.model.prop is None:
return {
'T_physical': T_physical,
'kappa': None,
'sigma': None,
'nec': None,
'max_T': T_physical.max(),
'T_center': T_physical[0],
't_physical': t_reduced * self.t_red
}
with torch.no_grad():
T_tensor = numpy2torch(T_physical, require_grad=False)
kappa = self.model.prop.kappa(T_tensor).cpu().numpy()
sigma = self.model.prop.sigma(T_tensor).cpu().numpy()
nec = self.model.prop.nec(T_tensor).cpu().numpy()
return {
'T_physical': T_physical,
'kappa': kappa,
'sigma': sigma,
'nec': nec,
'max_T': T_physical.max(),
'T_center': T_physical[0],
't_physical': t_reduced * self.t_red
}
def _make_figure(self, epoch: int, T_reduced_all: np.ndarray, **kwargs) -> plt.Figure:
"""
Build multi-panel figure showing temperature evolution over time steps.
Parameters:
-----------
epoch : int
Current training epoch
T_reduced_all : np.ndarray
Temperature array of shape (n_time, n_r) in reduced units
kwargs : dict
Optional training info (e.g., total_loss)
Returns:
--------
plt.Figure
Matplotlib figure with temperature panels in grid layout and
loss curve panel spanning the bottom row
"""
n_time = len(self.t_eval)
# Determine layout: temperature panels in grid, loss panel spans bottom
n_cols = min(3, n_time) # Max 3 columns for temperature panels
n_rows_temp = (n_time + n_cols - 1) // n_cols # Rows needed for temperature panels
n_rows_total = n_rows_temp + 1 # Add one row for loss panel
fig = plt.figure(figsize=(5*n_cols, 4*n_rows_temp + 3))
gs = fig.add_gridspec(n_rows_total, n_cols, hspace=0.4, wspace=0.3)
# Plot temperature at each time step
for i in range(n_time):
ax = fig.add_subplot(gs[i // n_cols, i % n_cols])
T_phys = T_reduced_all[i] * self.T_red
t_val = self.t_eval[i]
t_physical = t_val * self.t_red
props = self._compute_material_properties_at_t(T_phys, t_val)
ax.plot(self.x_eval, T_phys, 'b-', linewidth=2.5, label='CS-PINN Prediction')
# Plot reference temperature if available
T_ref_func = self.T_ref_func_list[i]
if T_ref_func is not None:
T_ref = T_ref_func(self.x_eval*self.R)
ax.plot(self.x_eval, T_ref, 'r--', linewidth=2.0, label='Reference', alpha=0.8)
# Compute and display error metrics
error = np.abs(T_phys.flatten() - T_ref.flatten())
relative_error = error / (T_ref.flatten() + 1e-10) * 100
max_error = error.max()
mean_error = error.mean()
max_rel_error = relative_error.max()
rel_l2_error = calc_relative_l2_err(T_ref, T_phys)
info_text = (f'Max Error: {max_error:.1f} K\n'
f'Mean Error: {mean_error:.1f} K\n'
f'Max Rel Error: {max_rel_error:.2f}%\n'
f'Rel L2 Error: {rel_l2_error:.5g}')
else:
info_text = (f'Epoch {epoch}\n'
f'Max T: {props["max_T"]:.0f} K\n'
f'Center T: {props["T_center"]:.0f} K')
ax.set_xlabel('Normalized radius r/R', fontsize=10)
ax.set_ylabel('Temperature (K)', fontsize=10)
ax.set_title(f'Time t = {t_physical*1e3:.2f} ms (epoch={epoch})', fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='best', fontsize=10)
# Add info box
ax.text(0.05, 0.05, info_text,
transform=ax.transAxes, fontsize=9, verticalalignment='bottom',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Add loss curve panel spanning the entire bottom row
ax_loss = fig.add_subplot(gs[n_rows_temp, :]) # Span all columns in last row
if self.history['losses']:
loss_epochs = self.history['epochs']
loss_values = self.history['losses']
ax_loss.semilogy(loss_epochs, loss_values, 'purple', linewidth=2.5, marker='o', markersize=4)
ax_loss.set_xlabel('Epoch', fontsize=11)
ax_loss.set_ylabel('Loss (log scale)', fontsize=11)
ax_loss.set_title('Training Loss Convergence', fontsize=12, fontweight='bold')
ax_loss.grid(True, alpha=0.3, which='both')
current_loss = kwargs.get('total_loss', None)
if current_loss is not None:
if isinstance(current_loss, torch.Tensor):
loss_val = current_loss.item()
else:
loss_val = float(current_loss)
ax_loss.text(0.98, 0.95, f'Current loss: {loss_val:.2e}',
transform=ax_loss.transAxes, fontsize=10, verticalalignment='top',
horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
else:
ax_loss.text(0.5, 0.5, 'No loss history yet',
ha='center', va='center', transform=ax_loss.transAxes,
fontsize=12, style='italic', color='gray')
ax_loss.set_title('Training Loss', fontsize=12, fontweight='bold')
fig.suptitle(f'Arc 1D Transient PINN Model - Epoch {epoch}',
fontsize=14, fontweight='bold', y=0.995)
return fig
[docs]
def visualize(self, network, epoch: int, writer: SummaryWriter, **kwargs):
"""
Generate visualization plots for the current training epoch.
Parameters:
-----------
network : nn.Module
The neural network being trained
epoch : int
Current training epoch number
writer : SummaryWriter
TensorBoard writer for logging figures
kwargs : dict
Additional training information (e.g., 'total_loss')
Returns:
--------
dict
Dictionary mapping plot names to matplotlib figures
"""
network.eval()
# Generate predictions on evaluation grid at all time points
T_reduced_all = []
with torch.no_grad():
for t_val in self.t_eval:
# Create (r, t) pairs for this time step
x_r = self.x_eval.reshape(-1, 1)
t_reps = np.repeat(t_val, len(x_r)).astype(REAL()).reshape(-1, 1)
xt_grid = np.hstack([x_r, t_reps])
# Predict
T_reduced = network(numpy2torch(xt_grid, require_grad=False)).cpu().numpy()
T_reduced_all.append(T_reduced)
T_reduced_all = np.array(T_reduced_all) # shape: (n_time, n_r)
# Save history for animation
if self.save_history and epoch % self.history_freq == 0:
self.history['epochs'].append(epoch)
self.history['T'].append(T_reduced_all.copy())
# Track center temperature at each time point
T_center_at_times = T_reduced_all[:, 0]
self.history['T_center_t'].append(T_center_at_times)
# Extract loss
total_loss = kwargs.get('total_loss', None)
if total_loss is not None:
if isinstance(total_loss, torch.Tensor):
self.history['losses'].append(total_loss.item())
else:
self.history['losses'].append(float(total_loss))
# Create figure
fig = self._make_figure(epoch=epoch, T_reduced_all=T_reduced_all, **kwargs)
# Optionally save frame for GIF
if self.gif_enabled and (epoch % self.gif_freq == 0):
frame_path = os.path.join(self.gif_tmp_dir, f'epoch_{epoch:06d}.png')
try:
fig.savefig(frame_path, dpi=120)
self._gif_frames.append(frame_path)
except Exception as e:
print(f'Warning: failed to save GIF frame at epoch {epoch}: {e}')
return {'transient_visualization': fig}
[docs]
def save_gif(self, gif_path: str = None, duration_ms: int = None, loop: int = 0):
"""
Assemble saved frames into a GIF showing transient temperature evolution and loss.
Parameters:
-----------
gif_path : str, optional
Output path for the GIF; defaults to <gif_dir>/training_animation_transient.gif
duration_ms : int, optional
Per-frame duration in milliseconds; defaults to gif_duration_ms
loop : int, default=0
Number of loops (0 = infinite)
"""
if not self.gif_enabled:
print('GIF not enabled; nothing to save.')
return
if len(self._gif_frames) == 0:
print('No frames collected; GIF not created.')
return
out_path = gif_path if gif_path is not None else os.path.join(self.gif_dir, self.gif_filename)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
frames_sorted = sorted(self._gif_frames)
try:
img2gif(frames_sorted, out_path, duration=(duration_ms or self.gif_duration_ms), loop=loop)
print(f'GIF saved: {out_path}')
except Exception as e:
print(f'Failed to create GIF: {e}')
return
if self.gif_cleanup_tmp:
try:
for f in frames_sorted:
if os.path.exists(f):
os.remove(f)
if os.path.isdir(self.gif_tmp_dir) and len(os.listdir(self.gif_tmp_dir)) == 0:
os.rmdir(self.gif_tmp_dir)
except Exception as e:
print(f'Warning: failed cleanup of tmp frames: {e}')
[docs]
def save_final_results(self, network: nn.Module, save_dir: str = None, epoch: int = None, **kwargs):
"""
Save final result figures showing temperature evolution and loss curve.
Outputs:
- final_panels_transient.png: multi-panel figure with temperature at different times and loss curve
- loss_curve_transient.png: standalone loss curve plot
- center_temp_evolution.png: center temperature evolution over time at different epochs
Parameters:
-----------
network : nn.Module
Trained network for final prediction
save_dir : str, optional
Output directory; defaults to gif_dir
epoch : int, optional
Epoch number for title; if None, uses last recorded or 0
kwargs : dict
Optional training info
"""
out_dir = save_dir or self.gif_dir
os.makedirs(out_dir, exist_ok=True)
# Generate final predictions
network.eval()
T_reduced_all = []
with torch.no_grad():
for t_val in self.t_eval:
x_r = self.x_eval.reshape(-1, 1)
t_reps = np.repeat(t_val, len(x_r)).astype(REAL()).reshape(-1, 1)
xt_grid = np.hstack([x_r, t_reps])
T_reduced = network(numpy2torch(xt_grid, require_grad=False)).cpu().numpy()
T_reduced_all.append(T_reduced)
T_reduced_all = np.array(T_reduced_all)
# Determine epoch label
if epoch is None:
epoch = self.history['epochs'][-1] if self.history['epochs'] else 0
# Save multi-panel figure (temperature profiles + loss curve)
fig = self._make_figure(epoch=epoch, T_reduced_all=T_reduced_all, **kwargs)
panels_path = os.path.join(out_dir, 'final_panels.png')
try:
fig.savefig(panels_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Final panels saved: {panels_path}')
except Exception as e:
print(f'Failed to save final panels: {e}')
# Save standalone loss curve
if self.history['losses']:
plt.figure(figsize=(7, 5))
plt.semilogy(self.history['epochs'], self.history['losses'], 'purple',
linewidth=2.0, marker='o', markersize=4)
plt.xlabel('Epoch', fontsize=11)
plt.ylabel('Loss (log scale)', fontsize=11)
plt.title('Training Loss Curve (Transient Arc Model)', fontsize=12, fontweight='bold')
plt.grid(True, alpha=0.3, which='both')
loss_path = os.path.join(out_dir, 'loss_curve.png')
try:
plt.savefig(loss_path, dpi=150, bbox_inches='tight')
print(f'Loss curve saved: {loss_path}')
except Exception as e:
print(f'Failed to save loss curve: {e}')
finally:
plt.close()
else:
print('No loss history recorded; loss_curve.png not created.')
[docs]
class TraArc1DVelNet(nn.Module):
"""
Neural network wrapper for solving 1D transient arc with radial velocity using
automatic boundary condition enforcement.
This network automatically satisfies the boundary conditions at r = 0 (symmetry)
by construction. The velocity output is modified to ensure V(0, t) = 0 for all
time steps, which is a physical requirement due to radial symmetry.
The network applies a transformation to the backbone network output:
V(r, t) = r · N(r, t)
where N(r, t) is the backbone network output (takes both r and t as inputs),
and r is the normalized radial coordinate. At r = 0, the symmetry condition
V(0, t) = 0 is automatically satisfied. In arc discharge simulations, the
radial velocity must be zero due to cylindrical symmetry.
Parameters (Constructor)
------------------------
network : nn.Module
Backbone neural network (e.g., FNN) that maps (r, t) → N(r, t)
Input shape: [batch_size, 2] where [:, 0]=r and [:, 1]=t
Output shape: [batch_size, 1] representing velocity prediction
"""
def __init__(self, network):
super(TraArc1DVelNet, self).__init__()
self.network = network
[docs]
def forward(self, x):
"""
Forward pass with automatic symmetry condition enforcement.
Parameters
----------
x : torch.Tensor
Input tensor of shape [batch_size, 2] where:
x[:, 0] : r (normalized radius, range [0, 1])
x[:, 1] : t (normalized time, range [0, 1])
Returns
-------
torch.Tensor
Radial velocity V(r, t) in normalized units, shape [batch_size, 1]
Satisfies V(0, t) = 0 for all t by construction (symmetry)
"""
out = self.network(x)
V = out*x[:,0:1] # Enforce V(0, t) = 0 by multiplication with r
return V
[docs]
class TraArc1DNet(nn.Module):
"""
Neural network wrapper for coupled 1D transient arc equations with automatic
boundary condition enforcement.
This network simultaneously predicts both temperature T(r, t) and radial velocity
V(r, t) while automatically satisfying multiple boundary conditions:
1. Temperature boundary condition: T(R, t) = Tb (at arc boundary)
2. Velocity symmetry condition: V(0, t) = 0 (at centerline)
The network applies transformations to the backbone network outputs:
T(r, t) = (r - R) · N₁(r, t) + Tb
V(r, t) = r · N₂(r, t)
where N₁(r, t) and N₂(r, t) are the two outputs of the backbone network.
This design is for coupled energy and momentum transport in arc discharges,
where the automatic boundary condition enforcement reduces training complexity.
Parameters (Constructor)
------------------------
network : nn.Module
Backbone neural network (e.g., FNN) that maps (r, t) → [N₁(r, t), N₂(r, t)]
Input shape: [batch_size, 2] where [:, 0]=r and [:, 1]=t
Output shape: [batch_size, 2] representing [temperature, velocity] predictions
R : float, optional
Normalized arc radius (default: 1.0)
Tb : float, optional
Normalized boundary temperature at r = R (default: 0.03)
"""
def __init__(self, network, R=1.0, Tb=0.03):
super(TraArc1DNet, self).__init__()
self.network = network
self.R = R # Reduced arc radius
self.Tb = Tb # Reduced boundary temperature at r=R
[docs]
def forward(self, x):
"""
Forward pass with automatic boundary condition enforcement for both T and V.
Parameters
----------
x : torch.Tensor
Input tensor of shape [batch_size, 2] where:
x[:, 0] : r (normalized radius, range [0, 1])
x[:, 1] : t (normalized time, range [0, 1])
Returns
-------
T : torch.Tensor
Temperature T(r, t) in normalized units, shape [batch_size, 1]
Satisfies T(R, t) = Tb for all t by construction
V : torch.Tensor
Radial velocity V(r, t) in normalized units, shape [batch_size, 1]
Satisfies V(0, t) = 0 for all t by construction (symmetry)
"""
out = self.network(x)
T = out[:,0:1]*(x[:,0:1] - self.R) + self.Tb # Enforce T(R, t) = Tb
V = out[:,1:2]*x[:,0:1] # Enforce V(0, t) = 0
return T, V
[docs]
def get_TVfunc_from_file(csv_file):
"""
Load reference temperature and velocity profiles from CSV file for comparison.
Parameters
----------
csv_file : str
Path to CSV file containing reference temperature and velocity data.
The CSV should have columns: 'r(m)' (radius in meters),
'T(K)' (temperature in Kelvin), and 'V(m/s)' (velocity in m/s).
Returns
-------
T_spline : function
Cubic spline interpolation function for temperature at radius r
V_spline : function
Cubic spline interpolation function for velocity at radius r
Raises
------
FileNotFoundError
If csv_file does not exist.
"""
# Check file existence
if not os.path.exists(csv_file):
raise FileNotFoundError(f"CSV file not found: {csv_file}")
df = pd.read_csv(csv_file)
r_data = df['r(m)'].values.astype(REAL())
V_data = df['V(m/s)'].values.astype(REAL())
T_data = df['T(K)'].values.astype(REAL())
V_spline = intp.CubicSpline(r_data, V_data, extrapolate=True)
T_spline = intp.CubicSpline(r_data, T_data, extrapolate=True)
return T_spline, V_spline
[docs]
class TraArc1DModel(PINN):
"""
PINN model for solving coupled 1D transient arc plasma equations with temperature
and radial velocity.
Implements a Physics-Informed Neural Network specifically designed for simulating
transient (time-dependent) arc discharge phenomena with full coupling between energy
and momentum transport. Solves coupled nonlinear transient equations considering
temperature evolution, radial velocity, thermal conduction, convection, and
radiation effects.
Physical Constraints
--------------------
- T(R, t) = Tb (Dirichlet boundary condition at arc radius)
- ∂T/∂r(0, t) = 0 (symmetry condition for temperature at centerline)
- V(0, t) = 0 (symmetry condition for velocity at centerline)
- T(r, 0) = Tinit_func(r) (initial temperature distribution)
- Mass conservation through continuity equation
Parameters (Constructor)
------------------------
R : float
Arc radius [m]
I : float
Arc current [A]
Tb : float, optional
Boundary temperature at r = R [K] (default: 300.0)
Tinit_func : callable
Function that returns initial temperature profile T(r)
T_red : float, optional
Temperature reduction factor for normalization [K] (default: 1e4)
t_red : float, optional
Time reduction factor for normalization [s] (default: 1e-3)
backbone_net : nn.Module, optional
Backbone neural network with 2 outputs (T, V)
(default: FNN with 7 layers)
train_data_x_size : int, optional
Number of spatial training collocation points (default: 200)
train_data_t_size : int, optional
Number of temporal training collocation points (default: 100)
sample_mode : str, optional
Sampling strategy: 'uniform', 'lhs', or 'random' (default: 'uniform')
prop : ArcPropSpline, optional
Arc material properties object (default: None)
"""
def __init__(
self,
R,
I,
Tb=300.0,
Tinit_func=None,
T_red=1e4,
t_red=1e-3,
backbone_net=FNN(layers=[2, 300, 300, 300, 300, 300, 300, 2]),
train_data_x_size=200,
train_data_t_size=100,
sample_mode='uniform',
prop:ArcPropSpline=None,
):
self.R = R
self.I = I
self.T_red = T_red
self.Tb = Tb
if Tinit_func is None:
raise ValueError('Tinit_func (initial condition) must be provided for transient arc model.')
else:
self.Tinit_func = Tinit_func
self.t_red = t_red
self.train_data_x_size = train_data_x_size
self.train_data_t_size = train_data_t_size
self.sample_mode = sample_mode
self.prop = prop
self.geo = Geo1DTime([0.0, 1.0], ts=0.0, te=1.0)
network = TraArc1DNet(backbone_net, R=1.0, Tb=Tb/T_red)
super().__init__(network)
self.set_loss_func(F.smooth_l1_loss)
def _define_loss_terms(self):
"""
Define physics-informed loss terms for the coupled transient arc model.
This method constructs the complete loss function by defining residuals for:
1. Energy PDE in the domain (temperature evolution with convection)
2. Continuity PDE in the domain (mass conservation with velocity)
3. Boundary condition at r=0 (symmetry: ∂T/∂r = 0)
4. Initial condition at t=0 (prescribed temperature distribution)
"""
def _pde_T_residual(network, x):
"""
PDE residual for transient arc equation (Temperature equation).
Parameters
----------
network : nn.Module
Neural network model
x : torch.Tensor
Input tensor where x[:,0] is r and x[:,1] is t
"""
T, V = network(x)
kappa = self.prop.kappa(T.view(-1)*self.T_red).view(-1,1)
Cp = self.prop.Cp(T.view(-1)*self.T_red).view(-1,1)
rho = self.prop.rho(T.view(-1)*self.T_red).view(-1,1)
nec = self.prop.nec(T.view(-1)*self.T_red).view(-1,1)
joule = 0
radiation = 4*np.pi*nec
net_energy = joule - radiation
T_x = df_dX(T, x)
T_r = T_x[:,0:1]
T_t = T_x[:,1:2]
r = x[:,0:1]
T_term = r*kappa*T_r
T_xx = df_dX(T_term, x)
T_rr = T_xx[:,0:1]
func = T_t + V*T_r*(self.t_red/self.R) - (net_energy*(self.t_red/self.T_red) + T_rr/r*(self.t_red/self.R/self.R))/(rho*Cp)
return func
def _pde_V_residual(network, x):
"""
PDE residual for transient arc equation (Velocity equation).
Parameters
----------
network : nn.Module
Neural network model
x : torch.Tensor
Input tensor where x[:,0] is r and x[:,1] is t
"""
T, V = network(x)
rho = self.prop.rho(T.view(-1)*self.T_red).view(-1,1)
r = x[:,0:1]
rho_grad = df_dX(rho, x)
rho_t = rho_grad[:,1:2]
V_term = r*rho*V
V_term_grad = df_dX(V_term, x)
V_r = V_term_grad[:,0:1]
func = rho_t + V_r/r*(self.t_red/self.R)
return func
def _bc_T_residual(network, x):
"""
Boundary condition residual at r=0 (symmetry condition).
"""
T, _ = network(x)
T_x = df_dX(T, x)
func_bc = T_x[:,0:1]
return func_bc
def _init_T_residual(network, x):
"""
Initial condition residual at t=0.
"""
T, _ = network(x)
func_ic = T - Ti
return func_ic
# Sample domain collocation points
xt_domain, xt_bc = self.geo.sample_all_domain(Nx=self.train_data_x_size,
Nt=self.train_data_t_size,
mode=[self.sample_mode, self.sample_mode])
xb = xt_bc[0][0]
xi = xt_bc[1]
if isinstance(xi, torch.Tensor):
_xi = xi.detach().cpu().numpy()
Ti = self.Tinit_func(_xi[:,0:1]*self.R)/self.T_red
Ti = numpy2torch(Ti)
else: # np.ndarray
Ti = self.Tinit_func(xi[:,0:1]*self.R)/self.T_red
# Add equation terms with weights
self.add_equation('Domain_T', _pde_T_residual, weight=1.0, data=xt_domain)
self.add_equation('Domain_V', _pde_V_residual, weight=1.0, data=xt_domain)
self.add_equation('Left Boundary', _bc_T_residual, weight=10.0, data=xb)
self.add_equation('Initial Condition', _init_T_residual, weight=10.0, data=xi)
[docs]
class TraArc1DVisCallback(VisualizationCallback):
"""
Custom visualization callback for 1D transient arc PINN training with temperature and radial velocity.
This callback provides comprehensive monitoring of the coupled transient arc model training:
1. Real-time TensorBoard logging during training
2. Side-by-side comparison of temperature and velocity at multiple time steps
3. Training history tracking for post-training animation
4. Reference data comparison with error metrics (when CSV files provided)
5. Material property evolution monitoring over time
6. Loss convergence tracking with logarithmic scale visualization
"""
def __init__(self, model: 'TraArc1DModel', log_freq: int = 50,
save_history: bool = True, history_freq: int = None,
x_eval: np.ndarray = np.linspace(0, 1, 201, dtype=REAL()).reshape(-1,1),
t_eval: list = [0.1, 0.5, 0.9],
TV_csv_file: list[str] = ['','',''], # Paths to CSV files for reference temperature and velocity data at different times
# num_time_snapshots: int = 5,
gif_enabled: bool = False,
gif_dir: str = None,
gif_freq: int = None,
gif_duration_ms: int = 300,
gif_cleanup_tmp: bool = True):
"""
Initialize the transient arc visualization callback for real-time training monitoring.
This callback provides comprehensive real-time monitoring and post-training visualization
capabilities for coupled transient arc simulations with both temperature and velocity.
It tracks the evolution of both fields at multiple time points and generates
publication-quality figures with side-by-side T/V comparison and animations.
Parameters:
-----------
model : TraArc1DModel
The coupled transient arc model instance for temperature and velocity simulation.
Used to access geometry, physical parameters, and material properties.
log_freq : int, default=50
Frequency (in epochs) for logging visualizations to TensorBoard.
Controls how often multi-panel figures with T/V comparison are generated and logged.
save_history : bool, default=True
Whether to save prediction snapshots for creating training animations.
Set to False to reduce memory consumption if animations are not needed.
history_freq : int, optional
Frequency (in epochs) for saving history snapshots.
If None, defaults to log_freq.
Use larger values (e.g., 200) to reduce memory for long training runs.
x_eval : np.ndarray, shape (n_r, 1)
Radial evaluation grid for visualization (normalized radius, 0 to 1).
Default: 201 points linearly spaced from 0 to 1.
Finer grids provide higher-resolution visualizations but slower evaluation.
t_eval : list or np.ndarray
List of normalized time points (0 to 1) at which to display T and V
snapshots in the multi-panel figures.
Example: [0.1, 0.5, 0.9] shows distributions at 10%, 50%, 90% of total time.
Default: [0.1, 0.5, 0.9] (3 time snapshots per epoch).
Each time point creates one row with side-by-side T/V panels.
TV_csv_file : list[str], optional
Paths to CSV files containing reference temperature and velocity data
at different times for comparison with PINN predictions.
- One CSV file per time point in t_eval (must match length)
- CSV columns: 'r(m)' (radius in meters), 'T(K)' (temperature in Kelvin),
'V(m/s)' (velocity in meters per second)
- Empty string '' or None skips reference comparison for that time point
Example: ['ref_0.1ms.csv', 'ref_0.5ms.csv', 'ref_0.9ms.csv']
gif_enabled : bool, default=False
Whether to save per-epoch frames and generate a training animation GIF.
When True, PNG frames are saved at gif_freq intervals showing the evolution
of T and V distributions and loss convergence over training epochs.
Enables visual inspection of training dynamics and convergence behavior.
gif_dir : str, optional (default=None)
Output directory for GIF animation and final summary plots.
If None, defaults to current working directory.
gif_freq : int, optional
Frequency (in epochs) to save frames for GIF assembly.
If None, uses history_freq.
Use larger values (e.g., 500) for smaller GIF files and faster creation.
gif_duration_ms : int, default=300
Duration per frame in milliseconds for the final GIF animation.
gif_cleanup_tmp : bool, default=True
Whether to automatically delete temporary PNG frames after GIF creation.
Set to False to retain frames for manual inspection or re-assembly.
"""
super().__init__(name='CS-PINN_1D_Arc_Transient', log_freq=log_freq)
self.model = model
self.x_eval = x_eval
self.t_eval = t_eval
self.xt_list = model.geo.sample_space_time_list(x_eval, t_eval, require_grad=False)
# Load reference T and V functions from CSV files
TV_funcs = [get_TVfunc_from_file(csv_file) if csv_file else (None, None) for csv_file in TV_csv_file]
self.T_ref_func_list = [T_func for T_func, _ in TV_funcs]
self.V_ref_func_list = [V_func for _, V_func in TV_funcs]
# Physical parameters
self.T_red = model.T_red
self.t_red = model.t_red
self.R = model.R
self.I = model.I
self.Tb = model.Tb
self.Tinit_func = model.Tinit_func
# Training history tracking
self.save_history = save_history
self.history_freq = history_freq if history_freq is not None else log_freq
self.history = {
'epochs': [], # List of epoch numbers
'r_eval': self.x_eval, # Radial evaluation grid (fixed)
't_eval': self.t_eval, # Time evaluation points (fixed)
'T': [], # List of T(r,t) arrays [n_epochs, n_time, n_r]
'V': [], # List of V(r,t) arrays [n_epochs, n_time, n_r]
'losses': [], # List of total loss values
'T_center_t': [], # List of center temp at all time points
}
# GIF configuration/state
self.gif_enabled = gif_enabled
if gif_dir is None:
self.gif_dir = os.getcwd()
else:
self.gif_dir = gif_dir
self.gif_tmp_dir = os.path.join(self.gif_dir, 'tmp_frames')
self.gif_filename = 'training_animation.gif'
self.gif_freq = gif_freq if gif_freq is not None else self.history_freq
self.gif_duration_ms = gif_duration_ms
self.gif_cleanup_tmp = gif_cleanup_tmp
self._gif_frames = []
if self.gif_enabled:
os.makedirs(self.gif_dir, exist_ok=True)
os.makedirs(self.gif_tmp_dir, exist_ok=True)
def _compute_material_properties_at_t(self, T_physical: np.ndarray, t_reduced: float) -> dict:
"""
Compute material properties at a given time step and temperature profile.
Parameters:
-----------
T_physical : np.ndarray
Temperature in physical units (K), shape (n_r,)
t_reduced : float
Normalized time value for reference in output
Returns:
--------
dict
Dictionary containing material properties and statistics
"""
if self.model.prop is None:
return {
'T_physical': T_physical,
'kappa': None,
'sigma': None,
'nec': None,
'max_T': T_physical.max(),
'T_center': T_physical[0],
't_physical': t_reduced * self.t_red
}
with torch.no_grad():
T_tensor = numpy2torch(T_physical, require_grad=False)
kappa = self.model.prop.kappa(T_tensor).cpu().numpy()
sigma = self.model.prop.sigma(T_tensor).cpu().numpy()
nec = self.model.prop.nec(T_tensor).cpu().numpy()
return {
'T_physical': T_physical,
'kappa': kappa,
'sigma': sigma,
'nec': nec,
'max_T': T_physical.max(),
'T_center': T_physical[0],
't_physical': t_reduced * self.t_red
}
def _make_figure(self, epoch: int, T_reduced_all: np.ndarray, V_reduced_all: np.ndarray, **kwargs) -> plt.Figure:
"""
Build multi-panel figure showing temperature and velocity evolution over time steps.
Each time step shows T and V side by side for easy comparison.
Parameters:
-----------
epoch : int
Current training epoch
T_reduced_all : np.ndarray
Temperature array of shape (n_time, n_r) in reduced units
V_reduced_all : np.ndarray
Velocity array of shape (n_time, n_r) in reduced units
kwargs : dict
Optional training info (e.g., total_loss)
Returns:
--------
plt.Figure
Matplotlib figure with T/V panels side by side for each time step,
plus loss curve panel spanning the bottom row
"""
n_time = len(self.t_eval)
# Layout: Each row shows one time step with T and V side by side (2 columns per time)
# Plus 1 extra row for loss curve at the bottom
n_cols = 2 # T and V always side by side
n_rows = n_time + 1 # One row per time step, plus loss row
fig = plt.figure(figsize=(14, 4*n_time + 3))
gs = fig.add_gridspec(n_rows, n_cols, hspace=0.35, wspace=0.25,
height_ratios=[1]*n_time + [0.8])
# Plot temperature and velocity at each time step (side by side)
for i in range(n_time):
T_phys = T_reduced_all[i] * self.T_red
V_phys = V_reduced_all[i]
t_val = self.t_eval[i]
t_physical = t_val * self.t_red
props = self._compute_material_properties_at_t(T_phys, t_val)
# Temperature subplot (left column)
ax_T = fig.add_subplot(gs[i, 0])
ax_T.plot(self.x_eval, T_phys, 'b-', linewidth=2.5, label='CS-PINN')
# Plot reference temperature if available
T_ref_func = self.T_ref_func_list[i]
if T_ref_func is not None:
T_ref = T_ref_func(self.x_eval*self.R)
ax_T.plot(self.x_eval, T_ref, 'r--', linewidth=2.0, label='Reference', alpha=0.8)
error_T = np.abs(T_phys.flatten() - T_ref.flatten())
relative_error = error_T / (T_ref.flatten() + 1e-10) * 100
max_error = error_T.max()
mean_error = error_T.mean()
max_rel_error = relative_error.max()
rel_l2_error = calc_relative_l2_err(T_ref, T_phys)
info_text_T = (f'Max Error: {max_error:.1f} K\n'
f'Mean Error: {mean_error:.1f} K\n'
f'Max Rel Error: {max_rel_error:.2f}%\n'
f'Rel L2 Error: {rel_l2_error:.5g}')
else:
info_text_T = (f'Max T: {props["max_T"]:.0f} K\n'
f'Center: {props["T_center"]:.0f} K')
ax_T.set_xlabel('Normalized radius r/R', fontsize=11)
ax_T.set_ylabel('Temperature (K)', fontsize=11)
ax_T.set_title(f'Temperature @ t = {t_physical*1e3:.2f} ms', fontsize=12, fontweight='bold')
ax_T.grid(True, alpha=0.3)
ax_T.legend(loc='best', fontsize=10)
# Add info box for temperature
ax_T.text(0.05, 0.05, info_text_T,
transform=ax_T.transAxes, fontsize=9, verticalalignment='bottom',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.6))
# Velocity subplot (right column, same row)
ax_V = fig.add_subplot(gs[i, 1])
ax_V.plot(self.x_eval, V_phys, 'g-', linewidth=2.5, label='CS-PINN')
V_ref_func = self.V_ref_func_list[i]
if V_ref_func is not None:
V_ref = V_ref_func(self.x_eval*self.R)
ax_V.plot(self.x_eval, V_ref, 'r--', linewidth=2.0, label='Reference', alpha=0.8)
error_V = np.abs(V_phys.flatten() - V_ref.flatten())
relative_error = error_V / (np.abs(V_ref.flatten()) + 1e-10) * 100
max_error = error_V.max()
mean_error = error_V.mean()
max_rel_error = relative_error.max()
rel_l2_error = calc_relative_l2_err(V_ref, V_phys)
info_text_V = (f'Max Error: {max_error:.3g} m/s\n'
f'Mean Error: {mean_error:.3g} m/s\n'
f'Max Rel Error: {max_rel_error:.2f}%\n'
f'Rel L2 Error: {rel_l2_error:.5g}')
else:
info_text_V = f'Max: {V_phys.max():.3g} m/s'
ax_V.set_xlabel('Normalized radius r/R', fontsize=11)
ax_V.set_ylabel('Velocity (m/s)', fontsize=11)
ax_V.set_title(f'Velocity @ t = {t_physical*1e3:.2f} ms', fontsize=12, fontweight='bold')
ax_V.grid(True, alpha=0.3)
ax_V.legend(loc='best', fontsize=10)
ax_V.text(0.05, 0.05, info_text_V,
transform=ax_V.transAxes, fontsize=9, verticalalignment='bottom',
bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.6))
# Add loss curve panel spanning the entire bottom row
ax_loss = fig.add_subplot(gs[n_time, :]) # Span all columns in last row
if self.history['losses']:
loss_epochs = self.history['epochs']
loss_values = self.history['losses']
ax_loss.semilogy(loss_epochs, loss_values, 'purple', linewidth=2.5, marker='o', markersize=4)
ax_loss.set_xlabel('Epoch', fontsize=11)
ax_loss.set_ylabel('Loss (log scale)', fontsize=11)
ax_loss.set_title('Training Loss Convergence', fontsize=12, fontweight='bold')
ax_loss.grid(True, alpha=0.3, which='both')
current_loss = kwargs.get('total_loss', None)
if current_loss is not None:
if isinstance(current_loss, torch.Tensor):
loss_val = current_loss.item()
else:
loss_val = float(current_loss)
ax_loss.text(0.98, 0.95, f'Current loss: {loss_val:.2e}',
transform=ax_loss.transAxes, fontsize=10, verticalalignment='top',
horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
else:
ax_loss.text(0.5, 0.5, 'No loss history yet',
ha='center', va='center', transform=ax_loss.transAxes,
fontsize=12, style='italic', color='gray')
ax_loss.set_title('Training Loss', fontsize=12, fontweight='bold')
fig.suptitle(f'Arc 1D Transient PINN Model (T & V) - Epoch {epoch}',
fontsize=14, fontweight='bold', y=0.995)
return fig
[docs]
def visualize(self, network, epoch: int, writer: SummaryWriter, **kwargs):
"""
Generate visualization plots for the current training epoch.
Parameters:
-----------
network : nn.Module
The neural network being trained
epoch : int
Current training epoch number
writer : SummaryWriter
TensorBoard writer for logging figures
kwargs : dict
Additional training information (e.g., 'total_loss')
Returns:
--------
dict
Dictionary mapping plot names to matplotlib figures
"""
network.eval()
# Generate predictions on evaluation grid at all time points
T_reduced_all = []
V_reduced_all = []
with torch.no_grad():
for t_val in self.t_eval:
# Create (r, t) pairs for this time step
x_r = self.x_eval.reshape(-1, 1)
t_reps = np.repeat(t_val, len(x_r)).astype(REAL()).reshape(-1, 1)
xt_grid = np.hstack([x_r, t_reps])
# Predict
T_reduced, V_reduced = network(numpy2torch(xt_grid, require_grad=False))
T_reduced_all.append(T_reduced.cpu().numpy())
V_reduced_all.append(V_reduced.cpu().numpy())
T_reduced_all = np.array(T_reduced_all) # shape: (n_time, n_r)
V_reduced_all = np.array(V_reduced_all) # shape: (n_time, n_r)
# Save history for animation
if self.save_history and epoch % self.history_freq == 0:
self.history['epochs'].append(epoch)
self.history['T'].append(T_reduced_all.copy())
self.history['V'].append(V_reduced_all.copy())
# Track center temperature at each time point
T_center_at_times = T_reduced_all[:, 0]
self.history['T_center_t'].append(T_center_at_times)
# Extract loss
total_loss = kwargs.get('total_loss', None)
if total_loss is not None:
if isinstance(total_loss, torch.Tensor):
self.history['losses'].append(total_loss.item())
else:
self.history['losses'].append(float(total_loss))
# Create figure
fig = self._make_figure(epoch=epoch, T_reduced_all=T_reduced_all, V_reduced_all=V_reduced_all, **kwargs)
# Optionally save frame for GIF
if self.gif_enabled and (epoch % self.gif_freq == 0):
frame_path = os.path.join(self.gif_tmp_dir, f'epoch_{epoch:06d}.png')
try:
fig.savefig(frame_path, dpi=120)
self._gif_frames.append(frame_path)
except Exception as e:
print(f'Warning: failed to save GIF frame at epoch {epoch}: {e}')
return {'transient_visualization': fig}
[docs]
def save_gif(self, gif_path: str = None, duration_ms: int = None, loop: int = 0):
"""
Assemble saved frames into a GIF showing transient temperature evolution and loss.
Parameters:
-----------
gif_path : str, optional
Output path for the GIF; defaults to <gif_dir>/training_animation_transient.gif
duration_ms : int, optional
Per-frame duration in milliseconds; defaults to gif_duration_ms
loop : int, default=0
Number of loops (0 = infinite)
"""
if not self.gif_enabled:
print('GIF not enabled; nothing to save.')
return
if len(self._gif_frames) == 0:
print('No frames collected; GIF not created.')
return
out_path = gif_path if gif_path is not None else os.path.join(self.gif_dir, self.gif_filename)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
frames_sorted = sorted(self._gif_frames)
try:
img2gif(frames_sorted, out_path, duration=(duration_ms or self.gif_duration_ms), loop=loop)
print(f'GIF saved: {out_path}')
except Exception as e:
print(f'Failed to create GIF: {e}')
return
if self.gif_cleanup_tmp:
try:
for f in frames_sorted:
if os.path.exists(f):
os.remove(f)
if os.path.isdir(self.gif_tmp_dir) and len(os.listdir(self.gif_tmp_dir)) == 0:
os.rmdir(self.gif_tmp_dir)
except Exception as e:
print(f'Warning: failed cleanup of tmp frames: {e}')
[docs]
def save_final_results(self, network: nn.Module, save_dir: str = None, epoch: int = None, **kwargs):
"""
Save final result figures showing temperature evolution and loss curve.
Outputs:
- final_panels_transient.png: multi-panel figure with temperature at different times and loss curve
- loss_curve_transient.png: standalone loss curve plot
- center_temp_evolution.png: center temperature evolution over time at different epochs
Parameters:
-----------
network : nn.Module
Trained network for final prediction
save_dir : str, optional
Output directory; defaults to gif_dir
epoch : int, optional
Epoch number for title; if None, uses last recorded or 0
kwargs : dict
Optional training info
"""
out_dir = save_dir or self.gif_dir
os.makedirs(out_dir, exist_ok=True)
# Generate final predictions
network.eval()
T_reduced_all = []
V_reduced_all = []
with torch.no_grad():
for t_val in self.t_eval:
x_r = self.x_eval.reshape(-1, 1)
t_reps = np.repeat(t_val, len(x_r)).astype(REAL()).reshape(-1, 1)
xt_grid = np.hstack([x_r, t_reps])
T_reduced, V_reduced = network(numpy2torch(xt_grid, require_grad=False))
T_reduced_all.append(T_reduced.cpu().numpy())
V_reduced_all.append(V_reduced.cpu().numpy())
T_reduced_all = np.array(T_reduced_all)
V_reduced_all = np.array(V_reduced_all)
# Determine epoch label
if epoch is None:
epoch = self.history['epochs'][-1] if self.history['epochs'] else 0
# Save multi-panel figure (temperature & velocity profiles + loss curve)
fig = self._make_figure(epoch=epoch, T_reduced_all=T_reduced_all, V_reduced_all=V_reduced_all, **kwargs)
panels_path = os.path.join(out_dir, 'final_panels.png')
try:
fig.savefig(panels_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Final panels saved: {panels_path}')
except Exception as e:
print(f'Failed to save final panels: {e}')
# Save standalone loss curve
if self.history['losses']:
plt.figure(figsize=(7, 5))
plt.semilogy(self.history['epochs'], self.history['losses'], 'purple',
linewidth=2.0, marker='o', markersize=4)
plt.xlabel('Epoch', fontsize=11)
plt.ylabel('Loss (log scale)', fontsize=11)
plt.title('Training Loss Curve (Transient Arc Model)', fontsize=12, fontweight='bold')
plt.grid(True, alpha=0.3, which='both')
loss_path = os.path.join(out_dir, 'loss_curve.png')
try:
plt.savefig(loss_path, dpi=150, bbox_inches='tight')
print(f'Loss curve saved: {loss_path}')
except Exception as e:
print(f'Failed to save loss curve: {e}')
finally:
plt.close()
else:
print('No loss history recorded; loss_curve.png not created.')