preliminary modifications for N-S
This commit is contained in:
59
pina/pinn.py
59
pina/pinn.py
@@ -5,6 +5,7 @@ import numpy as np
|
||||
from pina.label_tensor import LabelTensor
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
|
||||
class PINN(object):
|
||||
|
||||
def __init__(self,
|
||||
@@ -13,7 +14,6 @@ class PINN(object):
|
||||
optimizer=torch.optim.Adam,
|
||||
lr=0.001,
|
||||
regularizer=0.00001,
|
||||
data_weight=1.,
|
||||
dtype=torch.float32,
|
||||
device='cpu',
|
||||
error_norm='mse'):
|
||||
@@ -53,13 +53,10 @@ class PINN(object):
|
||||
self.truth_values = {}
|
||||
self.input_pts = {}
|
||||
|
||||
|
||||
self.trained_epoch = 0
|
||||
self.optimizer = optimizer(
|
||||
self.model.parameters(), lr=lr, weight_decay=regularizer)
|
||||
|
||||
self.data_weight = data_weight
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
return self._problem
|
||||
@@ -96,6 +93,7 @@ class PINN(object):
|
||||
'optimizer_state' : self.optimizer.state_dict(),
|
||||
'optimizer_class' : self.optimizer.__class__,
|
||||
'history' : self.history,
|
||||
'input_points_dict' : self.input_pts,
|
||||
}
|
||||
|
||||
# TODO save also architecture param?
|
||||
@@ -117,22 +115,27 @@ class PINN(object):
|
||||
self.trained_epoch = checkpoint['epoch']
|
||||
self.history = checkpoint['history']
|
||||
|
||||
self.input_pts = checkpoint['input_points_dict']
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def span_pts(self, n, mode='grid', locations='all'):
|
||||
def span_pts(self, n_spatial, n_params=0, mode_spatial='grid', mode_param='random', locations='all'):
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.problem.conditions]
|
||||
|
||||
for location in locations:
|
||||
condition = self.problem.conditions[location]
|
||||
|
||||
try:
|
||||
pts = condition.location.sample(n, mode)
|
||||
pts = condition.location.sample(n_spatial, mode_spatial, variables=self.problem.spatial_variables)
|
||||
if n_params != 0:
|
||||
pts_params = condition.location.sample(n_params, mode_param, variables=self.problem.parameters)
|
||||
pts = LabelTensor(pts.repeat(n_params, 1), pts.labels)
|
||||
pts_params = LabelTensor(pts_params.repeat_interleave(n_spatial).reshape((n_spatial*n_params, len(self.problem.parameters))), pts_params.labels)
|
||||
pts = pts.append(pts_params)
|
||||
except:
|
||||
pts = condition.input_points
|
||||
|
||||
self.input_pts[location] = pts#.double() # TODO
|
||||
self.input_pts[location] = pts #.double() # TODO
|
||||
self.input_pts[location] = (
|
||||
self.input_pts[location].to(dtype=self.dtype,
|
||||
device=self.device))
|
||||
@@ -140,19 +143,16 @@ class PINN(object):
|
||||
self.input_pts[location].retain_grad()
|
||||
|
||||
|
||||
|
||||
def plot_pts(self, locations='all'):
|
||||
import matplotlib
|
||||
matplotlib.use('GTK3Agg')
|
||||
# matplotlib.use('GTK3Agg')
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.problem.conditions]
|
||||
|
||||
for location in locations:
|
||||
x, y = self.input_pts[location].tensor.T
|
||||
#plt.plot(x.detach(), y.detach(), 'o', label=location)
|
||||
np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
|
||||
|
||||
|
||||
x = self.input_pts[location].extract(['x'])
|
||||
y = self.input_pts[location].extract(['y'])
|
||||
plt.plot(x.detach(), y.detach(), '.', label=location)
|
||||
# np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
@@ -169,18 +169,23 @@ class PINN(object):
|
||||
for condition_name in self.problem.conditions:
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = self.input_pts[condition_name]
|
||||
|
||||
predicted = self.model(pts)
|
||||
|
||||
if isinstance(condition.function, list):
|
||||
for function in condition.function:
|
||||
residuals = function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
else:
|
||||
residuals = condition.function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
|
||||
if hasattr(condition, 'function'):
|
||||
if isinstance(condition.function, list):
|
||||
for function in condition.function:
|
||||
residuals = function(pts, predicted)
|
||||
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||
losses.append(local_loss)
|
||||
else:
|
||||
residuals = condition.function(pts, predicted)
|
||||
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||
losses.append(local_loss)
|
||||
elif hasattr(condition, 'output_points'):
|
||||
residuals = predicted - condition.output_points
|
||||
local_loss = condition.data_weight*self._compute_norm(residuals)
|
||||
losses.append(local_loss)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
sum(losses).backward()
|
||||
self.optimizer.step()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user