Implementation of DataLoader and DataModule (#383)
Refactoring for 0.2 * Data module, data loader and dataset * Refactor LabelTensor * Refactor solvers Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
dd43c8304c
commit
a27bd35443
@@ -1,34 +0,0 @@
|
||||
from .supervised import SupervisedSolver
|
||||
from ..graph import Graph
|
||||
|
||||
|
||||
class GraphSupervisedSolver(SupervisedSolver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
nodes_coordinates,
|
||||
nodes_data,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None):
|
||||
super().__init__(problem, model, loss, optimizer, scheduler)
|
||||
if isinstance(nodes_coordinates, str):
|
||||
self._nodes_coordinates = [nodes_coordinates]
|
||||
else:
|
||||
self._nodes_coordinates = nodes_coordinates
|
||||
if isinstance(nodes_data, str):
|
||||
self._nodes_data = nodes_data
|
||||
else:
|
||||
self._nodes_data = nodes_data
|
||||
|
||||
def forward(self, input):
|
||||
input_coords = input.extract(self._nodes_coordinates)
|
||||
input_data = input.extract(self._nodes_data)
|
||||
|
||||
if not isinstance(input, Graph):
|
||||
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
|
||||
g = self.model(input.data, edge_index=input.data.edge_index)
|
||||
g.labels = {1: {'name': 'output', 'dof': ['u']}}
|
||||
return g
|
||||
@@ -1,14 +1,15 @@
|
||||
""" Module for PINN """
|
||||
|
||||
import sys
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
|
||||
from ...solvers.solver import SolverInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.loss.loss_interface import LossInterface
|
||||
from pina.problem import InverseProblem
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from ...condition import InputOutputPointsCondition
|
||||
from ...solvers.solver import SolverInterface
|
||||
from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ...condition import DomainEquationCondition
|
||||
from ...optim import TorchOptimizer, TorchScheduler
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
@@ -25,13 +26,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
to the user to choose which problem the implemented solver inheriting from
|
||||
this class is suitable for.
|
||||
"""
|
||||
|
||||
accepted_condition_types = [DomainEquationCondition.condition_type[0],
|
||||
InputOutputPointsCondition.condition_type[0]]
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
optimizers_kwargs,
|
||||
schedulers,
|
||||
extra_features,
|
||||
loss,
|
||||
):
|
||||
@@ -53,11 +55,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
"""
|
||||
if optimizers is None:
|
||||
optimizers = TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
if schedulers is None:
|
||||
schedulers = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
super().__init__(
|
||||
models=models,
|
||||
problem=problem,
|
||||
optimizers=optimizers,
|
||||
optimizers_kwargs=optimizers_kwargs,
|
||||
schedulers=schedulers,
|
||||
extra_features=extra_features,
|
||||
)
|
||||
|
||||
@@ -85,7 +96,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
# variable will be stored with name = self.__logged_metric
|
||||
self.__logged_metric = None
|
||||
|
||||
def training_step(self, batch, _):
|
||||
self._model = self._pina_models[0]
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
|
||||
|
||||
def training_step(self, batch):
|
||||
"""
|
||||
The Physics Informed Solver Training Step. This function takes care
|
||||
of the physics informed training step, and it must not be override
|
||||
@@ -99,53 +115,68 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
condition_losses = []
|
||||
condition_idx = batch["condition"]
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
# condition name is logged (if logs enabled)
|
||||
self.__logged_metric = condition_name
|
||||
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self.loss_phys(samples, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch["output"][condition_idx == condition_id]
|
||||
loss = self.loss_data(samples, ground_truth)
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
input_pts = points['input_points']
|
||||
|
||||
# add condition losses for each epoch
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
loss = sum(condition_loss)
|
||||
self.log('train_loss', loss, prog_bar=True, on_epoch=True,
|
||||
logger=True, batch_size=self.get_batch_size(batch),
|
||||
sync_dist=True)
|
||||
|
||||
# total loss (must be a torch.Tensor), and logs
|
||||
total_loss = sum(condition_losses)
|
||||
self.save_logs_and_release()
|
||||
return total_loss.as_subclass(torch.Tensor)
|
||||
return loss
|
||||
|
||||
def loss_data(self, input_tensor, output_tensor):
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
TODO: add docstring
|
||||
"""
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
with torch.set_grad_enabled(True):
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
|
||||
loss = sum(condition_loss)
|
||||
self.log('val_loss', loss, on_epoch=True, prog_bar=True,
|
||||
logger=True, batch_size=self.get_batch_size(batch),
|
||||
sync_dist=True)
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the PINN solver. It computes the loss between
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
:param LabelTensor input_pts: The input to the neural networks.
|
||||
:param LabelTensor output_pts: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss_value = self.loss(self.forward(input_tensor), output_tensor)
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
return self.loss(self.forward(input_tensor), output_tensor)
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@abstractmethod
|
||||
def loss_phys(self, samples, equation):
|
||||
@@ -196,13 +227,17 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param str name: The name of the loss.
|
||||
:param torch.Tensor loss_value: The value of the loss.
|
||||
"""
|
||||
batch_size = self.trainer.data_module.batch_size \
|
||||
if self.trainer.data_module.batch_size is not None else 999
|
||||
|
||||
self.log(
|
||||
self.__logged_metric + "_loss",
|
||||
loss_value,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
on_step=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
self.__logged_res_losses.append(loss_value)
|
||||
|
||||
|
||||
@@ -9,10 +9,8 @@ except ImportError:
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .basepinn import PINNInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.problem import InverseProblem
|
||||
|
||||
|
||||
@@ -56,16 +54,16 @@ class PINN(PINNInterface):
|
||||
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
|
||||
__name__ = 'PINN'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
@@ -82,20 +80,15 @@ class PINN(PINNInterface):
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
super().__init__(
|
||||
models=[model],
|
||||
models=model,
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict)
|
||||
|
||||
# assign variables
|
||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
||||
self._neural_net = self.models[0]
|
||||
|
||||
def forward(self, x):
|
||||
@@ -126,9 +119,8 @@ class PINN(PINNInterface):
|
||||
"""
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
loss_value = self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
torch.zeros_like(residual), residual
|
||||
)
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
return loss_value
|
||||
|
||||
def configure_optimizers(self):
|
||||
@@ -141,16 +133,21 @@ class PINN(PINNInterface):
|
||||
"""
|
||||
# if the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters that the optimizer needs to optimize
|
||||
|
||||
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizers[0].add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
return self.optimizers, [self.scheduler]
|
||||
self._optimizer.optimizer_instance.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return ([self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance])
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import pytorch_lightning
|
||||
import lightning
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
from ..optim import Optimizer, Scheduler
|
||||
@@ -10,7 +10,8 @@ import torch
|
||||
import sys
|
||||
|
||||
|
||||
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver base class. This class inherits is a wrapper of
|
||||
LightningModule class, inheriting all the
|
||||
@@ -83,7 +84,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
" optimizers.")
|
||||
|
||||
# extra features handling
|
||||
|
||||
self._pina_models = models
|
||||
self._pina_optimizers = optimizers
|
||||
self._pina_schedulers = schedulers
|
||||
@@ -94,7 +94,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -138,8 +138,16 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
TODO
|
||||
"""
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(self.accepted_condition_types).issubset(
|
||||
condition.condition_type):
|
||||
if not set(condition.condition_type).issubset(
|
||||
set(self.accepted_condition_types)):
|
||||
raise ValueError(
|
||||
f'{self.__name__} support only dose not support condition '
|
||||
f'{self.__name__} dose not support condition '
|
||||
f'{condition.condition_type}')
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
# Assuming batch is your custom Batch object
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]['input_points'])
|
||||
return batch_size
|
||||
@@ -1,12 +1,14 @@
|
||||
""" Module for SupervisedSolver """
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from sympy.strategies.branch import condition
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from ..optim import TorchOptimizer, TorchScheduler
|
||||
from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
@@ -37,7 +39,7 @@ class SupervisedSolver(SolverInterface):
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
accepted_condition_types = ['supervised']
|
||||
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
def __init__(self,
|
||||
@@ -46,7 +48,8 @@ class SupervisedSolver(SolverInterface):
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None):
|
||||
extra_features=None,
|
||||
use_lt=True):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
@@ -72,14 +75,19 @@ class SupervisedSolver(SolverInterface):
|
||||
problem=problem,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features)
|
||||
extra_features=extra_features,
|
||||
use_lt=use_lt)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
|
||||
subclass=False)
|
||||
self._loss = loss
|
||||
self._model = self._pina_models[0]
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
self.validation_condition_losses = {
|
||||
k: {'loss': [],
|
||||
'count': []} for k in self.problem.conditions.keys()}
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -105,7 +113,7 @@ class SupervisedSolver(SolverInterface):
|
||||
return ([self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance])
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch):
|
||||
"""Solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
@@ -115,33 +123,37 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
condition_idx = batch.supervised.condition_indices
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch.supervised.input_points
|
||||
out = batch.supervised.output_points
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
|
||||
# for data driven mode
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} works only in data-driven mode.")
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
input_pts.labels = pts.labels
|
||||
output_pts.labels = out.labels
|
||||
|
||||
loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
loss = sum(condition_loss)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
Solver validation step.
|
||||
"""
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
loss = sum(condition_loss)
|
||||
self.log('val_loss', loss, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
||||
"""
|
||||
Solver test step.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("Test step not implemented yet.")
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the Supervised solver. It computes the loss between
|
||||
|
||||
Reference in New Issue
Block a user