Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -5,7 +5,7 @@ from torch.nn.modules.loss import _Loss
|
||||
from .solver import SingleSolverInterface
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
from ..condition import InputTargetCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SingleSolverInterface):
|
||||
@@ -37,7 +37,7 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = InputOutputPointsCondition
|
||||
accepted_conditions_types = InputTargetCondition
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -95,8 +95,8 @@ class SupervisedSolver(SingleSolverInterface):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = (
|
||||
points["input_points"],
|
||||
points["output_points"],
|
||||
points["input"],
|
||||
points["target"],
|
||||
)
|
||||
condition_loss[condition_name] = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts
|
||||
|
||||
Reference in New Issue
Block a user