Add Graph support in Dataset and Dataloader

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent eb146ea2ea
commit ccc5f5a322
11 changed files with 125 additions and 75 deletions

View File

@@ -1,13 +1,15 @@
import math
import torch
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, \
UnsupervisedDataset
from pina.data import PinaDataLoader
from pina import LabelTensor, Condition
from pina.equation import Equation
from pina.domain import CartesianDomain
from pina.problem import SpatialProblem
from pina.problem import SpatialProblem, AbstractProblem
from pina.operators import laplacian
from pina.equation.equation_factory import FixedValue
from pina.graph import Graph
def laplace_equation(input_, output_):
@@ -30,49 +32,49 @@ class Poisson(SpatialProblem):
conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_),
Condition(input_points=in_, output_points=out_),
'data2':
Condition(input_points=in2_, output_points=out2_),
Condition(input_points=in2_, output_points=out2_),
'unsupervised':
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
'unsupervised2':
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
}
@@ -98,8 +100,8 @@ def test_data():
assert dataset.input_points.shape == (61, 2)
assert dataset['input_points'].labels == ['x', 'y']
assert dataset.input_points.labels == ['x', 'y']
assert dataset['input_points', 3:].shape == (58, 2)
assert dataset[3:][1].labels == ['u']
assert dataset.input_points[3:].shape == (58, 2)
assert dataset.output_points[:3].labels == ['u']
assert dataset.output_points.shape == (61, 1)
assert dataset.output_points.labels == ['u']
assert dataset.condition_indices.dtype == torch.uint8
@@ -193,4 +195,32 @@ def test_loader():
assert i.unsupervised.input_points.requires_grad == True
test_loader()
coordinates = LabelTensor(torch.rand((100, 100, 2)), labels=['x', 'y'])
data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
class GraphProblem(AbstractProblem):
output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
input = [Graph.build('radius',
nodes_coordinates=coordinates[i, :, :],
nodes_data=data[i, :, :], radius=0.2)
for i in
range(100)]
output_variables = ['u']
conditions = {
'graph_data': Condition(input_points=input, output_points=output)
}
graph_problem = GraphProblem()
def test_loader_graph():
data_module = PinaDataModule(graph_problem, device='cpu', batch_size=10)
data_module.setup()
loader = data_module.train_dataloader()
for i in loader:
assert len(i) <= 10
assert isinstance(i.supervised.input_points, list)
assert all(isinstance(x, Graph) for x in i.supervised.input_points)