Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -5,7 +5,7 @@ import torch
|
||||
from .solver import MultiSolverInterface
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
from ..condition import InputTargetCondition
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface, PowerLoss
|
||||
from torch.nn.modules.loss import _Loss
|
||||
@@ -25,7 +25,7 @@ class GAROM(MultiSolverInterface):
|
||||
<https://doi.org/10.48550/arXiv.2305.15881>`_.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = InputOutputPointsCondition
|
||||
accepted_conditions_types = InputTargetCondition
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -70,8 +70,8 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
.. warning::
|
||||
The algorithm works only for data-driven model. Hence in the ``problem`` definition
|
||||
the codition must only contain ``input_points`` (e.g. coefficient parameters, time
|
||||
parameters), and ``output_points``.
|
||||
the codition must only contain ``input`` (e.g. coefficient parameters, time
|
||||
parameters), and ``target``.
|
||||
"""
|
||||
|
||||
# set loss
|
||||
@@ -233,8 +233,8 @@ class GAROM(MultiSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
|
||||
parameters, snapshots
|
||||
@@ -257,8 +257,8 @@ class GAROM(MultiSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
@@ -272,8 +272,8 @@ class GAROM(MultiSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
|
||||
@@ -9,8 +9,8 @@ from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ...condition import (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
@@ -28,8 +28,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
|
||||
accepted_conditions_types = (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
@@ -138,16 +138,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if "output_points" not in points:
|
||||
input_pts = points["input_points"]
|
||||
if "target" not in points:
|
||||
input_pts = points["input"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points["input_points"]
|
||||
output_pts = points["output_points"]
|
||||
input_pts = points["input"]
|
||||
output_pts = points["target"]
|
||||
loss = self.loss_data(
|
||||
input_pts=input_pts.requires_grad_(), output_pts=output_pts
|
||||
)
|
||||
|
||||
@@ -262,7 +262,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
for (
|
||||
condition_name,
|
||||
tensor,
|
||||
) in self.trainer.data_module.train_dataset.input_points.items():
|
||||
) in self.trainer.data_module.train_dataset.input.items():
|
||||
self.weights_dict[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1), device=device
|
||||
)
|
||||
|
||||
@@ -75,8 +75,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
|
||||
.. warning::
|
||||
This solver works only for data-driven model. Hence in the ``problem``
|
||||
definition the codition must only contain ``input_points``
|
||||
(e.g. coefficient parameters, time parameters), and ``output_points``.
|
||||
definition the codition must only contain ``input``
|
||||
(e.g. coefficient parameters, time parameters), and ``target``.
|
||||
|
||||
.. warning::
|
||||
This solver does not currently support the possibility to pass
|
||||
|
||||
@@ -172,7 +172,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
# assuming batch is a custom Batch object
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]["input_points"])
|
||||
batch_size += len(data[1]["input"])
|
||||
return batch_size
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,7 +5,7 @@ from torch.nn.modules.loss import _Loss
|
||||
from .solver import SingleSolverInterface
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
from ..condition import InputTargetCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SingleSolverInterface):
|
||||
@@ -37,7 +37,7 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = InputOutputPointsCondition
|
||||
accepted_conditions_types = InputTargetCondition
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -95,8 +95,8 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
condition_loss[condition_name] = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts
|
||||
|
||||
Reference in New Issue
Block a user