Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
@@ -122,6 +122,17 @@ Blocks
|
||||
Continuous Convolution Block <model/block/convolution.rst>
|
||||
Orthogonal Block <model/block/orthogonal.rst>
|
||||
|
||||
Message Passing
|
||||
-------------------
|
||||
|
||||
.. toctree::
|
||||
:titlesonly:
|
||||
|
||||
Deep Tensor Network Block <model/block/message_passing/deep_tensor_network_block.rst>
|
||||
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
|
||||
Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
|
||||
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
|
||||
|
||||
|
||||
Reduction and Embeddings
|
||||
--------------------------
|
||||
@@ -238,7 +249,8 @@ Callbacks
|
||||
|
||||
Processing callback <callback/processing_callback.rst>
|
||||
Optimizer callback <callback/optimizer_callback.rst>
|
||||
Refinment callback <callback/adaptive_refinment_callback.rst>
|
||||
R3 Refinment callback <callback/refinement/r3_refinement.rst>
|
||||
Refinment Interface callback <callback/refinement/refinement_interface.rst>
|
||||
Weighting callback <callback/linear_weight_update_callback.rst>
|
||||
|
||||
Losses and Weightings
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
Refinments callbacks
|
||||
=======================
|
||||
|
||||
.. currentmodule:: pina.callback.adaptive_refinement_callback
|
||||
.. currentmodule:: pina.callback.refinement
|
||||
.. autoclass:: R3Refinement
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -0,0 +1,7 @@
|
||||
Refinement Interface
|
||||
=======================
|
||||
|
||||
.. currentmodule:: pina.callback.refinement
|
||||
.. autoclass:: RefinementInterface
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -0,0 +1,8 @@
|
||||
Deep Tensor Network Block
|
||||
==================================
|
||||
.. currentmodule:: pina.model.block.message_passing.deep_tensor_network_block
|
||||
|
||||
.. autoclass:: DeepTensorNetworkBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
@@ -0,0 +1,8 @@
|
||||
E(n) Equivariant Network Block
|
||||
==================================
|
||||
.. currentmodule:: pina.model.block.message_passing.en_equivariant_network_block
|
||||
|
||||
.. autoclass:: EnEquivariantNetworkBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
@@ -0,0 +1,8 @@
|
||||
Interaction Network Block
|
||||
==================================
|
||||
.. currentmodule:: pina.model.block.message_passing.interaction_network_block
|
||||
|
||||
.. autoclass:: InteractionNetworkBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
@@ -0,0 +1,8 @@
|
||||
Radial Field Network Block
|
||||
==================================
|
||||
.. currentmodule:: pina.model.block.message_passing.radial_field_network_block
|
||||
|
||||
.. autoclass:: RadialFieldNetworkBlock
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:noindex:
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
__all__ = [
|
||||
"SwitchOptimizer",
|
||||
"R3Refinement",
|
||||
"MetricTracker",
|
||||
"PINAProgressBar",
|
||||
"LinearWeightUpdate",
|
||||
"R3Refinement",
|
||||
]
|
||||
|
||||
from .optimizer_callback import SwitchOptimizer
|
||||
from .adaptive_refinement_callback import R3Refinement
|
||||
from .processing_callback import MetricTracker, PINAProgressBar
|
||||
from .linear_weight_update_callback import LinearWeightUpdate
|
||||
from .refinement import R3Refinement
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
"""Module for the R3Refinement callback."""
|
||||
|
||||
import importlib.metadata
|
||||
import torch
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class R3Refinement(Callback):
|
||||
"""
|
||||
PINA Implementation of an R3 Refinement Callback.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_every):
|
||||
"""
|
||||
This callback implements the R3 (Retain-Resample-Release) routine for
|
||||
sampling new points based on adaptive search.
|
||||
The algorithm incrementally accumulates collocation points in regions
|
||||
of high PDE residuals, and releases those with low residuals.
|
||||
Points are sampled uniformly in all regions where sampling is needed.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Original Reference: Daw, Arka, et al. *Mitigating Propagation
|
||||
Failures in Physics-informed Neural Networks
|
||||
using Retain-Resample-Release (R3) Sampling. (2023)*.
|
||||
DOI: `10.48550/arXiv.2207.02338
|
||||
<https://doi.org/10.48550/arXiv.2207.02338>`_
|
||||
|
||||
:param int sample_every: Frequency for sampling.
|
||||
:raises ValueError: If `sample_every` is not an integer.
|
||||
|
||||
Example:
|
||||
>>> r3_callback = R3Refinement(sample_every=5)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"R3Refinement callback is being refactored in the pina "
|
||||
f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
|
||||
"version. Please use version 0.1 if R3Refinement is required."
|
||||
)
|
||||
|
||||
# super().__init__()
|
||||
|
||||
# # sample every
|
||||
# check_consistency(sample_every, int)
|
||||
# self._sample_every = sample_every
|
||||
# self._const_pts = None
|
||||
|
||||
# def _compute_residual(self, trainer):
|
||||
# """
|
||||
# Computes the residuals for a PINN object.
|
||||
|
||||
# :return: the total loss, and pointwise loss.
|
||||
# :rtype: tuple
|
||||
# """
|
||||
|
||||
# # extract the solver and device from trainer
|
||||
# solver = trainer.solver
|
||||
# device = trainer._accelerator_connector._accelerator_flag
|
||||
# precision = trainer.precision
|
||||
# if precision == "64-true":
|
||||
# precision = torch.float64
|
||||
# elif precision == "32-true":
|
||||
# precision = torch.float32
|
||||
# else:
|
||||
# raise RuntimeError(
|
||||
# "Currently R3Refinement is only implemented "
|
||||
# "for precision '32-true' and '64-true', set "
|
||||
# "Trainer precision to match one of the "
|
||||
# "available precisions."
|
||||
# )
|
||||
|
||||
# # compute residual
|
||||
# res_loss = {}
|
||||
# tot_loss = []
|
||||
# for location in self._sampling_locations:
|
||||
# condition = solver.problem.conditions[location]
|
||||
# pts = solver.problem.input_pts[location]
|
||||
# # send points to correct device
|
||||
# pts = pts.to(device=device, dtype=precision)
|
||||
# pts = pts.requires_grad_(True)
|
||||
# pts.retain_grad()
|
||||
# # PINN loss: equation evaluated only for sampling locations
|
||||
# target = condition.equation.residual(pts, solver.forward(pts))
|
||||
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||
# tot_loss.append(torch.abs(target))
|
||||
|
||||
# print(tot_loss)
|
||||
|
||||
# return torch.vstack(tot_loss), res_loss
|
||||
|
||||
# def _r3_routine(self, trainer):
|
||||
# """
|
||||
# R3 refinement main routine.
|
||||
|
||||
# :param Trainer trainer: PINA Trainer.
|
||||
# """
|
||||
# # compute residual (all device possible)
|
||||
# tot_loss, res_loss = self._compute_residual(trainer)
|
||||
# tot_loss = tot_loss.as_subclass(torch.Tensor)
|
||||
|
||||
# # !!!!!! From now everything is performed on CPU !!!!!!
|
||||
|
||||
# # average loss
|
||||
# avg = (tot_loss.mean()).to("cpu")
|
||||
# old_pts = {} # points to be retained
|
||||
# for location in self._sampling_locations:
|
||||
# pts = trainer._model.problem.input_pts[location]
|
||||
# labels = pts.labels
|
||||
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
||||
# residuals = res_loss[location].cpu()
|
||||
# mask = (residuals > avg).flatten()
|
||||
# if any(mask): # append residuals greater than average
|
||||
# pts = (pts[mask]).as_subclass(LabelTensor)
|
||||
# pts.labels = labels
|
||||
# old_pts[location] = pts
|
||||
# numb_pts = self._const_pts[location] - len(old_pts[location])
|
||||
# # sample new points
|
||||
# trainer._model.problem.discretise_domain(
|
||||
# numb_pts, "random", locations=[location]
|
||||
# )
|
||||
|
||||
# else: # if no res greater than average, samples all uniformly
|
||||
# numb_pts = self._const_pts[location]
|
||||
# # sample new points
|
||||
# trainer._model.problem.discretise_domain(
|
||||
# numb_pts, "random", locations=[location]
|
||||
# )
|
||||
# # adding previous population points
|
||||
# trainer._model.problem.add_points(old_pts)
|
||||
|
||||
# # update dataloader
|
||||
# trainer._create_or_update_loader()
|
||||
|
||||
# def on_train_start(self, trainer, _):
|
||||
# """
|
||||
# Callback function called at the start of training.
|
||||
|
||||
# This method extracts the locations for sampling from the problem
|
||||
# conditions and calculates the total population.
|
||||
|
||||
# :param trainer: The trainer object managing the training process.
|
||||
# :type trainer: pytorch_lightning.Trainer
|
||||
# :param _: Placeholder argument (not used).
|
||||
|
||||
# :return: None
|
||||
# :rtype: None
|
||||
# """
|
||||
# # extract locations for sampling
|
||||
# problem = trainer.solver.problem
|
||||
# locations = []
|
||||
# for condition_name in problem.conditions:
|
||||
# condition = problem.conditions[condition_name]
|
||||
# if hasattr(condition, "location"):
|
||||
# locations.append(condition_name)
|
||||
# self._sampling_locations = locations
|
||||
|
||||
# # extract total population
|
||||
# const_pts = {} # for each location, store the pts to keep constant
|
||||
# for location in self._sampling_locations:
|
||||
# pts = trainer._model.problem.input_pts[location]
|
||||
# const_pts[location] = len(pts)
|
||||
# self._const_pts = const_pts
|
||||
|
||||
# def on_train_epoch_end(self, trainer, __):
|
||||
# """
|
||||
# Callback function called at the end of each training epoch.
|
||||
|
||||
# This method triggers the R3 routine for refinement if the current
|
||||
# epoch is a multiple of `_sample_every`.
|
||||
|
||||
# :param trainer: The trainer object managing the training process.
|
||||
# :type trainer: pytorch_lightning.Trainer
|
||||
# :param __: Placeholder argument (not used).
|
||||
|
||||
# :return: None
|
||||
# :rtype: None
|
||||
# """
|
||||
# if trainer.current_epoch % self._sample_every == 0:
|
||||
# self._r3_routine(trainer)
|
||||
11
pina/callback/refinement/__init__.py
Normal file
11
pina/callback/refinement/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Module for Pina Refinement callbacks.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"RefinementInterface",
|
||||
"R3Refinement",
|
||||
]
|
||||
|
||||
from .refinement_interface import RefinementInterface
|
||||
from .r3_refinement import R3Refinement
|
||||
88
pina/callback/refinement/r3_refinement.py
Normal file
88
pina/callback/refinement/r3_refinement.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Module for the R3Refinement callback."""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from .refinement_interface import RefinementInterface
|
||||
from ...label_tensor import LabelTensor
|
||||
from ...utils import check_consistency
|
||||
from ...loss import LossInterface
|
||||
|
||||
|
||||
class R3Refinement(RefinementInterface):
|
||||
"""
|
||||
PINA Implementation of an R3 Refinement Callback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None
|
||||
):
|
||||
"""
|
||||
This callback implements the R3 (Retain-Resample-Release) routine for
|
||||
sampling new points based on adaptive search.
|
||||
The algorithm incrementally accumulates collocation points in regions
|
||||
of high PDE residuals, and releases those with low residuals.
|
||||
Points are sampled uniformly in all regions where sampling is needed.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Original Reference: Daw, Arka, et al. *Mitigating Propagation
|
||||
Failures in Physics-informed Neural Networks
|
||||
using Retain-Resample-Release (R3) Sampling. (2023)*.
|
||||
DOI: `10.48550/arXiv.2207.02338
|
||||
<https://doi.org/10.48550/arXiv.2207.02338>`_
|
||||
|
||||
:param int sample_every: Frequency for sampling.
|
||||
:param loss: Loss function
|
||||
:type loss: LossInterface | ~torch.nn.modules.loss._Loss
|
||||
:param condition_to_update: The conditions to update during the
|
||||
refinement process. If None, all conditions with a conditions will
|
||||
be updated. Default is None.
|
||||
:type condition_to_update: list(str) | tuple(str) | str
|
||||
:raises ValueError: If the condition_to_update is not a string or
|
||||
iterable of strings.
|
||||
:raises TypeError: If the residual_loss is not a subclass of
|
||||
torch.nn.Module.
|
||||
|
||||
|
||||
Example:
|
||||
>>> r3_callback = R3Refinement(sample_every=5)
|
||||
"""
|
||||
super().__init__(sample_every, condition_to_update)
|
||||
# check consistency loss
|
||||
check_consistency(residual_loss, (LossInterface, _Loss), subclass=True)
|
||||
self.loss_fn = residual_loss(reduction="none")
|
||||
|
||||
def sample(self, current_points, condition_name, solver):
|
||||
"""
|
||||
Sample new points based on the R3 refinement strategy.
|
||||
|
||||
:param current_points: Current points in the domain.
|
||||
:param condition_name: Name of the condition to update.
|
||||
:param PINNInterface solver: The solver object.
|
||||
:return: New points sampled based on the R3 strategy.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# Compute residuals for the given condition (average over fields)
|
||||
condition = solver.problem.conditions[condition_name]
|
||||
target = solver.compute_residual(
|
||||
current_points.requires_grad_(True), condition.equation
|
||||
)
|
||||
residuals = self.loss_fn(target, torch.zeros_like(target)).mean(
|
||||
dim=tuple(range(1, target.ndim))
|
||||
)
|
||||
|
||||
# Prepare new points
|
||||
labels = current_points.labels
|
||||
domain_name = solver.problem.conditions[condition_name].domain
|
||||
domain = solver.problem.domains[domain_name]
|
||||
num_old_points = self.initial_population_size[condition_name]
|
||||
mask = (residuals > residuals.mean()).flatten()
|
||||
|
||||
if mask.any(): # Use high-residual points
|
||||
pts = current_points[mask]
|
||||
pts.labels = labels
|
||||
retain_pts = len(pts)
|
||||
samples = domain.sample(num_old_points - retain_pts, "random")
|
||||
return LabelTensor.cat([pts, samples])
|
||||
return domain.sample(num_old_points, "random")
|
||||
155
pina/callback/refinement/refinement_interface.py
Normal file
155
pina/callback/refinement/refinement_interface.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
RefinementInterface class for handling the refinement of points in a neural
|
||||
network training process.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from lightning.pytorch import Callback
|
||||
from ...utils import check_consistency
|
||||
from ...solver.physics_informed_solver import PINNInterface
|
||||
|
||||
|
||||
class RefinementInterface(Callback, metaclass=ABCMeta):
|
||||
"""
|
||||
Interface class of Refinement approaches.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_every, condition_to_update=None):
|
||||
"""
|
||||
Initializes the RefinementInterface.
|
||||
|
||||
:param int sample_every: The number of epochs between each refinement.
|
||||
:param condition_to_update: The conditions to update during the
|
||||
refinement process. If None, all conditions with a domain will be
|
||||
updated. Default is None.
|
||||
:type condition_to_update: list(str) | tuple(str) | str
|
||||
|
||||
"""
|
||||
# check consistency of the input
|
||||
check_consistency(sample_every, int)
|
||||
if condition_to_update is not None:
|
||||
if isinstance(condition_to_update, str):
|
||||
condition_to_update = [condition_to_update]
|
||||
if not isinstance(condition_to_update, (list, tuple)):
|
||||
raise ValueError(
|
||||
"'condition_to_update' must be iter of strings."
|
||||
)
|
||||
check_consistency(condition_to_update, str)
|
||||
# store
|
||||
self.sample_every = sample_every
|
||||
self._condition_to_update = condition_to_update
|
||||
self._dataset = None
|
||||
self._initial_population_size = None
|
||||
|
||||
def on_train_start(self, trainer, solver):
|
||||
"""
|
||||
Called when the training begins. It initializes the conditions and
|
||||
dataset.
|
||||
|
||||
:param ~lightning.pytorch.trainer.trainer.Trainer trainer: The trainer
|
||||
object.
|
||||
:param ~pina.solver.solver.SolverInterface solver: The solver
|
||||
object associated with the trainer.
|
||||
:raises RuntimeError: If the solver is not a PINNInterface.
|
||||
:raises RuntimeError: If the conditions do not have a domain to sample
|
||||
from.
|
||||
"""
|
||||
# check we have valid conditions names
|
||||
if self._condition_to_update is None:
|
||||
self._condition_to_update = [
|
||||
name
|
||||
for name, cond in solver.problem.conditions.items()
|
||||
if hasattr(cond, "domain")
|
||||
]
|
||||
|
||||
for cond in self._condition_to_update:
|
||||
if cond not in solver.problem.conditions:
|
||||
raise RuntimeError(
|
||||
f"Condition '{cond}' not found in "
|
||||
f"{list(solver.problem.conditions.keys())}."
|
||||
)
|
||||
if not hasattr(solver.problem.conditions[cond], "domain"):
|
||||
raise RuntimeError(
|
||||
f"Condition '{cond}' does not contain a domain to "
|
||||
"sample from."
|
||||
)
|
||||
# check solver
|
||||
if not isinstance(solver, PINNInterface):
|
||||
raise RuntimeError(
|
||||
"Refinment strategies are currently implemented only "
|
||||
"for physics informed based solvers. Please use a Solver "
|
||||
"inheriting from 'PINNInterface'."
|
||||
)
|
||||
# store dataset
|
||||
self._dataset = trainer.datamodule.train_dataset
|
||||
# compute initial population size
|
||||
self._initial_population_size = self._compute_population_size(
|
||||
self._condition_to_update
|
||||
)
|
||||
return super().on_train_epoch_start(trainer, solver)
|
||||
|
||||
def on_train_epoch_end(self, trainer, solver):
|
||||
"""
|
||||
Performs the refinement at the end of each training epoch (if needed).
|
||||
|
||||
:param ~lightning.pytorch.trainer.trainer.Trainer: The trainer object.
|
||||
:param PINNInterface solver: The solver object.
|
||||
"""
|
||||
if (trainer.current_epoch % self.sample_every == 0) and (
|
||||
trainer.current_epoch != 0
|
||||
):
|
||||
self._update_points(solver)
|
||||
return super().on_train_epoch_end(trainer, solver)
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, current_points, condition_name, solver):
|
||||
"""
|
||||
Samples new points based on the condition.
|
||||
|
||||
:param current_points: Current points in the domain.
|
||||
:param condition_name: Name of the condition to update.
|
||||
:param PINNInterface solver: The solver object.
|
||||
:return: New points sampled based on the R3 strategy.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
"""
|
||||
Returns the dataset for training.
|
||||
"""
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def initial_population_size(self):
|
||||
"""
|
||||
Returns the dataset for training size.
|
||||
"""
|
||||
return self._initial_population_size
|
||||
|
||||
def _update_points(self, solver):
|
||||
"""
|
||||
Performs the refinement of the points.
|
||||
|
||||
:param PINNInterface solver: The solver object.
|
||||
"""
|
||||
new_points = {}
|
||||
for name in self._condition_to_update:
|
||||
current_points = self.dataset.conditions_dict[name]["input"]
|
||||
new_points[name] = {
|
||||
"input": self.sample(current_points, name, solver)
|
||||
}
|
||||
self.dataset.update_data(new_points)
|
||||
|
||||
def _compute_population_size(self, conditions):
|
||||
"""
|
||||
Computes the number of points in the dataset for each condition.
|
||||
|
||||
:param conditions: List of conditions to compute the number of points.
|
||||
:return: Dictionary with the population size for each condition.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {
|
||||
cond: len(self.dataset.conditions_dict[cond]["input"])
|
||||
for cond in conditions
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Module for the Collector class."""
|
||||
|
||||
from .graph import Graph
|
||||
from .utils import check_consistency
|
||||
|
||||
|
||||
class Collector:
|
||||
"""
|
||||
Collector class for retrieving data from different conditions in the
|
||||
problem.
|
||||
"""
|
||||
|
||||
def __init__(self, problem):
|
||||
"""
|
||||
Initialize the Collector class, by creating a hook between the collector
|
||||
and the problem and initializing the data collections (dictionary where
|
||||
data will be stored).
|
||||
|
||||
:param pina.problem.abstract_problem.AbstractProblem problem: The
|
||||
problem to collect data from.
|
||||
"""
|
||||
# creating a hook between collector and problem
|
||||
self.problem = problem
|
||||
|
||||
# those variables are used for the dataloading
|
||||
self._data_collections = {name: {} for name in self.problem.conditions}
|
||||
self.conditions_name = dict(enumerate(self.problem.conditions))
|
||||
|
||||
# variables used to check that all conditions are sampled
|
||||
self._is_conditions_ready = {
|
||||
name: False for name in self.problem.conditions
|
||||
}
|
||||
self.full = False
|
||||
|
||||
@property
|
||||
def full(self):
|
||||
"""
|
||||
Returns ``True`` if the collector is full. The collector is considered
|
||||
full if all conditions have entries in the ``data_collection``
|
||||
dictionary.
|
||||
|
||||
:return: ``True`` if all conditions are ready, ``False`` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
return all(self._is_conditions_ready.values())
|
||||
|
||||
@full.setter
|
||||
def full(self, value):
|
||||
"""
|
||||
Set the ``_full`` variable.
|
||||
|
||||
:param bool value: The value to set the ``_full`` variable.
|
||||
"""
|
||||
|
||||
check_consistency(value, bool)
|
||||
self._full = value
|
||||
|
||||
@property
|
||||
def data_collections(self):
|
||||
"""
|
||||
Return the data collections (dictionary where data is stored).
|
||||
|
||||
:return: The data collections where the data is stored.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return self._data_collections
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
Problem connected to the collector.
|
||||
|
||||
:return: The problem from which the data is collected.
|
||||
:rtype: pina.problem.abstract_problem.AbstractProblem
|
||||
"""
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, value):
|
||||
"""
|
||||
Set the problem connected to the collector.
|
||||
|
||||
:param pina.problem.abstract_problem.AbstractProblem value: The problem
|
||||
to connect to the collector.
|
||||
"""
|
||||
|
||||
self._problem = value
|
||||
|
||||
def store_fixed_data(self):
|
||||
"""
|
||||
Store inside data collections the fixed data of the problem. These comes
|
||||
from the conditions that do not require sampling.
|
||||
"""
|
||||
|
||||
# loop over all conditions
|
||||
for condition_name, condition in self.problem.conditions.items():
|
||||
# if the condition is not ready and domain is not attribute
|
||||
# of condition, we get and store the data
|
||||
if (not self._is_conditions_ready[condition_name]) and (
|
||||
not hasattr(condition, "domain")
|
||||
):
|
||||
# get data
|
||||
keys = condition.__slots__
|
||||
values = [getattr(condition, name) for name in keys]
|
||||
self.data_collections[condition_name] = dict(zip(keys, values))
|
||||
# condition now is ready
|
||||
self._is_conditions_ready[condition_name] = True
|
||||
|
||||
def store_sample_domains(self):
|
||||
"""
|
||||
Store inside data collections the sampled data of the problem. These
|
||||
comes from the conditions that require sampling (e.g.
|
||||
:class:`~pina.condition.domain_equation_condition.\
|
||||
DomainEquationCondition`).
|
||||
"""
|
||||
|
||||
for condition_name in self.problem.conditions:
|
||||
condition = self.problem.conditions[condition_name]
|
||||
if not hasattr(condition, "domain"):
|
||||
continue
|
||||
|
||||
samples = self.problem.discretised_domains[condition.domain]
|
||||
|
||||
self.data_collections[condition_name] = {
|
||||
"input": samples,
|
||||
"equation": condition.equation,
|
||||
}
|
||||
@@ -12,7 +12,6 @@ from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ..label_tensor import LabelTensor
|
||||
from .dataset import PinaDatasetFactory, PinaTensorDataset
|
||||
from ..collector import Collector
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
@@ -330,9 +329,7 @@ class PinaDataModule(LightningDataModule):
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
# Collect data
|
||||
collector = Collector(problem)
|
||||
collector.store_fixed_data()
|
||||
collector.store_sample_domains()
|
||||
problem.collect_data()
|
||||
|
||||
# Check if the splits are correct
|
||||
self._check_slit_sizes(train_size, test_size, val_size)
|
||||
@@ -361,7 +358,9 @@ class PinaDataModule(LightningDataModule):
|
||||
# raises NotImplementedError
|
||||
self.val_dataloader = super().val_dataloader
|
||||
|
||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||
self.data_splits = self._create_splits(
|
||||
problem.collected_data, splits_dict
|
||||
)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||
|
||||
def setup(self, stage=None):
|
||||
@@ -376,15 +375,15 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
if stage == "fit" or stage is None:
|
||||
self.train_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["train"],
|
||||
self.data_splits["train"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
"train"
|
||||
),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
if "val" in self.collector_splits.keys():
|
||||
if "val" in self.data_splits.keys():
|
||||
self.val_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["val"],
|
||||
self.data_splits["val"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
"val"
|
||||
),
|
||||
@@ -392,7 +391,7 @@ class PinaDataModule(LightningDataModule):
|
||||
)
|
||||
elif stage == "test":
|
||||
self.test_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["test"],
|
||||
self.data_splits["test"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths("test"),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
@@ -473,7 +472,7 @@ class PinaDataModule(LightningDataModule):
|
||||
for (
|
||||
condition_name,
|
||||
condition_dict,
|
||||
) in collector.data_collections.items():
|
||||
) in collector.items():
|
||||
len_data = len(condition_dict["input"])
|
||||
if self.shuffle:
|
||||
_apply_shuffle(condition_dict, len_data)
|
||||
@@ -540,7 +539,7 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.collector_splits[split].items():
|
||||
for k, v in self.data_splits[split].items():
|
||||
if self.batch_size is None:
|
||||
max_conditions_lengths[k] = len(v["input"])
|
||||
elif self.repeat:
|
||||
|
||||
@@ -239,6 +239,22 @@ class PinaTensorDataset(PinaDataset):
|
||||
"""
|
||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||
|
||||
def update_data(self, new_conditions_dict):
|
||||
"""
|
||||
Update the dataset with new data.
|
||||
This method is used to update the dataset with new data. It replaces
|
||||
the current data with the new data provided in the new_conditions_dict
|
||||
parameter.
|
||||
|
||||
:param dict new_conditions_dict: Dictionary containing the new data.
|
||||
:return: None
|
||||
"""
|
||||
for condition, data in new_conditions_dict.items():
|
||||
if condition in self.conditions_dict:
|
||||
self.conditions_dict[condition].update(data)
|
||||
else:
|
||||
self.conditions_dict[condition] = data
|
||||
|
||||
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
"""
|
||||
|
||||
13
pina/model/block/message_passing/__init__.py
Normal file
13
pina/model/block/message_passing/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Module for the message passing blocks of the graph neural models."""
|
||||
|
||||
__all__ = [
|
||||
"InteractionNetworkBlock",
|
||||
"DeepTensorNetworkBlock",
|
||||
"EnEquivariantNetworkBlock",
|
||||
"RadialFieldNetworkBlock",
|
||||
]
|
||||
|
||||
from .interaction_network_block import InteractionNetworkBlock
|
||||
from .deep_tensor_network_block import DeepTensorNetworkBlock
|
||||
from .en_equivariant_network_block import EnEquivariantNetworkBlock
|
||||
from .radial_field_network_block import RadialFieldNetworkBlock
|
||||
138
pina/model/block/message_passing/deep_tensor_network_block.py
Normal file
138
pina/model/block/message_passing/deep_tensor_network_block.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Module for the Deep Tensor Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from ....utils import check_positive_integer
|
||||
|
||||
|
||||
class DeepTensorNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the Deep Tensor Network block.
|
||||
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Schutt et al. in
|
||||
2017. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
linear transformation to the sender node features and the edge features,
|
||||
followed by a non-linear activation function. Messages are then aggregated
|
||||
using an aggregation scheme (e.g., sum, mean, min, max, or product).
|
||||
|
||||
The update step is performed by a simple addition of the incoming messages
|
||||
to the node features.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
|
||||
(2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*.
|
||||
Nature Communications 8, 13890 (2017).
|
||||
DOI: `<https://doi.org/10.1038/ncomms13890>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
activation=torch.nn.Tanh,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`DeepTensorNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.Tanh`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(edge_feature_dim, strict=True)
|
||||
|
||||
# Activation function
|
||||
self.activation = activation()
|
||||
|
||||
# Layer for processing node features
|
||||
self.node_layer = torch.nn.Linear(
|
||||
in_features=node_feature_dim,
|
||||
out_features=node_feature_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Layer for processing edge features
|
||||
self.edge_layer = torch.nn.Linear(
|
||||
in_features=edge_feature_dim,
|
||||
out_features=node_feature_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# Layer for computing the message
|
||||
self.message_layer = torch.nn.Linear(
|
||||
in_features=node_feature_dim,
|
||||
out_features=node_feature_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indeces.
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
|
||||
|
||||
def message(self, x_j, edge_attr):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Process node and edge features
|
||||
filter_node = self.node_layer(x_j)
|
||||
filter_edge = self.edge_layer(edge_attr)
|
||||
|
||||
# Compute the message to be passed
|
||||
message = self.message_layer(filter_node * filter_edge)
|
||||
|
||||
return self.activation(message)
|
||||
|
||||
def update(self, message, x):
|
||||
"""
|
||||
Update the node features with the received messages.
|
||||
|
||||
:param torch.Tensor message: The message to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return x + message
|
||||
229
pina/model/block/message_passing/en_equivariant_network_block.py
Normal file
229
pina/model/block/message_passing/en_equivariant_network_block.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Module for the E(n) Equivariant Graph Neural Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from torch_geometric.utils import degree
|
||||
from ....utils import check_positive_integer
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
class EnEquivariantNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the E(n) Equivariant Graph Neural Network block.
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Satorras et al. in
|
||||
2021. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
linear transformation to the sender node features and the edge features,
|
||||
together with the squared euclidean distance between the sender and
|
||||
recipient node positions, followed by a non-linear activation function.
|
||||
Messages are then aggregated using an aggregation scheme (e.g., sum, mean,
|
||||
min, max, or product).
|
||||
|
||||
The update step is performed by applying another MLP to the concatenation of
|
||||
the incoming messages and the node features. Here, also the node
|
||||
positions are updated by adding the incoming messages divided by the
|
||||
degree of the recipient node.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
|
||||
(2021). *E(n) Equivariant Graph Neural Networks.*
|
||||
In International Conference on Machine Learning.
|
||||
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
pos_dim,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
activation=torch.nn.SiLU,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`EnEquivariantNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
:param int pos_dim: The dimension of the position features.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
network. Default is 2.
|
||||
:param int n_update_layers: The number of layers in the update network.
|
||||
Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.SiLU`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `edge_feature_dim` is a negative integer.
|
||||
:raises AssertionError: If `pos_dim` is not a positive integer.
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_message_layers` is not a positive integer.
|
||||
:raises AssertionError: If `n_update_layers` is not a positive integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(edge_feature_dim, strict=False)
|
||||
check_positive_integer(pos_dim, strict=True)
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_message_layers, strict=True)
|
||||
check_positive_integer(n_update_layers, strict=True)
|
||||
|
||||
# Layer for computing the message
|
||||
self.message_net = FeedForward(
|
||||
input_dimensions=2 * node_feature_dim + edge_feature_dim + 1,
|
||||
output_dimensions=pos_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_message_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
# Layer for updating the node features
|
||||
self.update_feat_net = FeedForward(
|
||||
input_dimensions=node_feature_dim + pos_dim,
|
||||
output_dimensions=node_feature_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
# Layer for updating the node positions
|
||||
# The output dimension is set to 1 for equivariant updates
|
||||
self.update_pos_net = FeedForward(
|
||||
input_dimensions=pos_dim,
|
||||
output_dimensions=1,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, edge_index, edge_attr=None):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param pos: The euclidean coordinates of the nodes.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:param edge_attr: The edge attributes. Default is None.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The updated node features and node positions.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
return self.propagate(
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
|
||||
)
|
||||
|
||||
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_i: The node features of the recipient nodes.
|
||||
:type x_i: torch.Tensor | LabelTensor
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:param pos_i: The node coordinates of the recipient nodes.
|
||||
:type pos_i: torch.Tensor | LabelTensor
|
||||
:param pos_j: The node coordinates of the sender nodes.
|
||||
:type pos_j: torch.Tensor | LabelTensor
|
||||
:param edge_attr: The edge attributes.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
# Compute the euclidean distance between the sender and recipient nodes
|
||||
diff = pos_i - pos_j
|
||||
dist = torch.norm(diff, dim=-1, keepdim=True) ** 2
|
||||
|
||||
# Compute the message input
|
||||
if edge_attr is None:
|
||||
input_ = torch.cat((x_i, x_j, dist), dim=-1)
|
||||
else:
|
||||
input_ = torch.cat((x_i, x_j, dist, edge_attr), dim=-1)
|
||||
|
||||
# Compute the messages and their equivariant counterpart
|
||||
m_ij = self.message_net(input_)
|
||||
message = diff * self.update_pos_net(m_ij)
|
||||
|
||||
return message, m_ij
|
||||
|
||||
def aggregate(self, inputs, index, ptr=None, dim_size=None):
|
||||
"""
|
||||
Aggregate the messages at the nodes during message passing.
|
||||
|
||||
This method receives a tuple of tensors corresponding to the messages
|
||||
to be aggregated. Both messages are aggregated separately according to
|
||||
the specified aggregation scheme.
|
||||
|
||||
:param tuple(torch.Tensor) inputs: Tuple containing two messages to
|
||||
aggregate.
|
||||
:param index: The indices of target nodes for each message. This tensor
|
||||
specifies which node each message is aggregated into.
|
||||
:type index: torch.Tensor | LabelTensor
|
||||
:param ptr: Optional tensor to specify the slices of messages for each
|
||||
node (used in some aggregation strategies). Default is None.
|
||||
:type ptr: torch.Tensor | LabelTensor
|
||||
:param int dim_size: Optional size of the output dimension, i.e.,
|
||||
number of nodes. Default is None.
|
||||
:return: Tuple of aggregated tensors corresponding to (aggregated
|
||||
messages for position updates, aggregated messages for feature
|
||||
updates).
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
# Unpack the messages from the inputs
|
||||
message, m_ij = inputs
|
||||
|
||||
# Aggregate messages as usual using self.aggr method
|
||||
agg_message = super().aggregate(message, index, ptr, dim_size)
|
||||
agg_m_ij = super().aggregate(m_ij, index, ptr, dim_size)
|
||||
|
||||
return agg_message, agg_m_ij
|
||||
|
||||
def update(self, aggregated_inputs, x, pos, edge_index):
|
||||
"""
|
||||
Update the node features and the node coordinates with the received
|
||||
messages.
|
||||
|
||||
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param pos: The euclidean coordinates of the nodes.
|
||||
:type pos: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:return: The updated node features and node positions.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
# aggregated_inputs is tuple (agg_message, agg_m_ij)
|
||||
agg_message, agg_m_ij = aggregated_inputs
|
||||
|
||||
# Update node features with aggregated m_ij
|
||||
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
|
||||
|
||||
# Degree for normalization of position updates
|
||||
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
|
||||
pos = pos + agg_message / c
|
||||
|
||||
return x, pos
|
||||
149
pina/model/block/message_passing/interaction_network_block.py
Normal file
149
pina/model/block/message_passing/interaction_network_block.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Module for the Interaction Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from ....utils import check_positive_integer
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
class InteractionNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the Interaction Network block.
|
||||
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Battaglia et al. in
|
||||
2016. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
multi-layer perceptron (MLP) to the concatenation of the sender and
|
||||
recipient node features. Messages are then aggregated using an aggregation
|
||||
scheme (e.g., sum, mean, min, max, or product).
|
||||
|
||||
The update step is performed by applying another MLP to the concatenation of
|
||||
the incoming messages and the node features.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Battaglia, P. W., et al. (2016).
|
||||
*Interaction Networks for Learning about Objects, Relations and
|
||||
Physics*.
|
||||
In Advances in Neural Information Processing Systems (NeurIPS 2016).
|
||||
DOI: `<https://doi.org/10.48550/arXiv.1612.00222>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim=0,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
activation=torch.nn.SiLU,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`InteractionNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int edge_feature_dim: The dimension of the edge features.
|
||||
If edge_attr is not provided, it is assumed to be 0.
|
||||
Default is 0.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_message_layers: The number of layers in the message
|
||||
network. Default is 2.
|
||||
:param int n_update_layers: The number of layers in the update network.
|
||||
Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.SiLU`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_message_layers` is not a positive integer.
|
||||
:raises AssertionError: If `n_update_layers` is not a positive integer.
|
||||
:raises AssertionError: If `edge_feature_dim` is not a non-negative
|
||||
integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_message_layers, strict=True)
|
||||
check_positive_integer(n_update_layers, strict=True)
|
||||
check_positive_integer(edge_feature_dim, strict=False)
|
||||
|
||||
# Message network
|
||||
self.message_net = FeedForward(
|
||||
input_dimensions=2 * node_feature_dim + edge_feature_dim,
|
||||
output_dimensions=hidden_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_message_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
# Update network
|
||||
self.update_net = FeedForward(
|
||||
input_dimensions=node_feature_dim + hidden_dim,
|
||||
output_dimensions=node_feature_dim,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_update_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index, edge_attr=None):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indeces.
|
||||
:param edge_attr: The edge attributes. Default is None.
|
||||
:type edge_attr: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
|
||||
|
||||
def message(self, x_i, x_j, edge_attr):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_i: The node features of the recipient nodes.
|
||||
:type x_i: torch.Tensor | LabelTensor
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if edge_attr is None:
|
||||
input_ = torch.cat((x_i, x_j), dim=-1)
|
||||
else:
|
||||
input_ = torch.cat((x_i, x_j, edge_attr), dim=-1)
|
||||
return self.message_net(input_)
|
||||
|
||||
def update(self, message, x):
|
||||
"""
|
||||
Update the node features with the received messages.
|
||||
|
||||
:param torch.Tensor message: The message to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.update_net(torch.cat((x, message), dim=-1))
|
||||
126
pina/model/block/message_passing/radial_field_network_block.py
Normal file
126
pina/model/block/message_passing/radial_field_network_block.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Module for the Radial Field Network block."""
|
||||
|
||||
import torch
|
||||
from torch_geometric.nn import MessagePassing
|
||||
from torch_geometric.utils import remove_self_loops
|
||||
from ....utils import check_positive_integer
|
||||
from ....model import FeedForward
|
||||
|
||||
|
||||
class RadialFieldNetworkBlock(MessagePassing):
|
||||
"""
|
||||
Implementation of the Radial Field Network block.
|
||||
|
||||
This block is used to perform message-passing between nodes and edges in a
|
||||
graph neural network, following the scheme proposed by Köhler et al. in
|
||||
2020. It serves as an inner block in a larger graph neural network
|
||||
architecture.
|
||||
|
||||
The message between two nodes connected by an edge is computed by applying a
|
||||
linear transformation to the norm of the difference between the sender and
|
||||
recipient node features, together with the radial distance between the
|
||||
sender and recipient node features, followed by a non-linear activation
|
||||
function. Messages are then aggregated using an aggregation scheme
|
||||
(e.g., sum, mean, min, max, or product).
|
||||
|
||||
The update step is performed by a simple addition of the incoming messages
|
||||
to the node features.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference** Köhler, J., Klein, L., Noé, F. (2020).
|
||||
*Equivariant Flows: Exact Likelihood Generative Learning for Symmetric
|
||||
Densities*.
|
||||
In International Conference on Machine Learning.
|
||||
DOI: `<https://doi.org/10.48550/arXiv.2006.02425>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_feature_dim,
|
||||
hidden_dim=64,
|
||||
n_layers=2,
|
||||
activation=torch.nn.Tanh,
|
||||
aggr="add",
|
||||
node_dim=-2,
|
||||
flow="source_to_target",
|
||||
):
|
||||
"""
|
||||
Initialization of the :class:`RadialFieldNetworkBlock` class.
|
||||
|
||||
:param int node_feature_dim: The dimension of the node features.
|
||||
:param int hidden_dim: The dimension of the hidden features.
|
||||
Default is 64.
|
||||
:param int n_layers: The number of layers in the network. Default is 2.
|
||||
:param torch.nn.Module activation: The activation function.
|
||||
Default is :class:`torch.nn.Tanh`.
|
||||
:param str aggr: The aggregation scheme to use for message passing.
|
||||
Available options are "add", "mean", "min", "max", "mul".
|
||||
See :class:`torch_geometric.nn.MessagePassing` for more details.
|
||||
Default is "add".
|
||||
:param int node_dim: The axis along which to propagate. Default is -2.
|
||||
:param str flow: The direction of message passing. Available options
|
||||
are "source_to_target" and "target_to_source".
|
||||
The "source_to_target" flow means that messages are sent from
|
||||
the source node to the target node, while the "target_to_source"
|
||||
flow means that messages are sent from the target node to the
|
||||
source node. See :class:`torch_geometric.nn.MessagePassing` for more
|
||||
details. Default is "source_to_target".
|
||||
:raises AssertionError: If `node_feature_dim` is not a positive integer.
|
||||
:raises AssertionError: If `hidden_dim` is not a positive integer.
|
||||
:raises AssertionError: If `n_layers` is not a positive integer.
|
||||
"""
|
||||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
|
||||
|
||||
# Check values
|
||||
check_positive_integer(node_feature_dim, strict=True)
|
||||
check_positive_integer(hidden_dim, strict=True)
|
||||
check_positive_integer(n_layers, strict=True)
|
||||
|
||||
# Layer for processing node features
|
||||
self.radial_net = FeedForward(
|
||||
input_dimensions=1,
|
||||
output_dimensions=1,
|
||||
inner_size=hidden_dim,
|
||||
n_layers=n_layers,
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:param torch.Tensor edge_index: The edge indices.
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
edge_index, _ = remove_self_loops(edge_index)
|
||||
return self.propagate(edge_index=edge_index, x=x)
|
||||
|
||||
def message(self, x_i, x_j):
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
:param x_i: The node features of the recipient nodes.
|
||||
:type x_i: torch.Tensor | LabelTensor
|
||||
:param x_j: The node features of the sender nodes.
|
||||
:type x_j: torch.Tensor | LabelTensor
|
||||
:return: The message to be passed.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
r = x_i - x_j
|
||||
return self.radial_net(torch.norm(r, dim=1, keepdim=True)) * r
|
||||
|
||||
def update(self, message, x):
|
||||
"""
|
||||
Update the node features with the received messages.
|
||||
|
||||
:param torch.Tensor message: The message to be passed.
|
||||
:param x: The node features.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The updated node features.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return x + message
|
||||
@@ -275,44 +275,42 @@ def fast_laplacian(output_, input_, components, d, method="std"):
|
||||
def fast_advection(output_, input_, velocity_field, components, d):
|
||||
"""
|
||||
Perform the advection operation on the ``output_`` with respect to the
|
||||
``input``. This operator support vector-valued functions with multiple input
|
||||
coordinates.
|
||||
``input``. This operator supports vector-valued functions with multiple
|
||||
input coordinates.
|
||||
|
||||
Unlike ``advection``, this function performs no internal checks on input and
|
||||
output tensors. The user is required to specify both ``components`` and
|
||||
``d`` as lists of strings. It is designed to enhance computation speed.
|
||||
|
||||
:param LabelTensor output_: The output tensor on which the advection is
|
||||
computed.
|
||||
computed. It includes both the velocity and the quantity to be advected.
|
||||
:param LabelTensor input_: the input tensor with respect to which advection
|
||||
is computed.
|
||||
:param str velocity_field: The name of the output variable used as velocity
|
||||
field. It must be chosen among the output labels.
|
||||
:param list[str] velocity_field: The name of the output variables used as
|
||||
velocity field. It must be chosen among the output labels.
|
||||
:param list[str] components: The names of the output variables for which to
|
||||
compute the advection. It must be a subset of the output labels.
|
||||
:param list[str] d: The names of the input variables with respect to which
|
||||
the advection is computed. It must be a subset of the input labels.
|
||||
:return: The computed advection tensor.
|
||||
:rtype: torch.Tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# Add a dimension to the velocity field for following operations
|
||||
velocity = output_.extract(velocity_field).unsqueeze(-1)
|
||||
|
||||
# Remove the velocity field from the components
|
||||
filter_components = [c for c in components if c != velocity_field]
|
||||
|
||||
# Compute the gradient
|
||||
grads = fast_grad(
|
||||
output_=output_, input_=input_, components=filter_components, d=d
|
||||
output_=output_, input_=input_, components=components, d=d
|
||||
)
|
||||
|
||||
# Reshape into [..., len(filter_components), len(d)]
|
||||
tmp = grads.reshape(*output_.shape[:-1], len(filter_components), len(d))
|
||||
tmp = grads.reshape(*output_.shape[:-1], len(components), len(d))
|
||||
|
||||
# Transpose to [..., len(d), len(filter_components)]
|
||||
tmp = tmp.transpose(-1, -2)
|
||||
|
||||
return (tmp * velocity).sum(dim=tmp.tensor.ndim - 2)
|
||||
adv = (tmp * velocity).sum(dim=tmp.tensor.ndim - 2)
|
||||
return LabelTensor(adv, labels=[f"adv_{c}" for c in components])
|
||||
|
||||
|
||||
def grad(output_, input_, components=None, d=None):
|
||||
@@ -425,15 +423,16 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
|
||||
def advection(output_, input_, velocity_field, components=None, d=None):
|
||||
"""
|
||||
Perform the advection operation on the ``output_`` with respect to the
|
||||
``input``. This operator support vector-valued functions with multiple input
|
||||
coordinates.
|
||||
``input``. This operator supports vector-valued functions with multiple
|
||||
input coordinates.
|
||||
|
||||
:param LabelTensor output_: The output tensor on which the advection is
|
||||
computed.
|
||||
computed. It includes both the velocity and the quantity to be advected.
|
||||
:param LabelTensor input_: the input tensor with respect to which advection
|
||||
is computed.
|
||||
:param str velocity_field: The name of the output variable used as velocity
|
||||
:param velocity_field: The name of the output variables used as velocity
|
||||
field. It must be chosen among the output labels.
|
||||
:type velocity_field: str | list[str]
|
||||
:param components: The names of the output variables for which to compute
|
||||
the advection. It must be a subset of the output labels.
|
||||
If ``None``, all output variables are considered. Default is ``None``.
|
||||
@@ -444,18 +443,29 @@ def advection(output_, input_, velocity_field, components=None, d=None):
|
||||
:type d: str | list[str]
|
||||
:raises TypeError: If the input tensor is not a LabelTensor.
|
||||
:raises TypeError: If the output tensor is not a LabelTensor.
|
||||
:raises RuntimeError: If the velocity field is not in the output labels.
|
||||
:raises RuntimeError: If the velocity field is not a subset of the output
|
||||
labels.
|
||||
:raises RuntimeError: If the dimensionality of the velocity field does not
|
||||
match that of the input tensor.
|
||||
:return: The computed advection tensor.
|
||||
:rtype: torch.Tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
components, d = _check_values(
|
||||
output_=output_, input_=input_, components=components, d=d
|
||||
)
|
||||
|
||||
# Check if velocity field is present in the output labels
|
||||
if velocity_field not in output_.labels:
|
||||
# Map velocity_field to a list if it is a string
|
||||
if isinstance(velocity_field, str):
|
||||
velocity_field = [velocity_field]
|
||||
|
||||
# Check if all the velocity_field labels are present in the output labels
|
||||
if not all(vi in output_.labels for vi in velocity_field):
|
||||
raise RuntimeError("Velocity labels missing from output tensor.")
|
||||
|
||||
# Check if the velocity has the same dimensionality as the input tensor
|
||||
if len(velocity_field) != len(d):
|
||||
raise RuntimeError(
|
||||
f"Velocity {velocity_field} is not present in the output labels."
|
||||
"Velocity dimensionality does not match input dimensionality."
|
||||
)
|
||||
|
||||
return fast_advection(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Module for the AbstractProblem class."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface, CartesianDomain
|
||||
from ..condition.domain_equation_condition import DomainEquationCondition
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import merge_tensors
|
||||
from ..utils import merge_tensors, custom_warning_format
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
@@ -23,14 +24,11 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
Initialization of the :class:`AbstractProblem` class.
|
||||
"""
|
||||
self._discretised_domains = {}
|
||||
# create collector to manage problem data
|
||||
|
||||
# create hook conditions <-> problems
|
||||
for condition_name in self.conditions:
|
||||
self.conditions[condition_name].problem = self
|
||||
|
||||
self._batching_dimension = 0
|
||||
|
||||
# Store in domains dict all the domains object directly passed to
|
||||
# ConditionInterface. Done for back compatibility with PINA <0.2
|
||||
if not hasattr(self, "domains"):
|
||||
@@ -41,41 +39,57 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self.domains[cond_name] = cond.domain
|
||||
cond.domain = cond_name
|
||||
|
||||
self._collected_data = {}
|
||||
|
||||
@property
|
||||
def batching_dimension(self):
|
||||
def collected_data(self):
|
||||
"""
|
||||
Get batching dimension.
|
||||
Return the collected data from the problem's conditions. If some domains
|
||||
are not sampled, they will not be returned by collected data.
|
||||
|
||||
:return: The batching dimension.
|
||||
:rtype: int
|
||||
:return: The collected data. Keys are condition names, and values are
|
||||
dictionaries containing the input points and the corresponding
|
||||
equations or target points.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._batching_dimension
|
||||
|
||||
@batching_dimension.setter
|
||||
def batching_dimension(self, value):
|
||||
"""
|
||||
Set the batching dimension.
|
||||
|
||||
:param int value: The batching dimension.
|
||||
"""
|
||||
self._batching_dimension = value
|
||||
# collect data so far
|
||||
self.collect_data()
|
||||
# raise warning if some sample data are missing
|
||||
if not self.are_all_domains_discretised:
|
||||
warnings.formatwarning = custom_warning_format
|
||||
warnings.filterwarnings("always", category=RuntimeWarning)
|
||||
warning_message = "\n".join(
|
||||
[
|
||||
f"""{" " * 13} ---> Domain {key} {
|
||||
"sampled" if key in self.discretised_domains
|
||||
else
|
||||
"not sampled"}"""
|
||||
for key in self.domains
|
||||
]
|
||||
)
|
||||
warnings.warn(
|
||||
"Some of the domains are still not sampled. Consider calling "
|
||||
"problem.discretise_domain function for all domains before "
|
||||
"accessing the collected data:\n"
|
||||
f"{warning_message}",
|
||||
RuntimeWarning,
|
||||
)
|
||||
return self._collected_data
|
||||
|
||||
# back compatibility 0.1
|
||||
@property
|
||||
def input_pts(self):
|
||||
"""
|
||||
Return a dictionary mapping condition names to their corresponding
|
||||
input points.
|
||||
input points. If some domains are not sampled, they will not be returned
|
||||
and the corresponding condition will be empty.
|
||||
|
||||
:return: The input points of the problem.
|
||||
:rtype: dict
|
||||
"""
|
||||
to_return = {}
|
||||
for cond_name, cond in self.conditions.items():
|
||||
if hasattr(cond, "input"):
|
||||
to_return[cond_name] = cond.input
|
||||
elif hasattr(cond, "domain"):
|
||||
to_return[cond_name] = self._discretised_domains[cond.domain]
|
||||
for cond_name, data in self.collected_data.items():
|
||||
to_return[cond_name] = data["input"]
|
||||
return to_return
|
||||
|
||||
@property
|
||||
@@ -300,3 +314,29 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self.discretised_domains[k] = LabelTensor.vstack(
|
||||
[self.discretised_domains[k], v]
|
||||
)
|
||||
|
||||
def collect_data(self):
|
||||
"""
|
||||
Aggregate data from the problem's conditions into a single dictionary.
|
||||
"""
|
||||
data = {}
|
||||
# Iterate over the conditions and collect data
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
# Check if the condition has an domain attribute
|
||||
if hasattr(condition, "domain"):
|
||||
# Only store the discretisation points if the domain is
|
||||
# in the dictionary
|
||||
if condition.domain in self.discretised_domains:
|
||||
samples = self.discretised_domains[condition.domain]
|
||||
data[condition_name] = {
|
||||
"input": samples,
|
||||
"equation": condition.equation,
|
||||
}
|
||||
else:
|
||||
# If the condition does not have a domain attribute, store
|
||||
# the input and target points
|
||||
keys = condition.__slots__
|
||||
values = [getattr(condition, name) for name in keys]
|
||||
data[condition_name] = dict(zip(keys, values))
|
||||
self._collected_data = data
|
||||
|
||||
93
pina/type_checker.py
Normal file
93
pina/type_checker.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Module for enforcing type hints in Python functions."""
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
import logging
|
||||
|
||||
|
||||
def enforce_types(func):
|
||||
"""
|
||||
Function decorator to enforce type hints at runtime.
|
||||
|
||||
This decorator checks the types of the arguments and of the return value of
|
||||
the decorated function against the type hints specified in the function
|
||||
signature. If the types do not match, a TypeError is raised.
|
||||
Type checking is only performed when the logging level is set to `DEBUG`.
|
||||
|
||||
:param Callable func: The function to be decorated.
|
||||
:return: The decorated function with enforced type hints.
|
||||
:rtype: Callable
|
||||
|
||||
:Example:
|
||||
|
||||
>>> @enforce_types
|
||||
def dummy_function(a: int, b: float) -> float:
|
||||
... return a+b
|
||||
|
||||
# This always works.
|
||||
dummy_function(1, 2.0)
|
||||
|
||||
# This raises a TypeError for the second argument, if logging is set to
|
||||
# `DEBUG`.
|
||||
dummy_function(1, "Hello, world!")
|
||||
|
||||
|
||||
>>> @enforce_types
|
||||
def dummy_function2(a: int, right: bool) -> float:
|
||||
... if right:
|
||||
... return float(a)
|
||||
... else:
|
||||
... return "Hello, world!"
|
||||
|
||||
# This always works.
|
||||
dummy_function2(1, right=True)
|
||||
|
||||
# This raises a TypeError for the return value if logging is set to
|
||||
# `DEBUG`.
|
||||
dummy_function2(1, right=False)
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Wrapper function to enforce type hints.
|
||||
|
||||
:param tuple args: Positional arguments passed to the function.
|
||||
:param dict kwargs: Keyword arguments passed to the function.
|
||||
:raises TypeError: If the argument or return type does not match the
|
||||
specified type hints.
|
||||
:return: The result of the decorated function.
|
||||
:rtype: Any
|
||||
"""
|
||||
level = logging.getLevelName(logging.getLogger().getEffectiveLevel())
|
||||
|
||||
# Enforce type hints only in debug mode
|
||||
if level != "DEBUG":
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Get the type hints for the function arguments
|
||||
hints = typing.get_type_hints(func)
|
||||
sig = inspect.signature(func)
|
||||
bound = sig.bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
for arg_name, arg_value in bound.arguments.items():
|
||||
expected_type = hints.get(arg_name)
|
||||
if expected_type and not isinstance(arg_value, expected_type):
|
||||
raise TypeError(
|
||||
f"Argument '{arg_name}' must be {expected_type.__name__}, "
|
||||
f"but got {type(arg_value).__name__}!"
|
||||
)
|
||||
|
||||
# Get the type hints for the return values
|
||||
return_type = hints.get("return")
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if return_type and not isinstance(result, return_type):
|
||||
raise TypeError(
|
||||
f"Return value must be {return_type.__name__}, "
|
||||
f"but got {type(result).__name__}!"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -193,3 +193,22 @@ def chebyshev_roots(n):
|
||||
k = torch.arange(n)
|
||||
nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
|
||||
return nodes
|
||||
|
||||
|
||||
def check_positive_integer(value, strict=True):
|
||||
"""
|
||||
Check if the value is a positive integer.
|
||||
|
||||
:param int value: The value to check.
|
||||
:param bool strict: If True, the value must be strictly positive.
|
||||
Default is True.
|
||||
:raises AssertionError: If the value is not a positive integer.
|
||||
"""
|
||||
if strict:
|
||||
assert (
|
||||
isinstance(value, int) and value > 0
|
||||
), f"Expected a strictly positive integer, got {value}."
|
||||
else:
|
||||
assert (
|
||||
isinstance(value, int) and value >= 0
|
||||
), f"Expected a non-negative integer, got {value}."
|
||||
|
||||
@@ -1,45 +1,58 @@
|
||||
import pytest
|
||||
|
||||
from torch.nn import MSELoss
|
||||
|
||||
from pina.solver import PINN
|
||||
from pina.trainer import Trainer
|
||||
from pina.model import FeedForward
|
||||
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||
from pina.callback import R3Refinement
|
||||
from pina.callback.refinement import R3Refinement
|
||||
|
||||
|
||||
# make the problem
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ["g1", "g2", "g3", "g4"]
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
|
||||
poisson_problem.discretise_domain(n, "grid", domains="D")
|
||||
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
|
||||
poisson_problem.discretise_domain(10, "grid", domains="D")
|
||||
model = FeedForward(
|
||||
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
|
||||
)
|
||||
|
||||
# make the solver
|
||||
solver = PINN(problem=poisson_problem, model=model)
|
||||
|
||||
|
||||
# def test_r3constructor():
|
||||
# R3Refinement(sample_every=10)
|
||||
def test_constructor():
|
||||
# good constructor
|
||||
R3Refinement(sample_every=10)
|
||||
R3Refinement(sample_every=10, residual_loss=MSELoss)
|
||||
R3Refinement(sample_every=10, condition_to_update=["D"])
|
||||
# wrong constructor
|
||||
with pytest.raises(ValueError):
|
||||
R3Refinement(sample_every="str")
|
||||
with pytest.raises(ValueError):
|
||||
R3Refinement(sample_every=10, condition_to_update=3)
|
||||
|
||||
|
||||
# def test_r3refinment_routine():
|
||||
# # make the trainer
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[R3Refinement(sample_every=1)],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# trainer.train()
|
||||
|
||||
# def test_r3refinment_routine():
|
||||
# model = FeedForward(len(poisson_problem.input_variables),
|
||||
# len(poisson_problem.output_variables))
|
||||
# solver = PINN(problem=poisson_problem, model=model)
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[R3Refinement(sample_every=1)],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
|
||||
# trainer.train()
|
||||
# after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
|
||||
# assert before_n_points == after_n_points
|
||||
@pytest.mark.parametrize(
|
||||
"condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]]
|
||||
)
|
||||
def test_sample(condition_to_update):
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[
|
||||
R3Refinement(
|
||||
sample_every=1, condition_to_update=condition_to_update
|
||||
)
|
||||
],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
)
|
||||
before_n_points = {
|
||||
loc: len(trainer.solver.problem.input_pts[loc])
|
||||
for loc in condition_to_update
|
||||
}
|
||||
trainer.train()
|
||||
after_n_points = {
|
||||
loc: len(trainer.data_module.train_dataset.input[loc])
|
||||
for loc in condition_to_update
|
||||
}
|
||||
assert before_n_points == trainer.callbacks[0].initial_population_size
|
||||
assert before_n_points == after_n_points
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina import Condition, LabelTensor, Graph
|
||||
from pina.condition import InputTargetCondition, DomainEquationCondition
|
||||
from pina.graph import RadiusGraph
|
||||
from pina.problem import AbstractProblem, SpatialProblem
|
||||
from pina.domain import CartesianDomain
|
||||
from pina.equation.equation import Equation
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
from pina.operator import laplacian
|
||||
from pina.collector import Collector
|
||||
|
||||
|
||||
def test_supervised_tensor_collector():
|
||||
class SupervisedProblem(AbstractProblem):
|
||||
output_variables = None
|
||||
conditions = {
|
||||
"data1": Condition(
|
||||
input=torch.rand((10, 2)),
|
||||
target=torch.rand((10, 2)),
|
||||
),
|
||||
"data2": Condition(
|
||||
input=torch.rand((20, 2)),
|
||||
target=torch.rand((20, 2)),
|
||||
),
|
||||
"data3": Condition(
|
||||
input=torch.rand((30, 2)),
|
||||
target=torch.rand((30, 2)),
|
||||
),
|
||||
}
|
||||
|
||||
problem = SupervisedProblem()
|
||||
collector = Collector(problem)
|
||||
for v in collector.conditions_name.values():
|
||||
assert v in problem.conditions.keys()
|
||||
|
||||
|
||||
def test_pinn_collector():
|
||||
def laplace_equation(input_, output_):
|
||||
force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin(
|
||||
input_.extract(["y"]) * torch.pi
|
||||
)
|
||||
delta_u = laplacian(output_.extract(["u"]), input_)
|
||||
return delta_u - force_term
|
||||
|
||||
my_laplace = Equation(laplace_equation)
|
||||
in_ = LabelTensor(
|
||||
torch.tensor([[0.0, 1.0]], requires_grad=True), ["x", "y"]
|
||||
)
|
||||
out_ = LabelTensor(torch.tensor([[0.0]], requires_grad=True), ["u"])
|
||||
|
||||
class Poisson(SpatialProblem):
|
||||
output_variables = ["u"]
|
||||
spatial_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
|
||||
|
||||
conditions = {
|
||||
"gamma1": Condition(
|
||||
domain=CartesianDomain({"x": [0, 1], "y": 1}),
|
||||
equation=FixedValue(0.0),
|
||||
),
|
||||
"gamma2": Condition(
|
||||
domain=CartesianDomain({"x": [0, 1], "y": 0}),
|
||||
equation=FixedValue(0.0),
|
||||
),
|
||||
"gamma3": Condition(
|
||||
domain=CartesianDomain({"x": 1, "y": [0, 1]}),
|
||||
equation=FixedValue(0.0),
|
||||
),
|
||||
"gamma4": Condition(
|
||||
domain=CartesianDomain({"x": 0, "y": [0, 1]}),
|
||||
equation=FixedValue(0.0),
|
||||
),
|
||||
"D": Condition(
|
||||
domain=CartesianDomain({"x": [0, 1], "y": [0, 1]}),
|
||||
equation=my_laplace,
|
||||
),
|
||||
"data": Condition(input=in_, target=out_),
|
||||
}
|
||||
|
||||
def poisson_sol(self, pts):
|
||||
return -(
|
||||
torch.sin(pts.extract(["x"]) * torch.pi)
|
||||
* torch.sin(pts.extract(["y"]) * torch.pi)
|
||||
) / (2 * torch.pi**2)
|
||||
|
||||
truth_solution = poisson_sol
|
||||
|
||||
problem = Poisson()
|
||||
boundaries = ["gamma1", "gamma2", "gamma3", "gamma4"]
|
||||
problem.discretise_domain(10, "grid", domains=boundaries)
|
||||
problem.discretise_domain(10, "grid", domains="D")
|
||||
|
||||
collector = Collector(problem)
|
||||
collector.store_fixed_data()
|
||||
collector.store_sample_domains()
|
||||
|
||||
for k, v in problem.conditions.items():
|
||||
if isinstance(v, InputTargetCondition):
|
||||
assert list(collector.data_collections[k].keys()) == [
|
||||
"input",
|
||||
"target",
|
||||
]
|
||||
|
||||
for k, v in problem.conditions.items():
|
||||
if isinstance(v, DomainEquationCondition):
|
||||
assert list(collector.data_collections[k].keys()) == [
|
||||
"input",
|
||||
"equation",
|
||||
]
|
||||
|
||||
|
||||
def test_supervised_graph_collector():
|
||||
pos = torch.rand((100, 3))
|
||||
x = [torch.rand((100, 3)) for _ in range(10)]
|
||||
graph_list_1 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x]
|
||||
out_1 = torch.rand((10, 100, 3))
|
||||
|
||||
pos = torch.rand((50, 3))
|
||||
x = [torch.rand((50, 3)) for _ in range(10)]
|
||||
graph_list_2 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x]
|
||||
out_2 = torch.rand((10, 50, 3))
|
||||
|
||||
class SupervisedProblem(AbstractProblem):
|
||||
output_variables = None
|
||||
conditions = {
|
||||
"data1": Condition(input=graph_list_1, target=out_1),
|
||||
"data2": Condition(input=graph_list_2, target=out_2),
|
||||
}
|
||||
|
||||
problem = SupervisedProblem()
|
||||
collector = Collector(problem)
|
||||
collector.store_fixed_data()
|
||||
# assert all(collector._is_conditions_ready.values())
|
||||
for v in collector.conditions_name.values():
|
||||
assert v in problem.conditions.keys()
|
||||
59
tests/test_messagepassing/test_deep_tensor_network_block.py
Normal file
59
tests/test_messagepassing/test_deep_tensor_network_block.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina.model.block.message_passing import DeepTensorNetworkBlock
|
||||
|
||||
# Data for testing
|
||||
x = torch.rand(10, 3)
|
||||
edge_index = torch.randint(0, 10, (2, 20))
|
||||
edge_attr = torch.randn(20, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
@pytest.mark.parametrize("edge_feature_dim", [3, 5])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim):
|
||||
|
||||
DeepTensorNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
)
|
||||
|
||||
# Should fail if node_feature_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
DeepTensorNetworkBlock(
|
||||
node_feature_dim=-1, edge_feature_dim=edge_feature_dim
|
||||
)
|
||||
|
||||
# Should fail if edge_feature_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
DeepTensorNetworkBlock(
|
||||
node_feature_dim=node_feature_dim, edge_feature_dim=-1
|
||||
)
|
||||
|
||||
|
||||
def test_forward():
|
||||
|
||||
model = DeepTensorNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=edge_attr.shape[1],
|
||||
)
|
||||
|
||||
output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr)
|
||||
assert output_.shape == x.shape
|
||||
|
||||
|
||||
def test_backward():
|
||||
|
||||
model = DeepTensorNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=edge_attr.shape[1],
|
||||
)
|
||||
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
x=x.requires_grad_(),
|
||||
edge_attr=edge_attr.requires_grad_(),
|
||||
)
|
||||
|
||||
loss = torch.mean(output_)
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
165
tests/test_messagepassing/test_equivariant_network_block.py
Normal file
165
tests/test_messagepassing/test_equivariant_network_block.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina.model.block.message_passing import EnEquivariantNetworkBlock
|
||||
|
||||
# Data for testing
|
||||
x = torch.rand(10, 4)
|
||||
pos = torch.rand(10, 3)
|
||||
edge_index = torch.randint(0, 10, (2, 20))
|
||||
edge_attr = torch.randn(20, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
@pytest.mark.parametrize("pos_dim", [2, 3])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
)
|
||||
|
||||
# Should fail if node_feature_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=-1,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
)
|
||||
|
||||
# Should fail if edge_feature_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=-1,
|
||||
pos_dim=pos_dim,
|
||||
)
|
||||
|
||||
# Should fail if pos_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=-1,
|
||||
)
|
||||
|
||||
# Should fail if hidden_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
hidden_dim=-1,
|
||||
)
|
||||
|
||||
# Should fail if n_message_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
n_message_layers=-1,
|
||||
)
|
||||
|
||||
# Should fail if n_update_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
n_update_layers=-1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_forward(edge_feature_dim):
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos.shape[1],
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(edge_index=edge_index, x=x, pos=pos)
|
||||
else:
|
||||
output_ = model(
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
|
||||
)
|
||||
|
||||
assert output_[0].shape == x.shape
|
||||
assert output_[1].shape == pos.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_backward(edge_feature_dim):
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos.shape[1],
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
)
|
||||
else:
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
edge_attr=edge_attr.requires_grad_(),
|
||||
)
|
||||
|
||||
loss = torch.mean(output_[0])
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
assert pos.grad.shape == pos.shape
|
||||
|
||||
|
||||
def test_equivariance():
|
||||
|
||||
# Graph to be fully connected and undirected
|
||||
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
|
||||
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
|
||||
|
||||
# Random rotation (det(rotation) should be 1)
|
||||
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
|
||||
# Random translation
|
||||
translation = torch.rand(1, pos.shape[-1])
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=0,
|
||||
pos_dim=pos.shape[1],
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
).eval()
|
||||
|
||||
h1, pos1 = model(edge_index=edge_index, x=x, pos=pos)
|
||||
h2, pos2 = model(
|
||||
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation
|
||||
)
|
||||
|
||||
# Transform model output
|
||||
pos1_transformed = (pos1 @ rotation.T) + translation
|
||||
|
||||
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)
|
||||
assert torch.allclose(h1, h2, atol=1e-5)
|
||||
84
tests/test_messagepassing/test_interaction_network_block.py
Normal file
84
tests/test_messagepassing/test_interaction_network_block.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina.model.block.message_passing import InteractionNetworkBlock
|
||||
|
||||
# Data for testing
|
||||
x = torch.rand(10, 3)
|
||||
edge_index = torch.randint(0, 10, (2, 20))
|
||||
edge_attr = torch.randn(20, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim):
|
||||
|
||||
InteractionNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
)
|
||||
|
||||
# Should fail if node_feature_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
InteractionNetworkBlock(node_feature_dim=-1)
|
||||
|
||||
# Should fail if edge_feature_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
InteractionNetworkBlock(node_feature_dim=3, edge_feature_dim=-1)
|
||||
|
||||
# Should fail if hidden_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
InteractionNetworkBlock(node_feature_dim=3, hidden_dim=-1)
|
||||
|
||||
# Should fail if n_message_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
InteractionNetworkBlock(node_feature_dim=3, n_message_layers=-1)
|
||||
|
||||
# Should fail if n_update_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
InteractionNetworkBlock(node_feature_dim=3, n_update_layers=-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_forward(edge_feature_dim):
|
||||
|
||||
model = InteractionNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(edge_index=edge_index, x=x)
|
||||
else:
|
||||
output_ = model(edge_index=edge_index, x=x, edge_attr=edge_attr)
|
||||
assert output_.shape == x.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_backward(edge_feature_dim):
|
||||
|
||||
model = InteractionNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(edge_index=edge_index, x=x.requires_grad_())
|
||||
else:
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
x=x.requires_grad_(),
|
||||
edge_attr=edge_attr.requires_grad_(),
|
||||
)
|
||||
|
||||
loss = torch.mean(output_)
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
92
tests/test_messagepassing/test_radial_field_network_block.py
Normal file
92
tests/test_messagepassing/test_radial_field_network_block.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina.model.block.message_passing import RadialFieldNetworkBlock
|
||||
|
||||
# Data for testing
|
||||
x = torch.rand(10, 3)
|
||||
edge_index = torch.randint(0, 10, (2, 20))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
def test_constructor(node_feature_dim):
|
||||
|
||||
RadialFieldNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
hidden_dim=64,
|
||||
n_layers=2,
|
||||
)
|
||||
|
||||
# Should fail if node_feature_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
RadialFieldNetworkBlock(
|
||||
node_feature_dim=-1,
|
||||
hidden_dim=64,
|
||||
n_layers=2,
|
||||
)
|
||||
|
||||
# Should fail if hidden_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
RadialFieldNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
hidden_dim=-1,
|
||||
n_layers=2,
|
||||
)
|
||||
|
||||
# Should fail if n_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
RadialFieldNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
hidden_dim=64,
|
||||
n_layers=-1,
|
||||
)
|
||||
|
||||
|
||||
def test_forward():
|
||||
|
||||
model = RadialFieldNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
hidden_dim=64,
|
||||
n_layers=2,
|
||||
)
|
||||
|
||||
output_ = model(edge_index=edge_index, x=x)
|
||||
assert output_.shape == x.shape
|
||||
|
||||
|
||||
def test_backward():
|
||||
|
||||
model = RadialFieldNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
hidden_dim=64,
|
||||
n_layers=2,
|
||||
)
|
||||
|
||||
output_ = model(edge_index=edge_index, x=x.requires_grad_())
|
||||
loss = torch.mean(output_)
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
|
||||
def test_equivariance():
|
||||
|
||||
# Graph to be fully connected and undirected
|
||||
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
|
||||
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
|
||||
|
||||
# Random rotation (det(rotation) should be 1)
|
||||
rotation = torch.linalg.qr(torch.rand(x.shape[-1], x.shape[-1])).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
|
||||
# Random translation
|
||||
translation = torch.rand(1, x.shape[-1])
|
||||
|
||||
model = RadialFieldNetworkBlock(node_feature_dim=x.shape[1]).eval()
|
||||
|
||||
pos1 = model(edge_index=edge_index, x=x)
|
||||
pos2 = model(edge_index=edge_index, x=x @ rotation.T + translation)
|
||||
|
||||
# Transform model output
|
||||
pos1_transformed = (pos1 @ rotation.T) + translation
|
||||
|
||||
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)
|
||||
@@ -296,22 +296,183 @@ def test_laplacian(f):
|
||||
laplacian(output_=output_, input_=input_, components=["a", "b", "c"])
|
||||
|
||||
|
||||
def test_advection():
|
||||
def test_advection_scalar():
|
||||
|
||||
# Define input and output
|
||||
# Define 3-dimensional input
|
||||
input_ = torch.rand((20, 3), requires_grad=True)
|
||||
input_ = LabelTensor(input_, ["x", "y", "z"])
|
||||
output_ = LabelTensor(input_**2, ["u", "v", "c"])
|
||||
|
||||
# Define the velocity field
|
||||
velocity = output_.extract(["c"])
|
||||
# Define 3-dimensional velocity field and quantity to be advected
|
||||
velocity = torch.rand((20, 3), requires_grad=True)
|
||||
field = torch.sum(input_**2, dim=-1, keepdim=True)
|
||||
|
||||
# Compute the true advection and the pina advection
|
||||
pina_advection = advection(
|
||||
output_=output_, input_=input_, velocity_field="c"
|
||||
# Combine velocity and field into a LabelTensor
|
||||
labels = ["ux", "uy", "uz", "c"]
|
||||
output_ = LabelTensor(torch.cat((velocity, field), dim=1), labels)
|
||||
|
||||
# Compute the pina advection
|
||||
components = ["c"]
|
||||
pina_adv = advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
components=components,
|
||||
d=["x", "y", "z"],
|
||||
)
|
||||
true_advection = velocity * 2 * input_.extract(["x", "y"])
|
||||
|
||||
# Check the shape of the advection
|
||||
assert pina_advection.shape == (*output_.shape[:-1], output_.shape[-1] - 1)
|
||||
assert torch.allclose(pina_advection, true_advection)
|
||||
# Compute the true advection
|
||||
grads = 2 * input_
|
||||
true_adv = torch.sum(grads * velocity, dim=grads.ndim - 1, keepdim=True)
|
||||
|
||||
# Check the shape, labels, and value of the advection
|
||||
assert pina_adv.shape == (*output_.shape[:-1], len(components))
|
||||
assert pina_adv.labels == ["adv_c"]
|
||||
assert torch.allclose(pina_adv, true_adv)
|
||||
|
||||
# Should fail if input not a LabelTensor
|
||||
with pytest.raises(TypeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_.tensor,
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail if output not a LabelTensor
|
||||
with pytest.raises(TypeError):
|
||||
advection(
|
||||
output_=output_.tensor,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail for non-existent input labels
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
d=["x", "a"],
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail for non-existent output labels
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
components=["a", "b", "c"],
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail if velocity_field labels are not present in the output labels
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy", "nonexistent"],
|
||||
components=["c"],
|
||||
)
|
||||
|
||||
# Should fail if velocity_field dimensionality does not match input tensor
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy"],
|
||||
components=["c"],
|
||||
)
|
||||
|
||||
|
||||
def test_advection_vector():
|
||||
|
||||
# Define 3-dimensional input
|
||||
input_ = torch.rand((20, 3), requires_grad=True)
|
||||
input_ = LabelTensor(input_, ["x", "y", "z"])
|
||||
|
||||
# Define 3-dimensional velocity field
|
||||
velocity = torch.rand((20, 3), requires_grad=True)
|
||||
|
||||
# Define 2-dimensional field to be advected
|
||||
field_1 = torch.sum(input_**2, dim=-1, keepdim=True)
|
||||
field_2 = torch.sum(input_**3, dim=-1, keepdim=True)
|
||||
|
||||
# Combine velocity and field into a LabelTensor
|
||||
labels = ["ux", "uy", "uz", "c1", "c2"]
|
||||
output_ = LabelTensor(
|
||||
torch.cat((velocity, field_1, field_2), dim=1), labels
|
||||
)
|
||||
|
||||
# Compute the pina advection
|
||||
components = ["c1", "c2"]
|
||||
pina_adv = advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
components=components,
|
||||
d=["x", "y", "z"],
|
||||
)
|
||||
|
||||
# Compute the true gradients of the fields "c1", "c2"
|
||||
grads1 = 2 * input_
|
||||
grads2 = 3 * input_**2
|
||||
|
||||
# Compute the true advection for each field
|
||||
true_adv1 = torch.sum(grads1 * velocity, dim=grads1.ndim - 1, keepdim=True)
|
||||
true_adv2 = torch.sum(grads2 * velocity, dim=grads2.ndim - 1, keepdim=True)
|
||||
true_adv = torch.cat((true_adv1, true_adv2), dim=-1)
|
||||
|
||||
# Check the shape, labels, and value of the advection
|
||||
assert pina_adv.shape == (*output_.shape[:-1], len(components))
|
||||
assert pina_adv.labels == ["adv_c1", "adv_c2"]
|
||||
assert torch.allclose(pina_adv, true_adv)
|
||||
|
||||
# Should fail if input not a LabelTensor
|
||||
with pytest.raises(TypeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_.tensor,
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail if output not a LabelTensor
|
||||
with pytest.raises(TypeError):
|
||||
advection(
|
||||
output_=output_.tensor,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail for non-existent input labels
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
d=["x", "a"],
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail for non-existent output labels
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
components=["a", "b", "c"],
|
||||
velocity_field=["ux", "uy", "uz"],
|
||||
)
|
||||
|
||||
# Should fail if velocity_field labels are not present in the output labels
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy", "nonexistent"],
|
||||
components=["c"],
|
||||
)
|
||||
|
||||
# Should fail if velocity_field dimensionality does not match input tensor
|
||||
with pytest.raises(RuntimeError):
|
||||
advection(
|
||||
output_=output_,
|
||||
input_=input_,
|
||||
velocity_field=["ux", "uy"],
|
||||
components=["c"],
|
||||
)
|
||||
|
||||
@@ -4,6 +4,11 @@ from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||
from pina import LabelTensor
|
||||
from pina.domain import Union
|
||||
from pina.domain import CartesianDomain
|
||||
from pina.condition import (
|
||||
Condition,
|
||||
InputTargetCondition,
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
|
||||
def test_discretise_domain():
|
||||
@@ -45,6 +50,24 @@ def test_variables_correct_order_sampling():
|
||||
)
|
||||
|
||||
|
||||
def test_input_pts():
|
||||
n = 10
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(n, "grid")
|
||||
assert sorted(list(poisson_problem.input_pts.keys())) == sorted(
|
||||
list(poisson_problem.conditions.keys())
|
||||
)
|
||||
|
||||
|
||||
def test_collected_data():
|
||||
n = 10
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(n, "grid")
|
||||
assert sorted(list(poisson_problem.collected_data.keys())) == sorted(
|
||||
list(poisson_problem.conditions.keys())
|
||||
)
|
||||
|
||||
|
||||
def test_add_points():
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(0, "random", domains=["D"])
|
||||
@@ -84,3 +107,23 @@ def test_wrong_custom_sampling_logic(mode):
|
||||
}
|
||||
with pytest.raises(RuntimeError):
|
||||
poisson_problem.discretise_domain(sample_rules=sampling_rules)
|
||||
|
||||
|
||||
def test_aggregate_data():
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.conditions["data"] = Condition(
|
||||
input=LabelTensor(torch.tensor([[0.0, 1.0]]), labels=["x", "y"]),
|
||||
target=LabelTensor(torch.tensor([[0.0]]), labels=["u"]),
|
||||
)
|
||||
poisson_problem.discretise_domain(0, "random", domains="all")
|
||||
poisson_problem.collect_data()
|
||||
assert isinstance(poisson_problem.collected_data, dict)
|
||||
for name, conditions in poisson_problem.conditions.items():
|
||||
assert name in poisson_problem.collected_data.keys()
|
||||
if isinstance(conditions, InputTargetCondition):
|
||||
assert "input" in poisson_problem.collected_data[name].keys()
|
||||
assert "target" in poisson_problem.collected_data[name].keys()
|
||||
elif isinstance(conditions, DomainEquationCondition):
|
||||
assert "input" in poisson_problem.collected_data[name].keys()
|
||||
assert "target" not in poisson_problem.collected_data[name].keys()
|
||||
assert "equation" in poisson_problem.collected_data[name].keys()
|
||||
|
||||
55
tests/test_type_checker.py
Normal file
55
tests/test_type_checker.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
import logging
|
||||
import math
|
||||
from pina.type_checker import enforce_types
|
||||
|
||||
|
||||
# Definition of a test function for arguments
|
||||
@enforce_types
|
||||
def foo_function1(a: int, b: float) -> float:
|
||||
return a + b
|
||||
|
||||
|
||||
# Definition of a test function for return values
|
||||
@enforce_types
|
||||
def foo_function2(a: int, right: bool) -> float:
|
||||
if right:
|
||||
return float(a)
|
||||
else:
|
||||
return "Hello, world!"
|
||||
|
||||
|
||||
def test_argument_type_checking():
|
||||
|
||||
# Setting logging level to INFO, which should not trigger type checking
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Both should work, even if the arguments are not of the expected type
|
||||
assert math.isclose(foo_function1(a=1, b=2.0), 3.0)
|
||||
assert math.isclose(foo_function1(a=1, b=2), 3.0)
|
||||
|
||||
# Setting logging level to DEBUG, which should trigger type checking
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# The second should fail, as the second argument is an int
|
||||
assert math.isclose(foo_function1(a=1, b=2.0), 3.0)
|
||||
with pytest.raises(TypeError):
|
||||
foo_function1(a=1, b=2)
|
||||
|
||||
|
||||
def test_return_type_checking():
|
||||
|
||||
# Setting logging level to INFO, which should not trigger type checking
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Both should work, even if the return value is not of the expected type
|
||||
assert math.isclose(foo_function2(a=1, right=True), 1.0)
|
||||
assert foo_function2(a=1, right=False) == "Hello, world!"
|
||||
|
||||
# Setting logging level to DEBUG, which should trigger type checking
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# The second should fail, as the return value is a string
|
||||
assert math.isclose(foo_function2(a=1, right=True), 1.0)
|
||||
with pytest.raises(TypeError):
|
||||
foo_function2(a=1, right=False)
|
||||
@@ -1,12 +1,9 @@
|
||||
import torch
|
||||
|
||||
from pina.utils import merge_tensors
|
||||
from pina.label_tensor import LabelTensor
|
||||
from pina import LabelTensor
|
||||
from pina.domain import EllipsoidDomain, CartesianDomain
|
||||
from pina.utils import check_consistency
|
||||
import pytest
|
||||
from pina.domain import DomainInterface
|
||||
|
||||
from pina import LabelTensor
|
||||
from pina.utils import merge_tensors, check_consistency, check_positive_integer
|
||||
from pina.domain import EllipsoidDomain, CartesianDomain, DomainInterface
|
||||
|
||||
|
||||
def test_merge_tensors():
|
||||
@@ -50,3 +47,24 @@ def test_check_consistency_incorrect():
|
||||
check_consistency(torch.Tensor, DomainInterface, subclass=True)
|
||||
with pytest.raises(ValueError):
|
||||
check_consistency(ellipsoid1, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", [0, 1, 2, 3, 10])
|
||||
@pytest.mark.parametrize("strict", [True, False])
|
||||
def test_check_positive_integer(value, strict):
|
||||
if value != 0:
|
||||
check_positive_integer(value, strict=strict)
|
||||
else:
|
||||
check_positive_integer(value, strict=False)
|
||||
|
||||
# Should fail if value is negative
|
||||
with pytest.raises(AssertionError):
|
||||
check_positive_integer(-1, strict=strict)
|
||||
|
||||
# Should fail if value is not an integer
|
||||
with pytest.raises(AssertionError):
|
||||
check_positive_integer(1.5, strict=strict)
|
||||
|
||||
# Should fail if value is not a number
|
||||
with pytest.raises(AssertionError):
|
||||
check_positive_integer("string", strict=strict)
|
||||
|
||||
42
tutorials/README.md
vendored
42
tutorials/README.md
vendored
@@ -9,39 +9,39 @@ The table below provides an overview of each tutorial. All tutorials are also av
|
||||
|
||||
| Description | Tutorial |
|
||||
|---------------|-----------|
|
||||
Introductory Tutorial: A Beginner’s Guide to PINA|[[.ipynb](tutorial17/tutorial.ipynb),[.py](tutorial17/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial17/tutorial.html)]|
|
||||
How to build a `Problem` in PINA|[[.ipynb](tutorial16/tutorial.ipynb),[.py](tutorial16/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial16/tutorial.html)]|
|
||||
Introduction to Solver classes|[[.ipynb](tutorial18/tutorial.ipynb),[.py](tutorial18/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial18/tutorial.html)]|
|
||||
Introduction to `Trainer` class|[[.ipynb](tutorial11/tutorial.ipynb),[.py](tutorial11/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial11/tutorial.html)]|
|
||||
Data structure for SciML: `Tensor`, `LabelTensor`, `Data` and `Graph` |[[.ipynb](tutorial19/tutorial.ipynb),[.py](tutorial19/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial19/tutorial.html)]|
|
||||
Building geometries with `DomainInterface` class|[[.ipynb](tutorial6/tutorial.ipynb),[.py](tutorial6/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial6/tutorial.html)]|
|
||||
Introduction to PINA `Equation` class|[[.ipynb](tutorial12/tutorial.ipynb),[.py](tutorial12/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial12/tutorial.html)]|
|
||||
Introductory Tutorial: A Beginner’s Guide to PINA|[[.ipynb](tutorial17/tutorial.ipynb),[.py](tutorial17/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial17/tutorial.html)]|
|
||||
How to build a `Problem` in PINA|[[.ipynb](tutorial16/tutorial.ipynb),[.py](tutorial16/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial16/tutorial.html)]|
|
||||
Introduction to Solver classes|[[.ipynb](tutorial18/tutorial.ipynb),[.py](tutorial18/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial18/tutorial.html)]|
|
||||
Introduction to `Trainer` class|[[.ipynb](tutorial11/tutorial.ipynb),[.py](tutorial11/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial11/tutorial.html)]|
|
||||
Data structure for SciML: `Tensor`, `LabelTensor`, `Data` and `Graph` |[[.ipynb](tutorial19/tutorial.ipynb),[.py](tutorial19/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial19/tutorial.html)]|
|
||||
Building geometries with `DomainInterface` class|[[.ipynb](tutorial6/tutorial.ipynb),[.py](tutorial6/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial6/tutorial.html)]|
|
||||
Introduction to PINA `Equation` class|[[.ipynb](tutorial12/tutorial.ipynb),[.py](tutorial12/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial12/tutorial.html)]|
|
||||
|
||||
|
||||
## Physics Informed Neural Networks
|
||||
| Description | Tutorial |
|
||||
|---------------|-----------|
|
||||
Introductory Tutorial: Physics Informed Neural Networks with PINA |[[.ipynb](tutorial1/tutorial.ipynb),[.py](tutorial1/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial1/tutorial.html)]|
|
||||
Enhancing PINNs with Extra Features to solve the Poisson Problem |[[.ipynb](tutorial2/tutorial.ipynb),[.py](tutorial2/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial2/tutorial.html)]|
|
||||
Applying Hard Constraints in PINNs to solve the Wave Problem |[[.ipynb](tutorial3/tutorial.ipynb),[.py](tutorial3/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial3/tutorial.html)]|
|
||||
Applying Periodic Boundary Conditions in PINNs to solve the Helmotz Problem |[[.ipynb](tutorial9/tutorial.ipynb),[.py](tutorial9/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial9/tutorial.html)]|
|
||||
Inverse Problem Solving with Physics-Informed Neural Network |[[.ipynb](tutorial7/tutorial.ipynb),[.py](tutorial7/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial7/tutorial.html)]|
|
||||
Learning Multiscale PDEs Using Fourier Feature Networks|[[.ipynb](tutorial13/tutorial.ipynb),[.py](tutorial13/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial13/tutorial.html)]|
|
||||
Learning Bifurcating PDE Solutions with Physics-Informed Deep Ensembles|[[.ipynb](tutorial14/tutorial.ipynb),[.py](tutorial14/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial14/tutorial.html)]|
|
||||
Introductory Tutorial: Physics Informed Neural Networks with PINA |[[.ipynb](tutorial1/tutorial.ipynb),[.py](tutorial1/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial1/tutorial.html)]|
|
||||
Enhancing PINNs with Extra Features to solve the Poisson Problem |[[.ipynb](tutorial2/tutorial.ipynb),[.py](tutorial2/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial2/tutorial.html)]|
|
||||
Applying Hard Constraints in PINNs to solve the Wave Problem |[[.ipynb](tutorial3/tutorial.ipynb),[.py](tutorial3/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial3/tutorial.html)]|
|
||||
Applying Periodic Boundary Conditions in PINNs to solve the Helmholtz Problem |[[.ipynb](tutorial9/tutorial.ipynb),[.py](tutorial9/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial9/tutorial.html)]|
|
||||
Inverse Problem Solving with Physics-Informed Neural Network |[[.ipynb](tutorial7/tutorial.ipynb),[.py](tutorial7/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial7/tutorial.html)]|
|
||||
Learning Multiscale PDEs Using Fourier Feature Networks|[[.ipynb](tutorial13/tutorial.ipynb),[.py](tutorial13/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial13/tutorial.html)]|
|
||||
Learning Bifurcating PDE Solutions with Physics-Informed Deep Ensembles|[[.ipynb](tutorial14/tutorial.ipynb),[.py](tutorial14/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial14/tutorial.html)]|
|
||||
|
||||
|
||||
## Neural Operator Learning
|
||||
| Description | Tutorial |
|
||||
|---------------|-----------|
|
||||
Introductory Tutorial: Neural Operator Learning with PINA |[[.ipynb](tutorial21/tutorial.ipynb),[.py](tutorial21/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial21/tutorial.html)]|
|
||||
Modeling 2D Darcy Flow with the Fourier Neural Operator |[[.ipynb](tutorial5/tutorial.ipynb),[.py](tutorial5/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial5/tutorial.html)]|
|
||||
Solving the Kuramoto–Sivashinsky Equation with Averaging Neural Operator |[[.ipynb](tutorial10/tutorial.ipynb),[.py](tutorial10/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial10/tutorial.html)]|
|
||||
Introductory Tutorial: Neural Operator Learning with PINA |[[.ipynb](tutorial21/tutorial.ipynb),[.py](tutorial21/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial21/tutorial.html)]|
|
||||
Modeling 2D Darcy Flow with the Fourier Neural Operator |[[.ipynb](tutorial5/tutorial.ipynb),[.py](tutorial5/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial5/tutorial.html)]|
|
||||
Solving the Kuramoto–Sivashinsky Equation with Averaging Neural Operator |[[.ipynb](tutorial10/tutorial.ipynb),[.py](tutorial10/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial10/tutorial.html)]|
|
||||
|
||||
## Supervised Learning
|
||||
| Description | Tutorial |
|
||||
|---------------|-----------|
|
||||
Introductory Tutorial: Supervised Learning with PINA |[[.ipynb](tutorial20/tutorial.ipynb),[.py](tutorial20/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial20/tutorial.html)]|
|
||||
Chemical Properties Prediction with Graph Neural Networks |[[.ipynb](tutorial15/tutorial.ipynb),[.py](tutorial15/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial15/tutorial.html)]|
|
||||
Unstructured Convolutional Autoencoders with Continuous Convolution |[[.ipynb](tutorial4/tutorial.ipynb),[.py](tutorial4/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial4/tutorial.html)]|
|
||||
Reduced Order Modeling with POD-RBF and POD-NN Approaches for Fluid Dynamics| [[.ipynb](tutorial8/tutorial.ipynb),[.py](tutorial8/tutorial.py),[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial8/tutorial.html)]|
|
||||
Introductory Tutorial: Supervised Learning with PINA |[[.ipynb](tutorial20/tutorial.ipynb),[.py](tutorial20/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial20/tutorial.html)]|
|
||||
Chemical Properties Prediction with Graph Neural Networks |[[.ipynb](tutorial15/tutorial.ipynb),[.py](tutorial15/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial15/tutorial.html)]|
|
||||
Unstructured Convolutional Autoencoders with Continuous Convolution |[[.ipynb](tutorial4/tutorial.ipynb),[.py](tutorial4/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial4/tutorial.html)]|
|
||||
Reduced Order Modeling with POD-RBF and POD-NN Approaches for Fluid Dynamics| [[.ipynb](tutorial8/tutorial.ipynb),[.py](tutorial8/tutorial.py),[.html](http://mathlab.github.io/PINA/tutorial8/tutorial.html)]|
|
||||
|
||||
|
||||
2
tutorials/tutorial11/tutorial.ipynb
vendored
2
tutorials/tutorial11/tutorial.ipynb
vendored
@@ -287,7 +287,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<p align=\\\"center\\\">\n",
|
||||
"<img src=\"../static/logging.png\" alt=\\\"Logging API\\\" width=\\\"400\\\"/>\n",
|
||||
" <img src=\"http://raw.githubusercontent.com/mathLab/PINA/master/tutorials/static/logging.png\" alt=\\\"Logging API\\\" width=\\\"400\\\"/>\n",
|
||||
"</p>"
|
||||
]
|
||||
},
|
||||
|
||||
2
tutorials/tutorial14/tutorial.ipynb
vendored
2
tutorials/tutorial14/tutorial.ipynb
vendored
@@ -58,7 +58,7 @@
|
||||
"This approach allows the ensemble to capture different perspectives of the problem, leading to more accurate and reliable predictions.\n",
|
||||
"\n",
|
||||
"<p align=\"center\">\n",
|
||||
" <img src=\"../static/deep_ensemble.png\" alt=\"PINA Workflow\" width=\"600\"/>\n",
|
||||
" <img src=\"http://raw.githubusercontent.com/mathLab/PINA/master/tutorials/static/deep_ensemble.png\" alt=\"Deep ensemble\" width=\"600\"/>\n",
|
||||
"</p>\n",
|
||||
"\n",
|
||||
"The image above illustrates a Deep Ensemble setup, where multiple models attempt to predict the text from an image. While individual models may make errors (e.g., predicting \"PONY\" instead of \"PINA\"), combining their outputs—such as taking the majority vote—often leads to the correct result. This ensemble effect improves reliability by mitigating the impact of individual model biases.\n",
|
||||
|
||||
5
tutorials/tutorial17/tutorial.ipynb
vendored
5
tutorials/tutorial17/tutorial.ipynb
vendored
@@ -11,9 +11,10 @@
|
||||
"[](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial17/tutorial.ipynb)\n",
|
||||
"\n",
|
||||
"<p align=\"left\">\n",
|
||||
" <img src=\"../static/pina_logo.png\" alt=\"PINA Logo\" width=\"90\"/>\n",
|
||||
" <img src=\"https://raw.githubusercontent.com/mathLab/PINA/master/readme/pina_logo.png\" alt=\"PINA logo\" width=\"90\"/>\n",
|
||||
"</p>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Welcome to **PINA**!\n",
|
||||
"\n",
|
||||
"PINA [1] is an open-source Python library designed for **Scientific Machine Learning (SciML)** tasks, particularly involving:\n",
|
||||
@@ -39,7 +40,7 @@
|
||||
"## The PINA Workflow \n",
|
||||
"\n",
|
||||
"<p align=\"center\">\n",
|
||||
" <img src=\"../static/pina_wokflow.png\" alt=\"PINA Workflow\" width=\"1000\"/>\n",
|
||||
" <img src=\"http://raw.githubusercontent.com/mathLab/PINA/master/tutorials/static/pina_wokflow.png\" alt=\"PINA Workflow\" width=\"1000\"/>\n",
|
||||
"</p>\n",
|
||||
"\n",
|
||||
"Solving a differential problem in **PINA** involves four main steps:\n",
|
||||
|
||||
3
tutorials/tutorial21/tutorial.ipynb
vendored
3
tutorials/tutorial21/tutorial.ipynb
vendored
@@ -163,10 +163,9 @@
|
||||
"At their core, **Neural Operators** transform an input function $a$ into an output function $u$. The general structure of a Neural Operator consists of three key components:\n",
|
||||
"\n",
|
||||
"<p align=\"center\">\n",
|
||||
" <img src=\"../static/neural_operator.png\" alt=\"Neural Operators\" width=\"800\"/>\n",
|
||||
" <img src=\"http://raw.githubusercontent.com/mathLab/PINA/master/tutorials/static/neural_operator.png\" alt=\"Neural Operators\" width=\"800\"/>\n",
|
||||
"</p>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"1. **Encoder**: The encoder maps the input into a specific embedding space.\n",
|
||||
"\n",
|
||||
"2. **Processor**: The processor consists of multiple layers performing **function convolutions**, which is the core computational unit in a Neural Operator. \n",
|
||||
|
||||
2
tutorials/tutorial9/tutorial.ipynb
vendored
2
tutorials/tutorial9/tutorial.ipynb
vendored
@@ -4,7 +4,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tutorial: Applying Periodic Boundary Conditions in PINNs to solve the Helmotz Problem\n",
|
||||
"# Tutorial: Applying Periodic Boundary Conditions in PINNs to solve the Helmholtz Problem\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial9/tutorial.ipynb)\n",
|
||||
"\n",
|
||||
|
||||
Reference in New Issue
Block a user