Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -5,7 +5,7 @@ import torch
from .solver import MultiSolverInterface
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
from ..condition import InputTargetCondition
from ..utils import check_consistency
from ..loss import LossInterface, PowerLoss
from torch.nn.modules.loss import _Loss
@@ -25,7 +25,7 @@ class GAROM(MultiSolverInterface):
<https://doi.org/10.48550/arXiv.2305.15881>`_.
"""
accepted_conditions_types = InputOutputPointsCondition
accepted_conditions_types = InputTargetCondition
def __init__(
self,
@@ -70,8 +70,8 @@ class GAROM(MultiSolverInterface):
.. warning::
The algorithm works only for data-driven model. Hence in the ``problem`` definition
the codition must only contain ``input_points`` (e.g. coefficient parameters, time
parameters), and ``output_points``.
the codition must only contain ``input`` (e.g. coefficient parameters, time
parameters), and ``target``.
"""
# set loss
@@ -233,8 +233,8 @@ class GAROM(MultiSolverInterface):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
parameters, snapshots
@@ -257,8 +257,8 @@ class GAROM(MultiSolverInterface):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
@@ -272,8 +272,8 @@ class GAROM(MultiSolverInterface):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(

View File

@@ -9,8 +9,8 @@ from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
@@ -28,8 +28,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
accepted_conditions_types = (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
@@ -138,16 +138,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if "output_points" not in points:
input_pts = points["input_points"]
if "target" not in points:
input_pts = points["input"]
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(), condition.equation
)
# if data are passed
else:
input_pts = points["input_points"]
output_pts = points["output_points"]
input_pts = points["input"]
output_pts = points["target"]
loss = self.loss_data(
input_pts=input_pts.requires_grad_(), output_pts=output_pts
)

View File

@@ -262,7 +262,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
for (
condition_name,
tensor,
) in self.trainer.data_module.train_dataset.input_points.items():
) in self.trainer.data_module.train_dataset.input.items():
self.weights_dict[condition_name].sa_weights.data = torch.rand(
(tensor.shape[0], 1), device=device
)

View File

@@ -75,8 +75,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
.. warning::
This solver works only for data-driven model. Hence in the ``problem``
definition the codition must only contain ``input_points``
(e.g. coefficient parameters, time parameters), and ``output_points``.
definition the codition must only contain ``input``
(e.g. coefficient parameters, time parameters), and ``target``.
.. warning::
This solver does not currently support the possibility to pass

View File

@@ -172,7 +172,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
# assuming batch is a custom Batch object
batch_size = 0
for data in batch:
batch_size += len(data[1]["input_points"])
batch_size += len(data[1]["input"])
return batch_size
@staticmethod

View File

@@ -5,7 +5,7 @@ from torch.nn.modules.loss import _Loss
from .solver import SingleSolverInterface
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
from ..condition import InputTargetCondition
class SupervisedSolver(SingleSolverInterface):
@@ -37,7 +37,7 @@ class SupervisedSolver(SingleSolverInterface):
multiple (discretised) input functions.
"""
accepted_conditions_types = InputOutputPointsCondition
accepted_conditions_types = InputTargetCondition
def __init__(
self,
@@ -95,8 +95,8 @@ class SupervisedSolver(SingleSolverInterface):
condition_loss = {}
for condition_name, points in batch:
input_pts, output_pts = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
condition_loss[condition_name] = self.loss_data(
input_pts=input_pts, output_pts=output_pts