minor fix, add few tests (#38)

This commit is contained in:
Nicola Demo
2022-11-29 12:42:01 +01:00
committed by GitHub
parent a1e947fede
commit 936f5e1043
4 changed files with 52 additions and 12 deletions

View File

@@ -45,7 +45,7 @@ class PINN(object):
self.device = torch.device(device) self.device = torch.device(device)
self.dtype = dtype self.dtype = dtype
self.history = [] self.history_loss = {}
self.model = model self.model = model
self.model.to(dtype=self.dtype, device=self.device) self.model.to(dtype=self.dtype, device=self.device)
@@ -92,7 +92,7 @@ class PINN(object):
'model_state': self.model.state_dict(), 'model_state': self.model.state_dict(),
'optimizer_state' : self.optimizer.state_dict(), 'optimizer_state' : self.optimizer.state_dict(),
'optimizer_class' : self.optimizer.__class__, 'optimizer_class' : self.optimizer.__class__,
'history' : self.history, 'history' : self.history_loss,
'input_points_dict' : self.input_pts, 'input_points_dict' : self.input_pts,
} }
@@ -113,7 +113,7 @@ class PINN(object):
self.optimizer.load_state_dict(checkpoint['optimizer_state']) self.optimizer.load_state_dict(checkpoint['optimizer_state'])
self.trained_epoch = checkpoint['epoch'] self.trained_epoch = checkpoint['epoch']
self.history = checkpoint['history'] self.history_loss = checkpoint['history']
self.input_pts = checkpoint['input_points_dict'] self.input_pts = checkpoint['input_points_dict']
@@ -184,7 +184,7 @@ class PINN(object):
self.input_pts[location].requires_grad_(True) self.input_pts[location].requires_grad_(True)
self.input_pts[location].retain_grad() self.input_pts[location].retain_grad()
def train(self, stop=100, frequency_print=2, trial=None): def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
epoch = 0 epoch = 0
@@ -230,10 +230,9 @@ class PINN(object):
sum(losses).backward() sum(losses).backward()
self.optimizer.step() self.optimizer.step()
self.trained_epoch += 1 if save_loss and (epoch % save_loss == 0 or epoch == 0):
if epoch % 50 == 0: self.history_loss[epoch] = [
self.history.append([loss.detach().item() for loss in losses]) loss.detach().item() for loss in losses]
epoch += 1
if trial: if trial:
import optuna import optuna
@@ -245,7 +244,7 @@ class PINN(object):
if epoch == stop: if epoch == stop:
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='') print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
for loss in losses: for loss in losses:
print('{:.6e} '.format(loss), end='') print('{:.6e} '.format(loss.item()), end='')
print() print()
break break
elif isinstance(stop, float): elif isinstance(stop, float):
@@ -260,9 +259,12 @@ class PINN(object):
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='') print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
for loss in losses: for loss in losses:
print('{:.6e} '.format(loss), end='') print('{:.6e} '.format(loss.item()), end='')
print() print()
self.trained_epoch += 1
epoch += 1
return sum(losses).item() return sum(losses).item()

View File

@@ -131,3 +131,22 @@ class Plotter:
plt.savefig(filename) plt.savefig(filename)
else: else:
plt.show() plt.show()
def plot_loss(self, pinn, label=None, log_scale=True):
"""
Plot the loss trend
TODO
"""
if not label:
label = str(pinn)
epochs = list(pinn.history_loss.keys())
loss = np.array(list(pinn.history_loss.values()))
if loss.ndim != 1:
loss = loss[:, 0]
plt.plot(epochs, loss, label=label)
if log_scale:
plt.yscale('log')

View File

@@ -56,7 +56,6 @@ class Span(Location):
def sample(self, n, mode='random', variables='all'): def sample(self, n, mode='random', variables='all'):
"""TODO """TODO
""" """
def _1d_sampler(n, mode, variables): def _1d_sampler(n, mode, variables):
""" Sample independentely the variables and cross the results""" """ Sample independentely the variables and cross the results"""
tmp = [] tmp = []

View File

@@ -38,8 +38,8 @@ class Poisson(SpatialProblem):
truth_solution = poisson_sol truth_solution = poisson_sol
problem = Poisson() problem = Poisson()
model = FeedForward(2, 1)
model = FeedForward(problem.input_variables, problem.output_variables)
def test_constructor(): def test_constructor():
PINN(problem, model) PINN(problem, model)
@@ -59,3 +59,23 @@ def test_span_pts():
assert pinn.input_pts['D'].shape[0] == n**2 assert pinn.input_pts['D'].shape[0] == n**2
pinn.span_pts(n, 'random', locations=['D']) pinn.span_pts(n, 'random', locations=['D'])
assert pinn.input_pts['D'].shape[0] == n assert pinn.input_pts['D'].shape[0] == n
def test_train():
pinn = PINN(problem, model)
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(5)
def test_train():
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
expected_keys = [[], list(range(0, 50, 3))]
param = [0, 3]
for i, truth_key in zip(param, expected_keys):
pinn = PINN(problem, model)
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(50, save_loss=i)
assert list(pinn.history_loss.keys()) == truth_key