supervised working
This commit is contained in:
@@ -82,6 +82,7 @@ class SupervisedSolver(SolverInterface):
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
self.loss = loss
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -90,7 +91,16 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._pina_model(x)
|
||||
|
||||
output = self._pina_model[0](x)
|
||||
|
||||
output.labels = {
|
||||
1: {
|
||||
"name": "output",
|
||||
"dof": self.problem.output_variables
|
||||
}
|
||||
}
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the solver.
|
||||
@@ -98,9 +108,12 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
self._pina_optimizer.hook(self._pina_model.parameters())
|
||||
self._pina_scheduler.hook(self._pina_optimizer)
|
||||
return self._pina_optimizer, self._pina_scheduler
|
||||
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
|
||||
self._pina_scheduler[0].hook(self._pina_optimizer[0])
|
||||
return (
|
||||
[self._pina_optimizer[0].optimizer_instance],
|
||||
[self._pina_scheduler[0].scheduler_instance]
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Solver training step.
|
||||
@@ -113,14 +126,16 @@ class SupervisedSolver(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
condition_idx = batch["condition"]
|
||||
condition_idx = batch.condition
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
out = batch["output"]
|
||||
pts = batch.input
|
||||
out = batch.output
|
||||
print(out)
|
||||
print(pts)
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
@@ -134,9 +149,11 @@ class SupervisedSolver(SolverInterface):
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
input_pts.labels = pts.labels
|
||||
output_pts.labels = out.labels
|
||||
|
||||
loss = (
|
||||
self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
* condition.data_weight
|
||||
)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
@@ -155,6 +172,10 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
print(input_pts)
|
||||
print(output_pts)
|
||||
print(self.loss)
|
||||
print(self.forward(input_pts))
|
||||
return self.loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user