Fix SupervisedSolver GPU bug and implement GraphSolver (#346)
* Fix some bugs * Solve bug with GPU and model_summary parameters in SupervisedSolver class * Implement GraphSolver class * Fix Tutorial 5
This commit is contained in:
committed by
Nicola Demo
parent
30f865d912
commit
2be57944ba
@@ -4,6 +4,7 @@ from .sample_dataset import SamplePointDataset
|
||||
from .data_dataset import DataPointDataset
|
||||
from .pina_batch import Batch
|
||||
|
||||
|
||||
class SamplePointLoader:
|
||||
"""
|
||||
This class is used to create a dataloader to use during the training.
|
||||
@@ -95,7 +96,7 @@ class SamplePointLoader:
|
||||
self.batch_output_pts = torch.tensor_split(
|
||||
dataset.output_pts, batch_num
|
||||
)
|
||||
print(input_labels)
|
||||
#print(input_labels)
|
||||
for i in range(len(self.batch_input_pts)):
|
||||
self.batch_input_pts[i].labels = input_labels
|
||||
self.batch_output_pts[i].labels = output_labels
|
||||
@@ -161,7 +162,6 @@ class SamplePointLoader:
|
||||
self.batch_input_pts,
|
||||
self.batch_output_pts,
|
||||
self.batch_data_conditions)
|
||||
print(batch.input.labels)
|
||||
|
||||
self.batches.append(batch)
|
||||
|
||||
|
||||
@@ -425,7 +425,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
raise NotImplementedError
|
||||
labels = [tensor.labels for tensor in tensors]
|
||||
print(labels)
|
||||
|
||||
|
||||
def requires_grad_(self, mode=True):
|
||||
lt = super().requires_grad_(mode)
|
||||
@@ -436,7 +436,6 @@ class LabelTensor(torch.Tensor):
|
||||
def dtype(self):
|
||||
return super().dtype
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
Performs Tensor dtype and/or device conversion. For more details, see
|
||||
@@ -447,7 +446,6 @@ class LabelTensor(torch.Tensor):
|
||||
new.data = tmp.data
|
||||
return new
|
||||
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
"""
|
||||
Clone the LabelTensor. For more details, see
|
||||
|
||||
@@ -269,4 +269,7 @@ class FNO(KernelNeuralOperator):
|
||||
:return: The output tensor obtained from FNO.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
if isinstance(x, LabelTensor):
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
return super().forward(x)
|
||||
|
||||
@@ -17,3 +17,4 @@ from .pinns import *
|
||||
from .supervised import SupervisedSolver
|
||||
from .rom import ReducedOrderModelSolver
|
||||
from .garom import GAROM
|
||||
from .graph import GraphSupervisedSolver
|
||||
|
||||
34
pina/solvers/graph.py
Normal file
34
pina/solvers/graph.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from .supervised import SupervisedSolver
|
||||
from ..graph import Graph
|
||||
|
||||
|
||||
class GraphSupervisedSolver(SupervisedSolver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
nodes_coordinates,
|
||||
nodes_data,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None):
|
||||
super().__init__(problem, model, loss, optimizer, scheduler)
|
||||
if isinstance(nodes_coordinates, str):
|
||||
self._nodes_coordinates = [nodes_coordinates]
|
||||
else:
|
||||
self._nodes_coordinates = nodes_coordinates
|
||||
if isinstance(nodes_data, str):
|
||||
self._nodes_data = nodes_data
|
||||
else:
|
||||
self._nodes_data = nodes_data
|
||||
|
||||
def forward(self, input):
|
||||
input_coords = input.extract(self._nodes_coordinates)
|
||||
input_data = input.extract(self._nodes_data)
|
||||
|
||||
if not isinstance(input, Graph):
|
||||
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
|
||||
g = self.model(input.data, edge_index=input.data.edge_index)
|
||||
g.labels = {1: {'name': 'output', 'dof': ['u']}}
|
||||
return g
|
||||
@@ -82,7 +82,10 @@ class SupervisedSolver(SolverInterface):
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
self.loss = loss
|
||||
self._loss = loss
|
||||
self._model = self._pina_model[0]
|
||||
self._optimizer = self._pina_optimizer[0]
|
||||
self._scheduler = self._pina_scheduler[0]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -92,7 +95,7 @@ class SupervisedSolver(SolverInterface):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
output = self._pina_model[0](x)
|
||||
output = self._model(x)
|
||||
|
||||
output.labels = {
|
||||
1: {
|
||||
@@ -108,11 +111,11 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
self._pina_optimizer[0].hook(self._pina_model[0].parameters())
|
||||
self._pina_scheduler[0].hook(self._pina_optimizer[0])
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return (
|
||||
[self._pina_optimizer[0].optimizer_instance],
|
||||
[self._pina_scheduler[0].scheduler_instance]
|
||||
[self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance]
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
@@ -170,28 +173,28 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.loss(self.forward(input_pts), output_pts)
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for training.
|
||||
"""
|
||||
return self._pina_scheduler
|
||||
return self._scheduler
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
Optimizer for training.
|
||||
"""
|
||||
return self._pina_optimizer
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
Neural network for training.
|
||||
"""
|
||||
return self._pina_model
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
|
||||
Reference in New Issue
Block a user