* Fix some bugs * Solve bug with GPU and model_summary parameters in SupervisedSolver class * Implement GraphSolver class * Fix Tutorial 5
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from .supervised import SupervisedSolver
|
|
from ..graph import Graph
|
|
|
|
|
|
class GraphSupervisedSolver(SupervisedSolver):
|
|
|
|
def __init__(
|
|
self,
|
|
problem,
|
|
model,
|
|
nodes_coordinates,
|
|
nodes_data,
|
|
loss=None,
|
|
optimizer=None,
|
|
scheduler=None):
|
|
super().__init__(problem, model, loss, optimizer, scheduler)
|
|
if isinstance(nodes_coordinates, str):
|
|
self._nodes_coordinates = [nodes_coordinates]
|
|
else:
|
|
self._nodes_coordinates = nodes_coordinates
|
|
if isinstance(nodes_data, str):
|
|
self._nodes_data = nodes_data
|
|
else:
|
|
self._nodes_data = nodes_data
|
|
|
|
def forward(self, input):
|
|
input_coords = input.extract(self._nodes_coordinates)
|
|
input_data = input.extract(self._nodes_data)
|
|
|
|
if not isinstance(input, Graph):
|
|
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
|
|
g = self.model(input.data, edge_index=input.data.edge_index)
|
|
g.labels = {1: {'name': 'output', 'dof': ['u']}}
|
|
return g
|