minor fix, add few tests (#38)
This commit is contained in:
22
pina/pinn.py
22
pina/pinn.py
@@ -45,7 +45,7 @@ class PINN(object):
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.dtype = dtype
|
||||
self.history = []
|
||||
self.history_loss = {}
|
||||
|
||||
self.model = model
|
||||
self.model.to(dtype=self.dtype, device=self.device)
|
||||
@@ -92,7 +92,7 @@ class PINN(object):
|
||||
'model_state': self.model.state_dict(),
|
||||
'optimizer_state' : self.optimizer.state_dict(),
|
||||
'optimizer_class' : self.optimizer.__class__,
|
||||
'history' : self.history,
|
||||
'history' : self.history_loss,
|
||||
'input_points_dict' : self.input_pts,
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ class PINN(object):
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
|
||||
|
||||
self.trained_epoch = checkpoint['epoch']
|
||||
self.history = checkpoint['history']
|
||||
self.history_loss = checkpoint['history']
|
||||
|
||||
self.input_pts = checkpoint['input_points_dict']
|
||||
|
||||
@@ -184,7 +184,7 @@ class PINN(object):
|
||||
self.input_pts[location].requires_grad_(True)
|
||||
self.input_pts[location].retain_grad()
|
||||
|
||||
def train(self, stop=100, frequency_print=2, trial=None):
|
||||
def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
|
||||
|
||||
epoch = 0
|
||||
|
||||
@@ -230,10 +230,9 @@ class PINN(object):
|
||||
sum(losses).backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.trained_epoch += 1
|
||||
if epoch % 50 == 0:
|
||||
self.history.append([loss.detach().item() for loss in losses])
|
||||
epoch += 1
|
||||
if save_loss and (epoch % save_loss == 0 or epoch == 0):
|
||||
self.history_loss[epoch] = [
|
||||
loss.detach().item() for loss in losses]
|
||||
|
||||
if trial:
|
||||
import optuna
|
||||
@@ -245,7 +244,7 @@ class PINN(object):
|
||||
if epoch == stop:
|
||||
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
|
||||
for loss in losses:
|
||||
print('{:.6e} '.format(loss), end='')
|
||||
print('{:.6e} '.format(loss.item()), end='')
|
||||
print()
|
||||
break
|
||||
elif isinstance(stop, float):
|
||||
@@ -260,9 +259,12 @@ class PINN(object):
|
||||
|
||||
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
|
||||
for loss in losses:
|
||||
print('{:.6e} '.format(loss), end='')
|
||||
print('{:.6e} '.format(loss.item()), end='')
|
||||
print()
|
||||
|
||||
self.trained_epoch += 1
|
||||
epoch += 1
|
||||
|
||||
return sum(losses).item()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user