Fix SupervisedSolver GPU bug and implement GraphSolver (#346)

* Fix some bugs
* Solve bug with GPU and model_summary parameters in SupervisedSolver class
* Implement GraphSolver class
* Fix Tutorial 5
This commit is contained in:
FilippoOlivo
2024-09-21 18:55:57 +02:00
committed by Nicola Demo
parent 30f865d912
commit 2be57944ba
10 changed files with 334 additions and 164 deletions

View File

@@ -82,7 +82,10 @@ class SupervisedSolver(SolverInterface):
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
self.loss = loss
self._loss = loss
self._model = self._pina_model[0]
self._optimizer = self._pina_optimizer[0]
self._scheduler = self._pina_scheduler[0]
def forward(self, x):
"""Forward pass implementation for the solver.
@@ -92,7 +95,7 @@ class SupervisedSolver(SolverInterface):
:rtype: torch.Tensor
"""
output = self._pina_model[0](x)
output = self._model(x)
output.labels = {
1: {
@@ -108,11 +111,11 @@ class SupervisedSolver(SolverInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
self._pina_scheduler[0].hook(self._pina_optimizer[0])
self._optimizer.hook(self._model.parameters())
self._scheduler.hook(self._optimizer)
return (
[self._pina_optimizer[0].optimizer_instance],
[self._pina_scheduler[0].scheduler_instance]
[self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance]
)
def training_step(self, batch, batch_idx):
@@ -170,28 +173,28 @@ class SupervisedSolver(SolverInterface):
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
return self.loss(self.forward(input_pts), output_pts)
return self._loss(self.forward(input_pts), output_pts)
@property
def scheduler(self):
"""
Scheduler for training.
"""
return self._pina_scheduler
return self._scheduler
@property
def optimizer(self):
"""
Optimizer for training.
"""
return self._pina_optimizer
return self._optimizer
@property
def model(self):
"""
Neural network for training.
"""
return self._pina_model
return self._model
@property
def loss(self):