From 936f5e10430d25c60e81afc52a36ac3f7bcd0d52 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Tue, 29 Nov 2022 12:42:01 +0100 Subject: [PATCH] minor fix, add few tests (#38) --- pina/pinn.py | 22 ++++++++++++---------- pina/plotter.py | 19 +++++++++++++++++++ pina/span.py | 1 - tests/test_pinn.py | 22 +++++++++++++++++++++- 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/pina/pinn.py b/pina/pinn.py index 03929cb..e6cecbf 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -45,7 +45,7 @@ class PINN(object): self.device = torch.device(device) self.dtype = dtype - self.history = [] + self.history_loss = {} self.model = model self.model.to(dtype=self.dtype, device=self.device) @@ -92,7 +92,7 @@ class PINN(object): 'model_state': self.model.state_dict(), 'optimizer_state' : self.optimizer.state_dict(), 'optimizer_class' : self.optimizer.__class__, - 'history' : self.history, + 'history' : self.history_loss, 'input_points_dict' : self.input_pts, } @@ -113,7 +113,7 @@ class PINN(object): self.optimizer.load_state_dict(checkpoint['optimizer_state']) self.trained_epoch = checkpoint['epoch'] - self.history = checkpoint['history'] + self.history_loss = checkpoint['history'] self.input_pts = checkpoint['input_points_dict'] @@ -184,7 +184,7 @@ class PINN(object): self.input_pts[location].requires_grad_(True) self.input_pts[location].retain_grad() - def train(self, stop=100, frequency_print=2, trial=None): + def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): epoch = 0 @@ -230,10 +230,9 @@ class PINN(object): sum(losses).backward() self.optimizer.step() - self.trained_epoch += 1 - if epoch % 50 == 0: - self.history.append([loss.detach().item() for loss in losses]) - epoch += 1 + if save_loss and (epoch % save_loss == 0 or epoch == 0): + self.history_loss[epoch] = [ + loss.detach().item() for loss in losses] if trial: import optuna @@ -245,7 +244,7 @@ class PINN(object): if epoch == stop: print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='') for loss in losses: - print('{:.6e} '.format(loss), end='') + print('{:.6e} '.format(loss.item()), end='') print() break elif isinstance(stop, float): @@ -260,9 +259,12 @@ class PINN(object): print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='') for loss in losses: - print('{:.6e} '.format(loss), end='') + print('{:.6e} '.format(loss.item()), end='') print() + self.trained_epoch += 1 + epoch += 1 + return sum(losses).item() diff --git a/pina/plotter.py b/pina/plotter.py index 7881fda..6b22ded 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -131,3 +131,22 @@ class Plotter: plt.savefig(filename) else: plt.show() + + def plot_loss(self, pinn, label=None, log_scale=True): + """ + Plot the loss trend + + TODO + """ + + if not label: + label = str(pinn) + + epochs = list(pinn.history_loss.keys()) + loss = np.array(list(pinn.history_loss.values())) + if loss.ndim != 1: + loss = loss[:, 0] + + plt.plot(epochs, loss, label=label) + if log_scale: + plt.yscale('log') diff --git a/pina/span.py b/pina/span.py index 3ebab85..50083cc 100644 --- a/pina/span.py +++ b/pina/span.py @@ -56,7 +56,6 @@ class Span(Location): def sample(self, n, mode='random', variables='all'): """TODO """ - def _1d_sampler(n, mode, variables): """ Sample independentely the variables and cross the results""" tmp = [] diff --git a/tests/test_pinn.py b/tests/test_pinn.py index 821e375..f4b5ce8 100644 --- a/tests/test_pinn.py +++ b/tests/test_pinn.py @@ -38,8 +38,8 @@ class Poisson(SpatialProblem): truth_solution = poisson_sol problem = Poisson() -model = FeedForward(2, 1) +model = FeedForward(problem.input_variables, problem.output_variables) def test_constructor(): PINN(problem, model) @@ -59,3 +59,23 @@ def test_span_pts(): assert pinn.input_pts['D'].shape[0] == n**2 pinn.span_pts(n, 'random', locations=['D']) assert pinn.input_pts['D'].shape[0] == n + +def test_train(): + pinn = PINN(problem, model) + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(5) + +def test_train(): + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + expected_keys = [[], list(range(0, 50, 3))] + param = [0, 3] + for i, truth_key in zip(param, expected_keys): + pinn = PINN(problem, model) + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(50, save_loss=i) + assert list(pinn.history_loss.keys()) == truth_key \ No newline at end of file