add dataset and dataloader for sample points (#195)

* add dataset and dataloader for sample points
* unittests
This commit is contained in:
Nicola Demo
2023-11-07 11:34:44 +01:00
parent cd5bc9a558
commit d654259428
19 changed files with 581 additions and 196 deletions

View File

@@ -134,7 +134,7 @@ def test_train_cpu():
hidden_dimension=64)
)
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu')
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu', batch_size=20)
trainer.train()
def test_sample():

View File

@@ -22,6 +22,8 @@ def laplace_equation(input_, output_):
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
@@ -45,7 +47,10 @@ class Poisson(SpatialProblem):
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_)
output_points=out_),
'data2': Condition(
input_points=in2_,
output_points=out2_)
}
def poisson_sol(self, pts):
@@ -92,7 +97,7 @@ def test_train_cpu():
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cpu', batch_size=20)
trainer.train()
def test_train_restore():
@@ -106,7 +111,7 @@ def test_train_restore():
trainer.train()
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
t = ntrainer.train(
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
import shutil
shutil.rmtree(tmpdir)
@@ -121,7 +126,7 @@ def test_train_load():
default_root_dir=tmpdir)
trainer.train()
new_pinn = PINN.load_from_checkpoint(
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
problem = poisson_problem, model=model)
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)