add train/eval

This commit is contained in:
Dario Coscia
2022-12-25 18:26:46 +01:00
committed by Nicola Demo
parent 5c09ff626c
commit 53cbf3f22c

View File

@@ -225,6 +225,7 @@ class PINN(object):
def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
self.model.train()
epoch = 0 epoch = 0
data_loader = self.data_set.dataloader data_loader = self.data_set.dataloader
@@ -319,6 +320,8 @@ class PINN(object):
self.trained_epoch += 1 self.trained_epoch += 1
epoch += 1 epoch += 1
self.model.eval()
return sum(losses).item() return sum(losses).item()
def error(self, dtype='l2', res=100): def error(self, dtype='l2', res=100):