diff --git a/pina/pinn.py b/pina/pinn.py index 4ae9628..215b53d 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -225,6 +225,7 @@ class PINN(object): def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): + self.model.train() epoch = 0 data_loader = self.data_set.dataloader @@ -319,6 +320,8 @@ class PINN(object): self.trained_epoch += 1 epoch += 1 + self.model.eval() + return sum(losses).item() def error(self, dtype='l2', res=100):