Fix SupervisedSolver GPU bug and implement GraphSolver (#346)
* Fix some bugs * Solve bug with GPU and model_summary parameters in SupervisedSolver class * Implement GraphSolver class * Fix Tutorial 5
This commit is contained in:
committed by
Nicola Demo
parent
30f865d912
commit
2be57944ba
34
pina/solvers/graph.py
Normal file
34
pina/solvers/graph.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from .supervised import SupervisedSolver
|
||||
from ..graph import Graph
|
||||
|
||||
|
||||
class GraphSupervisedSolver(SupervisedSolver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
nodes_coordinates,
|
||||
nodes_data,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None):
|
||||
super().__init__(problem, model, loss, optimizer, scheduler)
|
||||
if isinstance(nodes_coordinates, str):
|
||||
self._nodes_coordinates = [nodes_coordinates]
|
||||
else:
|
||||
self._nodes_coordinates = nodes_coordinates
|
||||
if isinstance(nodes_data, str):
|
||||
self._nodes_data = nodes_data
|
||||
else:
|
||||
self._nodes_data = nodes_data
|
||||
|
||||
def forward(self, input):
|
||||
input_coords = input.extract(self._nodes_coordinates)
|
||||
input_data = input.extract(self._nodes_data)
|
||||
|
||||
if not isinstance(input, Graph):
|
||||
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
|
||||
g = self.model(input.data, edge_index=input.data.edge_index)
|
||||
g.labels = {1: {'name': 'output', 'dof': ['u']}}
|
||||
return g
|
||||
Reference in New Issue
Block a user