Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -5,7 +5,7 @@ import torch
|
||||
from .solver import MultiSolverInterface
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
from ..condition import InputTargetCondition
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface, PowerLoss
|
||||
from torch.nn.modules.loss import _Loss
|
||||
@@ -25,7 +25,7 @@ class GAROM(MultiSolverInterface):
|
||||
<https://doi.org/10.48550/arXiv.2305.15881>`_.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = InputOutputPointsCondition
|
||||
accepted_conditions_types = InputTargetCondition
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -70,8 +70,8 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
.. warning::
|
||||
The algorithm works only for data-driven model. Hence in the ``problem`` definition
|
||||
the codition must only contain ``input_points`` (e.g. coefficient parameters, time
|
||||
parameters), and ``output_points``.
|
||||
the codition must only contain ``input`` (e.g. coefficient parameters, time
|
||||
parameters), and ``target``.
|
||||
"""
|
||||
|
||||
# set loss
|
||||
@@ -233,8 +233,8 @@ class GAROM(MultiSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
|
||||
parameters, snapshots
|
||||
@@ -257,8 +257,8 @@ class GAROM(MultiSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
@@ -272,8 +272,8 @@ class GAROM(MultiSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
parameters, snapshots = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
snapshots_gen = self.generator(parameters)
|
||||
condition_loss[condition_name] = self._loss(
|
||||
|
||||
Reference in New Issue
Block a user