Add check conditions-solver consistency

This commit is contained in:
FilippoOlivo
2025-01-16 19:42:46 +01:00
committed by Nicola Demo
parent f2340cd4ee
commit a6f0336d06
3 changed files with 20 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ from .solver import SolverInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface):
@@ -37,7 +38,8 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
"""
__name__ = 'SupervisedSolver'
accepted_conditions_types = InputOutputPointsCondition
def __init__(self,
problem,
@@ -143,7 +145,6 @@ class SupervisedSolver(SolverInterface):
self.log('val_loss', loss, prog_bar=True, logger=True,
batch_size=self.get_batch_size(batch), sync_dist=True)
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
"""
Solver test step.