Files
PINA/tests/test_dataset.py
Nicola Demo 5245a0b68c refact
2025-03-19 17:46:33 +01:00

123 lines
4.6 KiB
Python

import torch
import pytest
from pina.data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from pina import LabelTensor, Condition
from pina.equation import Equation
from pina.geometry import CartesianDomain
from pina.problem import SpatialProblem
from pina.model import FeedForward
from pina.operators import laplacian
from pina.equation.equation_factory import FixedValue
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
torch.sin(input_.extract(['y'])*torch.pi))
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
conditions = {
'gamma1': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 1}),
equation=FixedValue(0.0)),
'gamma2': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 0}),
equation=FixedValue(0.0)),
'gamma3': Condition(
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'gamma4': Condition(
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_),
'data2': Condition(
input_points=in2_,
output_points=out2_)
}
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
poisson = Poisson()
poisson.discretise_domain(10, 'grid', locations=boundaries)
def test_sample():
sample_dataset = SamplePointDataset(poisson, device='cpu')
assert len(sample_dataset) == 140
assert sample_dataset.pts.shape == (140, 2)
assert sample_dataset.pts.labels == ['x', 'y']
assert sample_dataset.condition_indeces.dtype == torch.int64
assert sample_dataset.condition_indeces.max() == torch.tensor(4)
assert sample_dataset.condition_indeces.min() == torch.tensor(0)
def test_data():
dataset = DataPointDataset(poisson, device='cpu')
assert len(dataset) == 61
assert dataset.input_pts.shape == (61, 2)
assert dataset.input_pts.labels == ['x', 'y']
assert dataset.output_pts.shape == (61, 1 )
assert dataset.output_pts.labels == ['u']
assert dataset.condition_indeces.dtype == torch.int64
assert dataset.condition_indeces.max() == torch.tensor(1)
assert dataset.condition_indeces.min() == torch.tensor(0)
def test_loader():
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) in [2, 3]
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']
loader2 = SamplePointLoader(sample_dataset, data_dataset, batch_size=None)
assert len(list(loader2)) == 2
def test_loader2():
poisson2 = Poisson()
del poisson.conditions['data2']
del poisson2.conditions['data']
poisson2.discretise_domain(10, 'grid', locations=boundaries)
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) == 2 # only phys condtions
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']
def test_loader3():
poisson2 = Poisson()
del poisson.conditions['gamma1']
del poisson.conditions['gamma2']
del poisson.conditions['gamma3']
del poisson.conditions['gamma4']
del poisson.conditions['D']
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) == 2 # only phys condtions
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']