preliminary modifications for N-S

This commit is contained in:
Anna Ivagnes
2022-05-05 17:12:31 +02:00
parent d152fe67e3
commit 8130912926
13 changed files with 213 additions and 162 deletions

View File

@@ -5,6 +5,7 @@ import numpy as np
from pina.label_tensor import LabelTensor
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
class PINN(object):
def __init__(self,
@@ -13,7 +14,6 @@ class PINN(object):
optimizer=torch.optim.Adam,
lr=0.001,
regularizer=0.00001,
data_weight=1.,
dtype=torch.float32,
device='cpu',
error_norm='mse'):
@@ -53,13 +53,10 @@ class PINN(object):
self.truth_values = {}
self.input_pts = {}
self.trained_epoch = 0
self.optimizer = optimizer(
self.model.parameters(), lr=lr, weight_decay=regularizer)
self.data_weight = data_weight
@property
def problem(self):
return self._problem
@@ -96,6 +93,7 @@ class PINN(object):
'optimizer_state' : self.optimizer.state_dict(),
'optimizer_class' : self.optimizer.__class__,
'history' : self.history,
'input_points_dict' : self.input_pts,
}
# TODO save also architecture param?
@@ -117,22 +115,27 @@ class PINN(object):
self.trained_epoch = checkpoint['epoch']
self.history = checkpoint['history']
self.input_pts = checkpoint['input_points_dict']
return self
def span_pts(self, n, mode='grid', locations='all'):
def span_pts(self, n_spatial, n_params=0, mode_spatial='grid', mode_param='random', locations='all'):
if locations == 'all':
locations = [condition for condition in self.problem.conditions]
for location in locations:
condition = self.problem.conditions[location]
try:
pts = condition.location.sample(n, mode)
pts = condition.location.sample(n_spatial, mode_spatial, variables=self.problem.spatial_variables)
if n_params != 0:
pts_params = condition.location.sample(n_params, mode_param, variables=self.problem.parameters)
pts = LabelTensor(pts.repeat(n_params, 1), pts.labels)
pts_params = LabelTensor(pts_params.repeat_interleave(n_spatial).reshape((n_spatial*n_params, len(self.problem.parameters))), pts_params.labels)
pts = pts.append(pts_params)
except:
pts = condition.input_points
self.input_pts[location] = pts#.double() # TODO
self.input_pts[location] = pts #.double() # TODO
self.input_pts[location] = (
self.input_pts[location].to(dtype=self.dtype,
device=self.device))
@@ -140,19 +143,16 @@ class PINN(object):
self.input_pts[location].retain_grad()
def plot_pts(self, locations='all'):
import matplotlib
matplotlib.use('GTK3Agg')
# matplotlib.use('GTK3Agg')
if locations == 'all':
locations = [condition for condition in self.problem.conditions]
for location in locations:
x, y = self.input_pts[location].tensor.T
#plt.plot(x.detach(), y.detach(), 'o', label=location)
np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
x = self.input_pts[location].extract(['x'])
y = self.input_pts[location].extract(['y'])
plt.plot(x.detach(), y.detach(), '.', label=location)
# np.savetxt('burgers_{}_pts.txt'.format(location), self.input_pts[location].tensor.detach(), header='x y', delimiter=' ')
plt.legend()
plt.show()
@@ -169,18 +169,23 @@ class PINN(object):
for condition_name in self.problem.conditions:
condition = self.problem.conditions[condition_name]
pts = self.input_pts[condition_name]
predicted = self.model(pts)
if isinstance(condition.function, list):
for function in condition.function:
residuals = function(pts, predicted)
losses.append(self._compute_norm(residuals))
else:
residuals = condition.function(pts, predicted)
losses.append(self._compute_norm(residuals))
if hasattr(condition, 'function'):
if isinstance(condition.function, list):
for function in condition.function:
residuals = function(pts, predicted)
local_loss = condition.data_weight*self._compute_norm(residuals)
losses.append(local_loss)
else:
residuals = condition.function(pts, predicted)
local_loss = condition.data_weight*self._compute_norm(residuals)
losses.append(local_loss)
elif hasattr(condition, 'output_points'):
residuals = predicted - condition.output_points
local_loss = condition.data_weight*self._compute_norm(residuals)
losses.append(local_loss)
self.optimizer.zero_grad()
sum(losses).backward()
self.optimizer.step()