Fix SupervisedSolver GPU bug and implement GraphSolver (#346)

* Fix some bugs
* Solve bug with GPU and model_summary parameters in SupervisedSolver class
* Implement GraphSolver class
* Fix Tutorial 5
This commit is contained in:
FilippoOlivo
2024-09-21 18:55:57 +02:00
committed by Nicola Demo
parent 30f865d912
commit 2be57944ba
10 changed files with 334 additions and 164 deletions

View File

@@ -4,6 +4,7 @@ from .sample_dataset import SamplePointDataset
from .data_dataset import DataPointDataset
from .pina_batch import Batch
class SamplePointLoader:
"""
This class is used to create a dataloader to use during the training.
@@ -95,7 +96,7 @@ class SamplePointLoader:
self.batch_output_pts = torch.tensor_split(
dataset.output_pts, batch_num
)
print(input_labels)
#print(input_labels)
for i in range(len(self.batch_input_pts)):
self.batch_input_pts[i].labels = input_labels
self.batch_output_pts[i].labels = output_labels
@@ -161,7 +162,6 @@ class SamplePointLoader:
self.batch_input_pts,
self.batch_output_pts,
self.batch_data_conditions)
print(batch.input.labels)
self.batches.append(batch)

View File

@@ -425,7 +425,7 @@ class LabelTensor(torch.Tensor):
raise NotImplementedError
labels = [tensor.labels for tensor in tensors]
print(labels)
def requires_grad_(self, mode=True):
lt = super().requires_grad_(mode)
@@ -436,7 +436,6 @@ class LabelTensor(torch.Tensor):
def dtype(self):
return super().dtype
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
@@ -447,7 +446,6 @@ class LabelTensor(torch.Tensor):
new.data = tmp.data
return new
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see

View File

@@ -269,4 +269,7 @@ class FNO(KernelNeuralOperator):
:return: The output tensor obtained from FNO.
:rtype: torch.Tensor
"""
if isinstance(x, LabelTensor):
x = x.as_subclass(torch.Tensor)
return super().forward(x)

View File

@@ -17,3 +17,4 @@ from .pinns import *
from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver
from .garom import GAROM
from .graph import GraphSupervisedSolver

34
pina/solvers/graph.py Normal file
View File

@@ -0,0 +1,34 @@
from .supervised import SupervisedSolver
from ..graph import Graph
class GraphSupervisedSolver(SupervisedSolver):
def __init__(
self,
problem,
model,
nodes_coordinates,
nodes_data,
loss=None,
optimizer=None,
scheduler=None):
super().__init__(problem, model, loss, optimizer, scheduler)
if isinstance(nodes_coordinates, str):
self._nodes_coordinates = [nodes_coordinates]
else:
self._nodes_coordinates = nodes_coordinates
if isinstance(nodes_data, str):
self._nodes_data = nodes_data
else:
self._nodes_data = nodes_data
def forward(self, input):
input_coords = input.extract(self._nodes_coordinates)
input_data = input.extract(self._nodes_data)
if not isinstance(input, Graph):
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
g = self.model(input.data, edge_index=input.data.edge_index)
g.labels = {1: {'name': 'output', 'dof': ['u']}}
return g

View File

@@ -82,7 +82,10 @@ class SupervisedSolver(SolverInterface):
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
self.loss = loss
self._loss = loss
self._model = self._pina_model[0]
self._optimizer = self._pina_optimizer[0]
self._scheduler = self._pina_scheduler[0]
def forward(self, x):
"""Forward pass implementation for the solver.
@@ -92,7 +95,7 @@ class SupervisedSolver(SolverInterface):
:rtype: torch.Tensor
"""
output = self._pina_model[0](x)
output = self._model(x)
output.labels = {
1: {
@@ -108,11 +111,11 @@ class SupervisedSolver(SolverInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
self._pina_scheduler[0].hook(self._pina_optimizer[0])
self._optimizer.hook(self._model.parameters())
self._scheduler.hook(self._optimizer)
return (
[self._pina_optimizer[0].optimizer_instance],
[self._pina_scheduler[0].scheduler_instance]
[self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance]
)
def training_step(self, batch, batch_idx):
@@ -170,28 +173,28 @@ class SupervisedSolver(SolverInterface):
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
return self.loss(self.forward(input_pts), output_pts)
return self._loss(self.forward(input_pts), output_pts)
@property
def scheduler(self):
"""
Scheduler for training.
"""
return self._pina_scheduler
return self._scheduler
@property
def optimizer(self):
"""
Optimizer for training.
"""
return self._pina_optimizer
return self._optimizer
@property
def model(self):
"""
Neural network for training.
"""
return self._pina_model
return self._model
@property
def loss(self):