add train/eval
This commit is contained in:
committed by
Nicola Demo
parent
5c09ff626c
commit
53cbf3f22c
@@ -225,6 +225,7 @@ class PINN(object):
|
|||||||
|
|
||||||
def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
|
def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
epoch = 0
|
epoch = 0
|
||||||
data_loader = self.data_set.dataloader
|
data_loader = self.data_set.dataloader
|
||||||
|
|
||||||
@@ -319,6 +320,8 @@ class PINN(object):
|
|||||||
self.trained_epoch += 1
|
self.trained_epoch += 1
|
||||||
epoch += 1
|
epoch += 1
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
return sum(losses).item()
|
return sum(losses).item()
|
||||||
|
|
||||||
def error(self, dtype='l2', res=100):
|
def error(self, dtype='l2', res=100):
|
||||||
|
|||||||
Reference in New Issue
Block a user