fix Supervised/PINN solvers forward + fix tut5

This commit is contained in:
Dario Coscia
2023-11-09 11:24:00 +01:00
committed by Nicola Demo
parent 4977d55507
commit 4844640727
6 changed files with 123 additions and 79 deletions

View File

@@ -83,11 +83,11 @@ class PINN(SolverInterface):
:return: PINN solution.
:rtype: torch.Tensor
"""
# extract labels
x = x.extract(self.problem.input_variables)
# perform forward pass
# extract torch.Tensor from corresponding label
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels
# set the labels for LabelTensor
output.labels = self.problem.output_variables
return output

View File

@@ -72,11 +72,11 @@ class SupervisedSolver(SolverInterface):
:return: Solver solution.
:rtype: torch.Tensor
"""
# extract labels
x = x.extract(self.problem.input_variables)
# perform forward pass
# extract torch.Tensor from corresponding label
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels
# set the labels for LabelTensor
output.labels = self.problem.output_variables
return output
@@ -99,6 +99,44 @@ class SupervisedSolver(SolverInterface):
:rtype: LabelTensor
"""
dataloader = self.trainer.train_dataloader
condition_idx = batch['condition']
for condition_id in range(condition_idx.min(), condition_idx.max()+1):
condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
out = batch['output']
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
# for data driven mode
if not hasattr(condition, 'output_points'):
raise NotImplementedError('Supervised solver works only in data-driven mode.')
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]
loss = self.loss(self.forward(input_pts), output_pts) * condition.data_weight
loss = loss.as_subclass(torch.Tensor)
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
return loss
def training_step_(self, batch, batch_idx):
"""Solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param batch_idx: The batch index.
:type batch_idx: int
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
for condition_name, samples in batch.items():
if condition_name not in self.problem.conditions: