minor fix, add few tests (#38)
This commit is contained in:
22
pina/pinn.py
22
pina/pinn.py
@@ -45,7 +45,7 @@ class PINN(object):
|
|||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.history = []
|
self.history_loss = {}
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model.to(dtype=self.dtype, device=self.device)
|
self.model.to(dtype=self.dtype, device=self.device)
|
||||||
@@ -92,7 +92,7 @@ class PINN(object):
|
|||||||
'model_state': self.model.state_dict(),
|
'model_state': self.model.state_dict(),
|
||||||
'optimizer_state' : self.optimizer.state_dict(),
|
'optimizer_state' : self.optimizer.state_dict(),
|
||||||
'optimizer_class' : self.optimizer.__class__,
|
'optimizer_class' : self.optimizer.__class__,
|
||||||
'history' : self.history,
|
'history' : self.history_loss,
|
||||||
'input_points_dict' : self.input_pts,
|
'input_points_dict' : self.input_pts,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class PINN(object):
|
|||||||
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
|
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
|
||||||
|
|
||||||
self.trained_epoch = checkpoint['epoch']
|
self.trained_epoch = checkpoint['epoch']
|
||||||
self.history = checkpoint['history']
|
self.history_loss = checkpoint['history']
|
||||||
|
|
||||||
self.input_pts = checkpoint['input_points_dict']
|
self.input_pts = checkpoint['input_points_dict']
|
||||||
|
|
||||||
@@ -184,7 +184,7 @@ class PINN(object):
|
|||||||
self.input_pts[location].requires_grad_(True)
|
self.input_pts[location].requires_grad_(True)
|
||||||
self.input_pts[location].retain_grad()
|
self.input_pts[location].retain_grad()
|
||||||
|
|
||||||
def train(self, stop=100, frequency_print=2, trial=None):
|
def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
|
||||||
|
|
||||||
epoch = 0
|
epoch = 0
|
||||||
|
|
||||||
@@ -230,10 +230,9 @@ class PINN(object):
|
|||||||
sum(losses).backward()
|
sum(losses).backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
self.trained_epoch += 1
|
if save_loss and (epoch % save_loss == 0 or epoch == 0):
|
||||||
if epoch % 50 == 0:
|
self.history_loss[epoch] = [
|
||||||
self.history.append([loss.detach().item() for loss in losses])
|
loss.detach().item() for loss in losses]
|
||||||
epoch += 1
|
|
||||||
|
|
||||||
if trial:
|
if trial:
|
||||||
import optuna
|
import optuna
|
||||||
@@ -245,7 +244,7 @@ class PINN(object):
|
|||||||
if epoch == stop:
|
if epoch == stop:
|
||||||
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
|
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
|
||||||
for loss in losses:
|
for loss in losses:
|
||||||
print('{:.6e} '.format(loss), end='')
|
print('{:.6e} '.format(loss.item()), end='')
|
||||||
print()
|
print()
|
||||||
break
|
break
|
||||||
elif isinstance(stop, float):
|
elif isinstance(stop, float):
|
||||||
@@ -260,9 +259,12 @@ class PINN(object):
|
|||||||
|
|
||||||
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
|
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
|
||||||
for loss in losses:
|
for loss in losses:
|
||||||
print('{:.6e} '.format(loss), end='')
|
print('{:.6e} '.format(loss.item()), end='')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
self.trained_epoch += 1
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
return sum(losses).item()
|
return sum(losses).item()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -131,3 +131,22 @@ class Plotter:
|
|||||||
plt.savefig(filename)
|
plt.savefig(filename)
|
||||||
else:
|
else:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
def plot_loss(self, pinn, label=None, log_scale=True):
|
||||||
|
"""
|
||||||
|
Plot the loss trend
|
||||||
|
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not label:
|
||||||
|
label = str(pinn)
|
||||||
|
|
||||||
|
epochs = list(pinn.history_loss.keys())
|
||||||
|
loss = np.array(list(pinn.history_loss.values()))
|
||||||
|
if loss.ndim != 1:
|
||||||
|
loss = loss[:, 0]
|
||||||
|
|
||||||
|
plt.plot(epochs, loss, label=label)
|
||||||
|
if log_scale:
|
||||||
|
plt.yscale('log')
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ class Span(Location):
|
|||||||
def sample(self, n, mode='random', variables='all'):
|
def sample(self, n, mode='random', variables='all'):
|
||||||
"""TODO
|
"""TODO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _1d_sampler(n, mode, variables):
|
def _1d_sampler(n, mode, variables):
|
||||||
""" Sample independentely the variables and cross the results"""
|
""" Sample independentely the variables and cross the results"""
|
||||||
tmp = []
|
tmp = []
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ class Poisson(SpatialProblem):
|
|||||||
truth_solution = poisson_sol
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
problem = Poisson()
|
problem = Poisson()
|
||||||
model = FeedForward(2, 1)
|
|
||||||
|
|
||||||
|
model = FeedForward(problem.input_variables, problem.output_variables)
|
||||||
|
|
||||||
def test_constructor():
|
def test_constructor():
|
||||||
PINN(problem, model)
|
PINN(problem, model)
|
||||||
@@ -59,3 +59,23 @@ def test_span_pts():
|
|||||||
assert pinn.input_pts['D'].shape[0] == n**2
|
assert pinn.input_pts['D'].shape[0] == n**2
|
||||||
pinn.span_pts(n, 'random', locations=['D'])
|
pinn.span_pts(n, 'random', locations=['D'])
|
||||||
assert pinn.input_pts['D'].shape[0] == n
|
assert pinn.input_pts['D'].shape[0] == n
|
||||||
|
|
||||||
|
def test_train():
|
||||||
|
pinn = PINN(problem, model)
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
pinn.span_pts(n, 'grid', boundaries)
|
||||||
|
pinn.span_pts(n, 'grid', locations=['D'])
|
||||||
|
pinn.train(5)
|
||||||
|
|
||||||
|
def test_train():
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
param = [0, 3]
|
||||||
|
for i, truth_key in zip(param, expected_keys):
|
||||||
|
pinn = PINN(problem, model)
|
||||||
|
pinn.span_pts(n, 'grid', boundaries)
|
||||||
|
pinn.span_pts(n, 'grid', locations=['D'])
|
||||||
|
pinn.train(50, save_loss=i)
|
||||||
|
assert list(pinn.history_loss.keys()) == truth_key
|
||||||
Reference in New Issue
Block a user