Fix SupervisedSolver GPU bug and implement GraphSolver (#346)

* Fix some bugs
* Solve bug with GPU and model_summary parameters in SupervisedSolver class
* Implement GraphSolver class
* Fix Tutorial 5
This commit is contained in:
FilippoOlivo
2024-09-21 18:55:57 +02:00
committed by Nicola Demo
parent 30f865d912
commit 2be57944ba
10 changed files with 334 additions and 164 deletions

34
pina/solvers/graph.py Normal file
View File

@@ -0,0 +1,34 @@
from .supervised import SupervisedSolver
from ..graph import Graph
class GraphSupervisedSolver(SupervisedSolver):
def __init__(
self,
problem,
model,
nodes_coordinates,
nodes_data,
loss=None,
optimizer=None,
scheduler=None):
super().__init__(problem, model, loss, optimizer, scheduler)
if isinstance(nodes_coordinates, str):
self._nodes_coordinates = [nodes_coordinates]
else:
self._nodes_coordinates = nodes_coordinates
if isinstance(nodes_data, str):
self._nodes_data = nodes_data
else:
self._nodes_data = nodes_data
def forward(self, input):
input_coords = input.extract(self._nodes_coordinates)
input_data = input.extract(self._nodes_data)
if not isinstance(input, Graph):
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
g = self.model(input.data, edge_index=input.data.edge_index)
g.labels = {1: {'name': 'output', 'dof': ['u']}}
return g