Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -5,7 +5,7 @@ import torch
from .solver import MultiSolverInterface
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
from ..condition import InputTargetCondition
from ..utils import check_consistency
from ..loss import LossInterface, PowerLoss
from torch.nn.modules.loss import _Loss
@@ -25,7 +25,7 @@ class GAROM(MultiSolverInterface):
<https://doi.org/10.48550/arXiv.2305.15881>`_.
"""
accepted_conditions_types = InputOutputPointsCondition
accepted_conditions_types = InputTargetCondition
def __init__(
self,
@@ -70,8 +70,8 @@ class GAROM(MultiSolverInterface):
.. warning::
The algorithm works only for data-driven model. Hence in the ``problem`` definition
the codition must only contain ``input_points`` (e.g. coefficient parameters, time
parameters), and ``output_points``.
the codition must only contain ``input`` (e.g. coefficient parameters, time
parameters), and ``target``.
"""
# set loss
@@ -233,8 +233,8 @@ class GAROM(MultiSolverInterface):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
parameters, snapshots
@@ -257,8 +257,8 @@ class GAROM(MultiSolverInterface):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
@@ -272,8 +272,8 @@ class GAROM(MultiSolverInterface):
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(