Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -5,7 +5,7 @@ from torch.nn.modules.loss import _Loss
from .solver import SingleSolverInterface
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
from ..condition import InputTargetCondition
class SupervisedSolver(SingleSolverInterface):
@@ -37,7 +37,7 @@ class SupervisedSolver(SingleSolverInterface):
multiple (discretised) input functions.
"""
accepted_conditions_types = InputOutputPointsCondition
accepted_conditions_types = InputTargetCondition
def __init__(
self,
@@ -95,8 +95,8 @@ class SupervisedSolver(SingleSolverInterface):
condition_loss = {}
for condition_name, points in batch:
input_pts, output_pts = (
points["input_points"],
points["output_points"],
points["input"],
points["target"],
)
condition_loss[condition_name] = self.loss_data(
input_pts=input_pts, output_pts=output_pts