Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -1,34 +0,0 @@
from .supervised import SupervisedSolver
from ..graph import Graph
class GraphSupervisedSolver(SupervisedSolver):
def __init__(
self,
problem,
model,
nodes_coordinates,
nodes_data,
loss=None,
optimizer=None,
scheduler=None):
super().__init__(problem, model, loss, optimizer, scheduler)
if isinstance(nodes_coordinates, str):
self._nodes_coordinates = [nodes_coordinates]
else:
self._nodes_coordinates = nodes_coordinates
if isinstance(nodes_data, str):
self._nodes_data = nodes_data
else:
self._nodes_data = nodes_data
def forward(self, input):
input_coords = input.extract(self._nodes_coordinates)
input_data = input.extract(self._nodes_data)
if not isinstance(input, Graph):
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
g = self.model(input.data, edge_index=input.data.edge_index)
g.labels = {1: {'name': 'output', 'dof': ['u']}}
return g

View File

@@ -1,14 +1,15 @@
""" Module for PINN """
import sys
from abc import ABCMeta, abstractmethod
import torch
from ...solvers.solver import SolverInterface
from pina.utils import check_consistency
from pina.loss.loss_interface import LossInterface
from pina.problem import InverseProblem
from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -25,13 +26,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""
accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]
def __init__(
self,
models,
problem,
optimizers,
optimizers_kwargs,
schedulers,
extra_features,
loss,
):
@@ -53,11 +55,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
"""
if optimizers is None:
optimizers = TorchOptimizer(torch.optim.Adam, lr=0.001)
if schedulers is None:
schedulers = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(
models=models,
problem=problem,
optimizers=optimizers,
optimizers_kwargs=optimizers_kwargs,
schedulers=schedulers,
extra_features=extra_features,
)
@@ -85,7 +96,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
# variable will be stored with name = self.__logged_metric
self.__logged_metric = None
def training_step(self, batch, _):
self._model = self._pina_models[0]
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
def training_step(self, batch):
"""
The Physics Informed Solver Training Step. This function takes care
of the physics informed training step, and it must not be override
@@ -99,53 +115,68 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:rtype: LabelTensor
"""
condition_losses = []
condition_idx = batch["condition"]
condition_loss = []
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch["pts"]
# condition name is logged (if logs enabled)
self.__logged_metric = condition_name
if len(batch) == 2:
samples = pts[condition_idx == condition_id]
loss = self.loss_phys(samples, condition.equation)
elif len(batch) == 3:
samples = pts[condition_idx == condition_id]
ground_truth = batch["output"][condition_idx == condition_id]
loss = self.loss_data(samples, ground_truth)
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
raise ValueError("Batch size not supported")
input_pts = points['input_points']
# add condition losses for each epoch
condition_losses.append(loss * condition.data_weight)
condition = self.problem.conditions[condition_name]
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
loss = sum(condition_loss)
self.log('train_loss', loss, prog_bar=True, on_epoch=True,
logger=True, batch_size=self.get_batch_size(batch),
sync_dist=True)
# total loss (must be a torch.Tensor), and logs
total_loss = sum(condition_losses)
self.save_logs_and_release()
return total_loss.as_subclass(torch.Tensor)
return loss
def loss_data(self, input_tensor, output_tensor):
def validation_step(self, batch):
"""
TODO: add docstring
"""
condition_loss = []
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
with torch.set_grad_enabled(True):
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
loss = sum(condition_loss)
self.log('val_loss', loss, on_epoch=True, prog_bar=True,
logger=True, batch_size=self.get_batch_size(batch),
sync_dist=True)
def loss_data(self, input_pts, output_pts):
"""
The data loss for the PINN solver. It computes the loss between
the network output against the true solution. This function
should not be override if not intentionally.
:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
:param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_pts: The true solution to compare the
network solution.
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
loss_value = self.loss(self.forward(input_tensor), output_tensor)
self.store_log(loss_value=float(loss_value))
return self.loss(self.forward(input_tensor), output_tensor)
return self._loss(self.forward(input_pts), output_pts)
@abstractmethod
def loss_phys(self, samples, equation):
@@ -196,13 +227,17 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
:param str name: The name of the loss.
:param torch.Tensor loss_value: The value of the loss.
"""
batch_size = self.trainer.data_module.batch_size \
if self.trainer.data_module.batch_size is not None else 999
self.log(
self.__logged_metric + "_loss",
loss_value,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=False,
on_step=True,
batch_size=batch_size,
)
self.__logged_res_losses.append(loss_value)

View File

@@ -9,10 +9,8 @@ except ImportError:
_LRScheduler as LRScheduler,
) # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .basepinn import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem
@@ -56,16 +54,16 @@ class PINN(PINNInterface):
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
"""
__name__ = 'PINN'
def __init__(
self,
problem,
model,
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
loss=None,
optimizer=None,
scheduler=None,
):
"""
:param AbstractProblem problem: The formulation of the problem.
@@ -82,20 +80,15 @@ class PINN(PINNInterface):
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
"""
super().__init__(
models=[model],
models=model,
problem=problem,
optimizers=[optimizer],
optimizers_kwargs=[optimizer_kwargs],
optimizers=optimizer,
schedulers=scheduler,
extra_features=extra_features,
loss=loss,
)
# check consistency
check_consistency(scheduler, LRScheduler, subclass=True)
check_consistency(scheduler_kwargs, dict)
# assign variables
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
self._neural_net = self.models[0]
def forward(self, x):
@@ -126,9 +119,8 @@ class PINN(PINNInterface):
"""
residual = self.compute_residual(samples=samples, equation=equation)
loss_value = self.loss(
torch.zeros_like(residual, requires_grad=True), residual
torch.zeros_like(residual), residual
)
self.store_log(loss_value=float(loss_value))
return loss_value
def configure_optimizers(self):
@@ -141,16 +133,21 @@ class PINN(PINNInterface):
"""
# if the problem is an InverseProblem, add the unknown parameters
# to the parameters that the optimizer needs to optimize
self._optimizer.hook(self._model.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizers[0].add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
return self.optimizers, [self.scheduler]
self._optimizer.optimizer_instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self._scheduler.hook(self._optimizer)
return ([self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance])
@property
def scheduler(self):

View File

@@ -2,7 +2,7 @@
from abc import ABCMeta, abstractmethod
from ..model.network import Network
import pytorch_lightning
import lightning
from ..utils import check_consistency
from ..problem import AbstractProblem
from ..optim import Optimizer, Scheduler
@@ -10,7 +10,8 @@ import torch
import sys
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver base class. This class inherits is a wrapper of
LightningModule class, inheriting all the
@@ -83,7 +84,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
" optimizers.")
# extra features handling
self._pina_models = models
self._pina_optimizers = optimizers
self._pina_schedulers = schedulers
@@ -94,7 +94,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
pass
@abstractmethod
def training_step(self, batch, batch_idx):
def training_step(self, batch):
pass
@abstractmethod
@@ -138,8 +138,16 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
TODO
"""
for _, condition in problem.conditions.items():
if not set(self.accepted_condition_types).issubset(
condition.condition_type):
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} support only dose not support condition '
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')
@staticmethod
def get_batch_size(batch):
# Assuming batch is your custom Batch object
batch_size = 0
for data in batch:
batch_size += len(data[1]['input_points'])
return batch_size

View File

@@ -1,12 +1,14 @@
""" Module for SupervisedSolver """
import torch
from pytorch_lightning.utilities.types import STEP_OUTPUT
from sympy.strategies.branch import condition
from torch.nn.modules.loss import _Loss
from ..optim import TorchOptimizer, TorchScheduler
from .solver import SolverInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface):
@@ -37,7 +39,7 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
"""
accepted_condition_types = ['supervised']
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
__name__ = 'SupervisedSolver'
def __init__(self,
@@ -46,7 +48,8 @@ class SupervisedSolver(SolverInterface):
loss=None,
optimizer=None,
scheduler=None,
extra_features=None):
extra_features=None,
use_lt=True):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use.
@@ -72,14 +75,19 @@ class SupervisedSolver(SolverInterface):
problem=problem,
optimizers=optimizer,
schedulers=scheduler,
extra_features=extra_features)
extra_features=extra_features,
use_lt=use_lt)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
subclass=False)
self._loss = loss
self._model = self._pina_models[0]
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
self.validation_condition_losses = {
k: {'loss': [],
'count': []} for k in self.problem.conditions.keys()}
def forward(self, x):
"""Forward pass implementation for the solver.
@@ -105,7 +113,7 @@ class SupervisedSolver(SolverInterface):
return ([self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance])
def training_step(self, batch, batch_idx):
def training_step(self, batch):
"""Solver training step.
:param batch: The batch element in the dataloader.
@@ -115,33 +123,37 @@ class SupervisedSolver(SolverInterface):
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
condition_idx = batch.supervised.condition_indices
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch.supervised.input_points
out = batch.supervised.output_points
if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.")
# for data driven mode
if not hasattr(condition, "output_points"):
raise NotImplementedError(
f"{type(self).__name__} works only in data-driven mode.")
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]
input_pts.labels = pts.labels
output_pts.labels = out.labels
loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss = loss.as_subclass(torch.Tensor)
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
condition_loss = []
for condition_name, points in batch:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
loss = sum(condition_loss)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True)
return loss
def validation_step(self, batch):
"""
Solver validation step.
"""
condition_loss = []
for condition_name, points in batch:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
loss = sum(condition_loss)
self.log('val_loss', loss, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True)
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
"""
Solver test step.
"""
raise NotImplementedError("Test step not implemented yet.")
def loss_data(self, input_pts, output_pts):
"""
The data loss for the Supervised solver. It computes the loss between