Add check conditions-solver consistency
This commit is contained in:
committed by
Nicola Demo
parent
f2340cd4ee
commit
a6f0336d06
@@ -7,6 +7,7 @@ from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
@@ -37,7 +38,8 @@ class SupervisedSolver(SolverInterface):
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
accepted_conditions_types = InputOutputPointsCondition
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
@@ -143,7 +145,6 @@ class SupervisedSolver(SolverInterface):
|
||||
self.log('val_loss', loss, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
||||
"""
|
||||
Solver test step.
|
||||
|
||||
Reference in New Issue
Block a user