Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -1,14 +1,15 @@
""" Module for PINN """
import sys
from abc import ABCMeta, abstractmethod
import torch
from ...solvers.solver import SolverInterface
from pina.utils import check_consistency
from pina.loss.loss_interface import LossInterface
from pina.problem import InverseProblem
from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -25,13 +26,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""
accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]
def __init__(
self,
models,
problem,
optimizers,
optimizers_kwargs,
schedulers,
extra_features,
loss,
):
@@ -53,11 +55,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
"""
if optimizers is None:
optimizers = TorchOptimizer(torch.optim.Adam, lr=0.001)
if schedulers is None:
schedulers = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(
models=models,
problem=problem,
optimizers=optimizers,
optimizers_kwargs=optimizers_kwargs,
schedulers=schedulers,
extra_features=extra_features,
)
@@ -85,7 +96,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
# variable will be stored with name = self.__logged_metric
self.__logged_metric = None
def training_step(self, batch, _):
self._model = self._pina_models[0]
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
def training_step(self, batch):
"""
The Physics Informed Solver Training Step. This function takes care
of the physics informed training step, and it must not be override
@@ -99,53 +115,68 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:rtype: LabelTensor
"""
condition_losses = []
condition_idx = batch["condition"]
condition_loss = []
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch["pts"]
# condition name is logged (if logs enabled)
self.__logged_metric = condition_name
if len(batch) == 2:
samples = pts[condition_idx == condition_id]
loss = self.loss_phys(samples, condition.equation)
elif len(batch) == 3:
samples = pts[condition_idx == condition_id]
ground_truth = batch["output"][condition_idx == condition_id]
loss = self.loss_data(samples, ground_truth)
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
raise ValueError("Batch size not supported")
input_pts = points['input_points']
# add condition losses for each epoch
condition_losses.append(loss * condition.data_weight)
condition = self.problem.conditions[condition_name]
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
loss = sum(condition_loss)
self.log('train_loss', loss, prog_bar=True, on_epoch=True,
logger=True, batch_size=self.get_batch_size(batch),
sync_dist=True)
# total loss (must be a torch.Tensor), and logs
total_loss = sum(condition_losses)
self.save_logs_and_release()
return total_loss.as_subclass(torch.Tensor)
return loss
def loss_data(self, input_tensor, output_tensor):
def validation_step(self, batch):
"""
TODO: add docstring
"""
condition_loss = []
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
with torch.set_grad_enabled(True):
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
loss = sum(condition_loss)
self.log('val_loss', loss, on_epoch=True, prog_bar=True,
logger=True, batch_size=self.get_batch_size(batch),
sync_dist=True)
def loss_data(self, input_pts, output_pts):
"""
The data loss for the PINN solver. It computes the loss between
the network output against the true solution. This function
should not be override if not intentionally.
:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
:param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_pts: The true solution to compare the
network solution.
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
loss_value = self.loss(self.forward(input_tensor), output_tensor)
self.store_log(loss_value=float(loss_value))
return self.loss(self.forward(input_tensor), output_tensor)
return self._loss(self.forward(input_pts), output_pts)
@abstractmethod
def loss_phys(self, samples, equation):
@@ -196,13 +227,17 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:param str name: The name of the loss.
:param torch.Tensor loss_value: The value of the loss.
"""
batch_size = self.trainer.data_module.batch_size \
if self.trainer.data_module.batch_size is not None else 999
self.log(
self.__logged_metric + "_loss",
loss_value,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=False,
on_step=True,
batch_size=batch_size,
)
self.__logged_res_losses.append(loss_value)

View File

@@ -9,10 +9,8 @@ except ImportError:
_LRScheduler as LRScheduler,
) # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .basepinn import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem
@@ -56,16 +54,16 @@ class PINN(PINNInterface):
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
"""
__name__ = 'PINN'
def __init__(
self,
problem,
model,
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
loss=None,
optimizer=None,
scheduler=None,
):
"""
:param AbstractProblem problem: The formulation of the problem.
@@ -82,20 +80,15 @@ class PINN(PINNInterface):
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
"""
super().__init__(
models=[model],
models=model,
problem=problem,
optimizers=[optimizer],
optimizers_kwargs=[optimizer_kwargs],
optimizers=optimizer,
schedulers=scheduler,
extra_features=extra_features,
loss=loss,
)
# check consistency
check_consistency(scheduler, LRScheduler, subclass=True)
check_consistency(scheduler_kwargs, dict)
# assign variables
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
self._neural_net = self.models[0]
def forward(self, x):
@@ -126,9 +119,8 @@ class PINN(PINNInterface):
"""
residual = self.compute_residual(samples=samples, equation=equation)
loss_value = self.loss(
torch.zeros_like(residual, requires_grad=True), residual
torch.zeros_like(residual), residual
)
self.store_log(loss_value=float(loss_value))
return loss_value
def configure_optimizers(self):
@@ -141,16 +133,21 @@ class PINN(PINNInterface):
"""
# if the problem is an InverseProblem, add the unknown parameters
# to the parameters that the optimizer needs to optimize
self._optimizer.hook(self._model.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizers[0].add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
return self.optimizers, [self.scheduler]
self._optimizer.optimizer_instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self._scheduler.hook(self._optimizer)
return ([self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance])
@property
def scheduler(self):