Fix SupervisedSolver GPU bug and implement GraphSolver (#346)
* Fix some bugs * Solve bug with GPU and model_summary parameters in SupervisedSolver class * Implement GraphSolver class * Fix Tutorial 5
This commit is contained in:
committed by
Nicola Demo
parent
30f865d912
commit
2be57944ba
@@ -17,3 +17,4 @@ from .pinns import *
|
||||
from .supervised import SupervisedSolver
|
||||
from .rom import ReducedOrderModelSolver
|
||||
from .garom import GAROM
|
||||
from .graph import GraphSupervisedSolver
|
||||
|
||||
34
pina/solvers/graph.py
Normal file
34
pina/solvers/graph.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from .supervised import SupervisedSolver
|
||||
from ..graph import Graph
|
||||
|
||||
|
||||
class GraphSupervisedSolver(SupervisedSolver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
nodes_coordinates,
|
||||
nodes_data,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None):
|
||||
super().__init__(problem, model, loss, optimizer, scheduler)
|
||||
if isinstance(nodes_coordinates, str):
|
||||
self._nodes_coordinates = [nodes_coordinates]
|
||||
else:
|
||||
self._nodes_coordinates = nodes_coordinates
|
||||
if isinstance(nodes_data, str):
|
||||
self._nodes_data = nodes_data
|
||||
else:
|
||||
self._nodes_data = nodes_data
|
||||
|
||||
def forward(self, input):
|
||||
input_coords = input.extract(self._nodes_coordinates)
|
||||
input_data = input.extract(self._nodes_data)
|
||||
|
||||
if not isinstance(input, Graph):
|
||||
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
|
||||
g = self.model(input.data, edge_index=input.data.edge_index)
|
||||
g.labels = {1: {'name': 'output', 'dof': ['u']}}
|
||||
return g
|
||||
@@ -82,7 +82,10 @@ class SupervisedSolver(SolverInterface):
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
self.loss = loss
|
||||
self._loss = loss
|
||||
self._model = self._pina_model[0]
|
||||
self._optimizer = self._pina_optimizer[0]
|
||||
self._scheduler = self._pina_scheduler[0]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -92,7 +95,7 @@ class SupervisedSolver(SolverInterface):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
output = self._pina_model[0](x)
|
||||
output = self._model(x)
|
||||
|
||||
output.labels = {
|
||||
1: {
|
||||
@@ -108,11 +111,11 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
|
||||
self._pina_scheduler[0].hook(self._pina_optimizer[0])
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return (
|
||||
[self._pina_optimizer[0].optimizer_instance],
|
||||
[self._pina_scheduler[0].scheduler_instance]
|
||||
[self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance]
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
@@ -170,28 +173,28 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.loss(self.forward(input_pts), output_pts)
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for training.
|
||||
"""
|
||||
return self._pina_scheduler
|
||||
return self._scheduler
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
Optimizer for training.
|
||||
"""
|
||||
return self._pina_optimizer
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
Neural network for training.
|
||||
"""
|
||||
return self._pina_model
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
|
||||
Reference in New Issue
Block a user