Trainer train simplified, tests for load (#168)
- the arguments of Trainer.train now are passed to the fit - unittest for load/restoring from checkpoint
This commit is contained in:
@@ -7,7 +7,7 @@ from .solvers.solver import SolverInterface
|
|||||||
|
|
||||||
class Trainer(pl.Trainer):
|
class Trainer(pl.Trainer):
|
||||||
|
|
||||||
def __init__(self, solver, kwargs={}):
|
def __init__(self, solver, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# get accellerator
|
# get accellerator
|
||||||
@@ -29,6 +29,6 @@ class Trainer(pl.Trainer):
|
|||||||
self._loader = DummyLoader(solver.problem.input_pts, device)
|
self._loader = DummyLoader(solver.problem.input_pts, device)
|
||||||
|
|
||||||
|
|
||||||
def train(self): # TODO add kwargs and lightining capabilities
|
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
|
||||||
return super().fit(self._model, self._loader)
|
return super().fit(self._model, self._loader, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ def test_train_cpu():
|
|||||||
hidden_dimension=64)
|
hidden_dimension=64)
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(solver=solver, kwargs={'max_epochs' : 4, 'accelerator': 'cpu'})
|
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu')
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
def test_sample():
|
def test_sample():
|
||||||
|
|||||||
@@ -56,12 +56,12 @@ class Poisson(SpatialProblem):
|
|||||||
|
|
||||||
truth_solution = poisson_sol
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
|
|
||||||
class myFeature(torch.nn.Module):
|
class myFeature(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Feature: sin(x)
|
Feature: sin(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(myFeature, self).__init__()
|
super(myFeature, self).__init__()
|
||||||
|
|
||||||
@@ -92,9 +92,46 @@ def test_train_cpu():
|
|||||||
n = 10
|
n = 10
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
def test_train_restore():
|
||||||
|
tmpdir = "tests/tmp_restore"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu', default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
print('ggg')
|
||||||
|
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||||
|
t = ntrainer.train(
|
||||||
|
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def test_train_load():
|
||||||
|
tmpdir = "tests/tmp_load"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = PINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(new_pinn.forward(test_pts).extract(['u']), pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
# # TODO fix asap. Basically sampling few variables
|
# # TODO fix asap. Basically sampling few variables
|
||||||
# # works only if both variables are in a range.
|
# # works only if both variables are in a range.
|
||||||
# # if one is fixed and the other not, this will
|
# # if one is fixed and the other not, this will
|
||||||
@@ -118,7 +155,7 @@ def test_train_extra_feats_cpu():
|
|||||||
n = 10
|
n = 10
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
||||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# TODO, fix GitHub actions to run also on GPU
|
# TODO, fix GitHub actions to run also on GPU
|
||||||
|
|||||||
Reference in New Issue
Block a user