Add Graph support in Dataset and Dataloader
This commit is contained in:
committed by
Nicola Demo
parent
eb146ea2ea
commit
ccc5f5a322
@@ -1,13 +1,15 @@
|
||||
import math
|
||||
import torch
|
||||
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset
|
||||
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, \
|
||||
UnsupervisedDataset
|
||||
from pina.data import PinaDataLoader
|
||||
from pina import LabelTensor, Condition
|
||||
from pina.equation import Equation
|
||||
from pina.domain import CartesianDomain
|
||||
from pina.problem import SpatialProblem
|
||||
from pina.problem import SpatialProblem, AbstractProblem
|
||||
from pina.operators import laplacian
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
from pina.graph import Graph
|
||||
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
@@ -30,49 +32,49 @@ class Poisson(SpatialProblem):
|
||||
|
||||
conditions = {
|
||||
'gamma1':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 1
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 1
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma2':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 0
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 0
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma3':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 1,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 1,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma4':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 0,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 0,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'D':
|
||||
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
|
||||
['x', 'y']),
|
||||
equation=my_laplace),
|
||||
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
|
||||
['x', 'y']),
|
||||
equation=my_laplace),
|
||||
'data':
|
||||
Condition(input_points=in_, output_points=out_),
|
||||
Condition(input_points=in_, output_points=out_),
|
||||
'data2':
|
||||
Condition(input_points=in2_, output_points=out2_),
|
||||
Condition(input_points=in2_, output_points=out2_),
|
||||
'unsupervised':
|
||||
Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
|
||||
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
|
||||
['alpha']),
|
||||
),
|
||||
Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
|
||||
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
|
||||
['alpha']),
|
||||
),
|
||||
'unsupervised2':
|
||||
Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
|
||||
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
|
||||
['alpha']),
|
||||
)
|
||||
Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
|
||||
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
|
||||
['alpha']),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -98,8 +100,8 @@ def test_data():
|
||||
assert dataset.input_points.shape == (61, 2)
|
||||
assert dataset['input_points'].labels == ['x', 'y']
|
||||
assert dataset.input_points.labels == ['x', 'y']
|
||||
assert dataset['input_points', 3:].shape == (58, 2)
|
||||
assert dataset[3:][1].labels == ['u']
|
||||
assert dataset.input_points[3:].shape == (58, 2)
|
||||
assert dataset.output_points[:3].labels == ['u']
|
||||
assert dataset.output_points.shape == (61, 1)
|
||||
assert dataset.output_points.labels == ['u']
|
||||
assert dataset.condition_indices.dtype == torch.uint8
|
||||
@@ -193,4 +195,32 @@ def test_loader():
|
||||
assert i.unsupervised.input_points.requires_grad == True
|
||||
|
||||
|
||||
test_loader()
|
||||
coordinates = LabelTensor(torch.rand((100, 100, 2)), labels=['x', 'y'])
|
||||
data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
|
||||
|
||||
|
||||
class GraphProblem(AbstractProblem):
|
||||
output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
|
||||
input = [Graph.build('radius',
|
||||
nodes_coordinates=coordinates[i, :, :],
|
||||
nodes_data=data[i, :, :], radius=0.2)
|
||||
for i in
|
||||
range(100)]
|
||||
output_variables = ['u']
|
||||
|
||||
conditions = {
|
||||
'graph_data': Condition(input_points=input, output_points=output)
|
||||
}
|
||||
|
||||
|
||||
graph_problem = GraphProblem()
|
||||
|
||||
|
||||
def test_loader_graph():
|
||||
data_module = PinaDataModule(graph_problem, device='cpu', batch_size=10)
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
for i in loader:
|
||||
assert len(i) <= 10
|
||||
assert isinstance(i.supervised.input_points, list)
|
||||
assert all(isinstance(x, Graph) for x in i.supervised.input_points)
|
||||
|
||||
Reference in New Issue
Block a user