From 53cbf3f22cdc4baf01cff9fdacfa4f96dd8f5a21 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Sun, 25 Dec 2022 18:26:46 +0100 Subject: [PATCH] add train/eval --- pina/pinn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pina/pinn.py b/pina/pinn.py index 4ae9628..215b53d 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -225,6 +225,7 @@ class PINN(object): def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): + self.model.train() epoch = 0 data_loader = self.data_set.dataloader @@ -319,6 +320,8 @@ class PINN(object): self.trained_epoch += 1 epoch += 1 + self.model.eval() + return sum(losses).item() def error(self, dtype='l2', res=100):