Dev Update (#582)

* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <dariocos99@gmail.com>

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
Dario Coscia
2025-06-13 17:34:37 +02:00
committed by GitHub
parent 6b355b45de
commit 7bf7d34d0f
40 changed files with 1963 additions and 581 deletions

View File

@@ -2,13 +2,13 @@
__all__ = [
"SwitchOptimizer",
"R3Refinement",
"MetricTracker",
"PINAProgressBar",
"LinearWeightUpdate",
"R3Refinement",
]
from .optimizer_callback import SwitchOptimizer
from .adaptive_refinement_callback import R3Refinement
from .processing_callback import MetricTracker, PINAProgressBar
from .linear_weight_update_callback import LinearWeightUpdate
from .refinement import R3Refinement

View File

@@ -1,181 +0,0 @@
"""Module for the R3Refinement callback."""
import importlib.metadata
import torch
from lightning.pytorch.callbacks import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency
class R3Refinement(Callback):
"""
PINA Implementation of an R3 Refinement Callback.
"""
def __init__(self, sample_every):
"""
This callback implements the R3 (Retain-Resample-Release) routine for
sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those with low residuals.
Points are sampled uniformly in all regions where sampling is needed.
.. seealso::
Original Reference: Daw, Arka, et al. *Mitigating Propagation
Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_
:param int sample_every: Frequency for sampling.
:raises ValueError: If `sample_every` is not an integer.
Example:
>>> r3_callback = R3Refinement(sample_every=5)
"""
raise NotImplementedError(
"R3Refinement callback is being refactored in the pina "
f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
"version. Please use version 0.1 if R3Refinement is required."
)
# super().__init__()
# # sample every
# check_consistency(sample_every, int)
# self._sample_every = sample_every
# self._const_pts = None
# def _compute_residual(self, trainer):
# """
# Computes the residuals for a PINN object.
# :return: the total loss, and pointwise loss.
# :rtype: tuple
# """
# # extract the solver and device from trainer
# solver = trainer.solver
# device = trainer._accelerator_connector._accelerator_flag
# precision = trainer.precision
# if precision == "64-true":
# precision = torch.float64
# elif precision == "32-true":
# precision = torch.float32
# else:
# raise RuntimeError(
# "Currently R3Refinement is only implemented "
# "for precision '32-true' and '64-true', set "
# "Trainer precision to match one of the "
# "available precisions."
# )
# # compute residual
# res_loss = {}
# tot_loss = []
# for location in self._sampling_locations:
# condition = solver.problem.conditions[location]
# pts = solver.problem.input_pts[location]
# # send points to correct device
# pts = pts.to(device=device, dtype=precision)
# pts = pts.requires_grad_(True)
# pts.retain_grad()
# # PINN loss: equation evaluated only for sampling locations
# target = condition.equation.residual(pts, solver.forward(pts))
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
# tot_loss.append(torch.abs(target))
# print(tot_loss)
# return torch.vstack(tot_loss), res_loss
# def _r3_routine(self, trainer):
# """
# R3 refinement main routine.
# :param Trainer trainer: PINA Trainer.
# """
# # compute residual (all device possible)
# tot_loss, res_loss = self._compute_residual(trainer)
# tot_loss = tot_loss.as_subclass(torch.Tensor)
# # !!!!!! From now everything is performed on CPU !!!!!!
# # average loss
# avg = (tot_loss.mean()).to("cpu")
# old_pts = {} # points to be retained
# for location in self._sampling_locations:
# pts = trainer._model.problem.input_pts[location]
# labels = pts.labels
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
# residuals = res_loss[location].cpu()
# mask = (residuals > avg).flatten()
# if any(mask): # append residuals greater than average
# pts = (pts[mask]).as_subclass(LabelTensor)
# pts.labels = labels
# old_pts[location] = pts
# numb_pts = self._const_pts[location] - len(old_pts[location])
# # sample new points
# trainer._model.problem.discretise_domain(
# numb_pts, "random", locations=[location]
# )
# else: # if no res greater than average, samples all uniformly
# numb_pts = self._const_pts[location]
# # sample new points
# trainer._model.problem.discretise_domain(
# numb_pts, "random", locations=[location]
# )
# # adding previous population points
# trainer._model.problem.add_points(old_pts)
# # update dataloader
# trainer._create_or_update_loader()
# def on_train_start(self, trainer, _):
# """
# Callback function called at the start of training.
# This method extracts the locations for sampling from the problem
# conditions and calculates the total population.
# :param trainer: The trainer object managing the training process.
# :type trainer: pytorch_lightning.Trainer
# :param _: Placeholder argument (not used).
# :return: None
# :rtype: None
# """
# # extract locations for sampling
# problem = trainer.solver.problem
# locations = []
# for condition_name in problem.conditions:
# condition = problem.conditions[condition_name]
# if hasattr(condition, "location"):
# locations.append(condition_name)
# self._sampling_locations = locations
# # extract total population
# const_pts = {} # for each location, store the pts to keep constant
# for location in self._sampling_locations:
# pts = trainer._model.problem.input_pts[location]
# const_pts[location] = len(pts)
# self._const_pts = const_pts
# def on_train_epoch_end(self, trainer, __):
# """
# Callback function called at the end of each training epoch.
# This method triggers the R3 routine for refinement if the current
# epoch is a multiple of `_sample_every`.
# :param trainer: The trainer object managing the training process.
# :type trainer: pytorch_lightning.Trainer
# :param __: Placeholder argument (not used).
# :return: None
# :rtype: None
# """
# if trainer.current_epoch % self._sample_every == 0:
# self._r3_routine(trainer)

View File

@@ -0,0 +1,11 @@
"""
Module for Pina Refinement callbacks.
"""
__all__ = [
"RefinementInterface",
"R3Refinement",
]
from .refinement_interface import RefinementInterface
from .r3_refinement import R3Refinement

View File

@@ -0,0 +1,88 @@
"""Module for the R3Refinement callback."""
import torch
from torch import nn
from torch.nn.modules.loss import _Loss
from .refinement_interface import RefinementInterface
from ...label_tensor import LabelTensor
from ...utils import check_consistency
from ...loss import LossInterface
class R3Refinement(RefinementInterface):
"""
PINA Implementation of an R3 Refinement Callback.
"""
def __init__(
self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None
):
"""
This callback implements the R3 (Retain-Resample-Release) routine for
sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those with low residuals.
Points are sampled uniformly in all regions where sampling is needed.
.. seealso::
Original Reference: Daw, Arka, et al. *Mitigating Propagation
Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_
:param int sample_every: Frequency for sampling.
:param loss: Loss function
:type loss: LossInterface | ~torch.nn.modules.loss._Loss
:param condition_to_update: The conditions to update during the
refinement process. If None, all conditions with a conditions will
be updated. Default is None.
:type condition_to_update: list(str) | tuple(str) | str
:raises ValueError: If the condition_to_update is not a string or
iterable of strings.
:raises TypeError: If the residual_loss is not a subclass of
torch.nn.Module.
Example:
>>> r3_callback = R3Refinement(sample_every=5)
"""
super().__init__(sample_every, condition_to_update)
# check consistency loss
check_consistency(residual_loss, (LossInterface, _Loss), subclass=True)
self.loss_fn = residual_loss(reduction="none")
def sample(self, current_points, condition_name, solver):
"""
Sample new points based on the R3 refinement strategy.
:param current_points: Current points in the domain.
:param condition_name: Name of the condition to update.
:param PINNInterface solver: The solver object.
:return: New points sampled based on the R3 strategy.
:rtype: LabelTensor
"""
# Compute residuals for the given condition (average over fields)
condition = solver.problem.conditions[condition_name]
target = solver.compute_residual(
current_points.requires_grad_(True), condition.equation
)
residuals = self.loss_fn(target, torch.zeros_like(target)).mean(
dim=tuple(range(1, target.ndim))
)
# Prepare new points
labels = current_points.labels
domain_name = solver.problem.conditions[condition_name].domain
domain = solver.problem.domains[domain_name]
num_old_points = self.initial_population_size[condition_name]
mask = (residuals > residuals.mean()).flatten()
if mask.any(): # Use high-residual points
pts = current_points[mask]
pts.labels = labels
retain_pts = len(pts)
samples = domain.sample(num_old_points - retain_pts, "random")
return LabelTensor.cat([pts, samples])
return domain.sample(num_old_points, "random")

View File

@@ -0,0 +1,155 @@
"""
RefinementInterface class for handling the refinement of points in a neural
network training process.
"""
from abc import ABCMeta, abstractmethod
from lightning.pytorch import Callback
from ...utils import check_consistency
from ...solver.physics_informed_solver import PINNInterface
class RefinementInterface(Callback, metaclass=ABCMeta):
"""
Interface class of Refinement approaches.
"""
def __init__(self, sample_every, condition_to_update=None):
"""
Initializes the RefinementInterface.
:param int sample_every: The number of epochs between each refinement.
:param condition_to_update: The conditions to update during the
refinement process. If None, all conditions with a domain will be
updated. Default is None.
:type condition_to_update: list(str) | tuple(str) | str
"""
# check consistency of the input
check_consistency(sample_every, int)
if condition_to_update is not None:
if isinstance(condition_to_update, str):
condition_to_update = [condition_to_update]
if not isinstance(condition_to_update, (list, tuple)):
raise ValueError(
"'condition_to_update' must be iter of strings."
)
check_consistency(condition_to_update, str)
# store
self.sample_every = sample_every
self._condition_to_update = condition_to_update
self._dataset = None
self._initial_population_size = None
def on_train_start(self, trainer, solver):
"""
Called when the training begins. It initializes the conditions and
dataset.
:param ~lightning.pytorch.trainer.trainer.Trainer trainer: The trainer
object.
:param ~pina.solver.solver.SolverInterface solver: The solver
object associated with the trainer.
:raises RuntimeError: If the solver is not a PINNInterface.
:raises RuntimeError: If the conditions do not have a domain to sample
from.
"""
# check we have valid conditions names
if self._condition_to_update is None:
self._condition_to_update = [
name
for name, cond in solver.problem.conditions.items()
if hasattr(cond, "domain")
]
for cond in self._condition_to_update:
if cond not in solver.problem.conditions:
raise RuntimeError(
f"Condition '{cond}' not found in "
f"{list(solver.problem.conditions.keys())}."
)
if not hasattr(solver.problem.conditions[cond], "domain"):
raise RuntimeError(
f"Condition '{cond}' does not contain a domain to "
"sample from."
)
# check solver
if not isinstance(solver, PINNInterface):
raise RuntimeError(
"Refinment strategies are currently implemented only "
"for physics informed based solvers. Please use a Solver "
"inheriting from 'PINNInterface'."
)
# store dataset
self._dataset = trainer.datamodule.train_dataset
# compute initial population size
self._initial_population_size = self._compute_population_size(
self._condition_to_update
)
return super().on_train_epoch_start(trainer, solver)
def on_train_epoch_end(self, trainer, solver):
"""
Performs the refinement at the end of each training epoch (if needed).
:param ~lightning.pytorch.trainer.trainer.Trainer: The trainer object.
:param PINNInterface solver: The solver object.
"""
if (trainer.current_epoch % self.sample_every == 0) and (
trainer.current_epoch != 0
):
self._update_points(solver)
return super().on_train_epoch_end(trainer, solver)
@abstractmethod
def sample(self, current_points, condition_name, solver):
"""
Samples new points based on the condition.
:param current_points: Current points in the domain.
:param condition_name: Name of the condition to update.
:param PINNInterface solver: The solver object.
:return: New points sampled based on the R3 strategy.
:rtype: LabelTensor
"""
@property
def dataset(self):
"""
Returns the dataset for training.
"""
return self._dataset
@property
def initial_population_size(self):
"""
Returns the dataset for training size.
"""
return self._initial_population_size
def _update_points(self, solver):
"""
Performs the refinement of the points.
:param PINNInterface solver: The solver object.
"""
new_points = {}
for name in self._condition_to_update:
current_points = self.dataset.conditions_dict[name]["input"]
new_points[name] = {
"input": self.sample(current_points, name, solver)
}
self.dataset.update_data(new_points)
def _compute_population_size(self, conditions):
"""
Computes the number of points in the dataset for each condition.
:param conditions: List of conditions to compute the number of points.
:return: Dictionary with the population size for each condition.
:rtype: dict
"""
return {
cond: len(self.dataset.conditions_dict[cond]["input"])
for cond in conditions
}