supervised working

This commit is contained in:
Nicola Demo
2024-08-08 16:19:52 +02:00
parent 5245a0b68c
commit 9d9c2aa23e
61 changed files with 375 additions and 262 deletions

View File

@@ -205,6 +205,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
# put everything in a list if only one input
if not isinstance(model, list):
model = [model]
if not isinstance(scheduler, list):
scheduler = [scheduler]
if not isinstance(optimizer, list):
optimizer = [optimizer]

View File

@@ -82,6 +82,7 @@ class SupervisedSolver(SolverInterface):
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
self.loss = loss
def forward(self, x):
"""Forward pass implementation for the solver.
@@ -90,7 +91,16 @@ class SupervisedSolver(SolverInterface):
:return: Solver solution.
:rtype: torch.Tensor
"""
return self._pina_model(x)
output = self._pina_model[0](x)
output.labels = {
1: {
"name": "output",
"dof": self.problem.output_variables
}
}
return output
def configure_optimizers(self):
"""Optimizer configuration for the solver.
@@ -98,9 +108,12 @@ class SupervisedSolver(SolverInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
self._pina_optimizer.hook(self._pina_model.parameters())
self._pina_scheduler.hook(self._pina_optimizer)
return self._pina_optimizer, self._pina_scheduler
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
self._pina_scheduler[0].hook(self._pina_optimizer[0])
return (
[self._pina_optimizer[0].optimizer_instance],
[self._pina_scheduler[0].scheduler_instance]
)
def training_step(self, batch, batch_idx):
"""Solver training step.
@@ -113,14 +126,16 @@ class SupervisedSolver(SolverInterface):
:rtype: LabelTensor
"""
condition_idx = batch["condition"]
condition_idx = batch.condition
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch["pts"]
out = batch["output"]
pts = batch.input
out = batch.output
print(out)
print(pts)
if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.")
@@ -134,9 +149,11 @@ class SupervisedSolver(SolverInterface):
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]
input_pts.labels = pts.labels
output_pts.labels = out.labels
loss = (
self.loss_data(input_pts=input_pts, output_pts=output_pts)
* condition.data_weight
)
loss = loss.as_subclass(torch.Tensor)
@@ -155,6 +172,10 @@ class SupervisedSolver(SolverInterface):
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
print(input_pts)
print(output_pts)
print(self.loss)
print(self.forward(input_pts))
return self.loss(self.forward(input_pts), output_pts)
@property