Fix SupervisedSolver GPU bug and implement GraphSolver (#346)
* Fix some bugs * Solve bug with GPU and model_summary parameters in SupervisedSolver class * Implement GraphSolver class * Fix Tutorial 5
This commit is contained in:
committed by
Nicola Demo
parent
30f865d912
commit
2be57944ba
@@ -82,7 +82,10 @@ class SupervisedSolver(SolverInterface):
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
self.loss = loss
|
||||
self._loss = loss
|
||||
self._model = self._pina_model[0]
|
||||
self._optimizer = self._pina_optimizer[0]
|
||||
self._scheduler = self._pina_scheduler[0]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -92,7 +95,7 @@ class SupervisedSolver(SolverInterface):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
output = self._pina_model[0](x)
|
||||
output = self._model(x)
|
||||
|
||||
output.labels = {
|
||||
1: {
|
||||
@@ -108,11 +111,11 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
|
||||
self._pina_scheduler[0].hook(self._pina_optimizer[0])
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return (
|
||||
[self._pina_optimizer[0].optimizer_instance],
|
||||
[self._pina_scheduler[0].scheduler_instance]
|
||||
[self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance]
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
@@ -170,28 +173,28 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.loss(self.forward(input_pts), output_pts)
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for training.
|
||||
"""
|
||||
return self._pina_scheduler
|
||||
return self._scheduler
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
Optimizer for training.
|
||||
"""
|
||||
return self._pina_optimizer
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
Neural network for training.
|
||||
"""
|
||||
return self._pina_model
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
|
||||
Reference in New Issue
Block a user