Network handles forward for all solvers
This commit is contained in:
committed by
Nicola Demo
parent
4844640727
commit
c90301c204
@@ -72,13 +72,7 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract torch.Tensor from corresponding label
|
||||
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
|
||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||
output = self.neural_net(x).as_subclass(LabelTensor)
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self.problem.output_variables
|
||||
return output
|
||||
return self.neural_net(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the solver.
|
||||
@@ -125,37 +119,6 @@ class SupervisedSolver(SolverInterface):
|
||||
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
|
||||
def training_step_(self, batch, batch_idx):
|
||||
"""Solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:param batch_idx: The batch index.
|
||||
:type batch_idx: int
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
# data loss
|
||||
if hasattr(condition, 'output_points'):
|
||||
input_pts, output_pts = samples
|
||||
loss = self.loss(self.forward(input_pts),
|
||||
output_pts) * condition.data_weight
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Supervised solver works only in data-driven mode.')
|
||||
|
||||
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user