Network handles forward for all solvers

This commit is contained in:
Dario Coscia
2023-11-09 15:16:57 +01:00
committed by Nicola Demo
parent 4844640727
commit c90301c204
5 changed files with 63 additions and 67 deletions

View File

@@ -72,13 +72,7 @@ class SupervisedSolver(SolverInterface):
:return: Solver solution.
:rtype: torch.Tensor
"""
# extract torch.Tensor from corresponding label
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self.problem.output_variables
return output
return self.neural_net(x)
def configure_optimizers(self):
"""Optimizer configuration for the solver.
@@ -125,37 +119,6 @@ class SupervisedSolver(SolverInterface):
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
return loss
def training_step_(self, batch, batch_idx):
"""Solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param batch_idx: The batch index.
:type batch_idx: int
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
for condition_name, samples in batch.items():
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition = self.problem.conditions[condition_name]
# data loss
if hasattr(condition, 'output_points'):
input_pts, output_pts = samples
loss = self.loss(self.forward(input_pts),
output_pts) * condition.data_weight
else:
raise RuntimeError(
'Supervised solver works only in data-driven mode.')
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
return loss
@property
def scheduler(self):
"""