Vectorial output
This commit is contained in:
@@ -15,6 +15,9 @@ class Condition:
|
|||||||
elif isinstance(args[0], Location) and callable(args[1]):
|
elif isinstance(args[0], Location) and callable(args[1]):
|
||||||
self.location = args[0]
|
self.location = args[0]
|
||||||
self.function = args[1]
|
self.function = args[1]
|
||||||
|
elif isinstance(args[0], Location) and isinstance(args[1], list):
|
||||||
|
self.location = args[0]
|
||||||
|
self.function = args[1]
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ class LabelTensor():
|
|||||||
self.tensor = x
|
self.tensor = x
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
if isinstance(key, (tuple, list)):
|
||||||
|
indeces = [self.labels.index(k) for k in key]
|
||||||
|
return LabelTensor(self.tensor[:, indeces], [self.labels[idx] for idx in indeces])
|
||||||
if key in self.labels:
|
if key in self.labels:
|
||||||
return self.tensor[:, self.labels.index(key)]
|
return self.tensor[:, self.labels.index(key)]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ class FeedForward(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
x = x[self.input_variables]
|
||||||
nf = len(self.extra_features)
|
nf = len(self.extra_features)
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
return LabelTensor(self.model(x.tensor), self.output_variables)
|
return LabelTensor(self.model(x.tensor), self.output_variables)
|
||||||
|
|||||||
15
pina/pinn.py
15
pina/pinn.py
@@ -65,8 +65,8 @@ class PINN(object):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.model.to(dtype=self.dtype, device=self.device)
|
self.model.to(dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
self.input_pts = {}
|
|
||||||
self.truth_values = {}
|
self.truth_values = {}
|
||||||
|
self.input_pts = {}
|
||||||
|
|
||||||
|
|
||||||
self.trained_epoch = 0
|
self.trained_epoch = 0
|
||||||
@@ -171,13 +171,15 @@ class PINN(object):
|
|||||||
except:
|
except:
|
||||||
pts = condition.input_points
|
pts = condition.input_points
|
||||||
|
|
||||||
|
print(location, pts)
|
||||||
|
|
||||||
self.input_pts[location] = pts
|
self.input_pts[location] = pts
|
||||||
print(pts.tensor.shape)
|
|
||||||
self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device)
|
self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device)
|
||||||
self.input_pts[location].tensor.requires_grad_(True)
|
self.input_pts[location].tensor.requires_grad_(True)
|
||||||
self.input_pts[location].tensor.retain_grad()
|
self.input_pts[location].tensor.retain_grad()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_pts(self, locations='all'):
|
def plot_pts(self, locations='all'):
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('GTK3Agg')
|
matplotlib.use('GTK3Agg')
|
||||||
@@ -209,8 +211,13 @@ class PINN(object):
|
|||||||
|
|
||||||
predicted = self.model(pts)
|
predicted = self.model(pts)
|
||||||
|
|
||||||
residuals = condition.function(pts, predicted)
|
if isinstance(condition.function, list):
|
||||||
losses.append(self._compute_norm(residuals))
|
for function in condition.function:
|
||||||
|
residuals = function(pts, predicted)
|
||||||
|
losses.append(self._compute_norm(residuals))
|
||||||
|
else:
|
||||||
|
residuals = condition.function(pts, predicted)
|
||||||
|
losses.append(self._compute_norm(residuals))
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
sum(losses).backward()
|
sum(losses).backward()
|
||||||
|
|||||||
@@ -84,11 +84,13 @@ class Plotter:
|
|||||||
"""
|
"""
|
||||||
res = 256
|
res = 256
|
||||||
pts = obj.problem.domain.sample(res, 'grid')
|
pts = obj.problem.domain.sample(res, 'grid')
|
||||||
|
print(pts)
|
||||||
grids_container = [
|
grids_container = [
|
||||||
pts[:, 0].reshape(res, res),
|
pts.tensor[:, 0].reshape(res, res),
|
||||||
pts[:, 1].reshape(res, res),
|
pts.tensor[:, 1].reshape(res, res),
|
||||||
]
|
]
|
||||||
predicted_output = obj.model(pts)
|
predicted_output = obj.model(pts)
|
||||||
|
predicted_output = predicted_output['p']
|
||||||
|
|
||||||
if hasattr(obj.problem, 'truth_solution'):
|
if hasattr(obj.problem, 'truth_solution'):
|
||||||
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
||||||
@@ -102,10 +104,56 @@ class Plotter:
|
|||||||
fig.colorbar(cb, ax=axes[2])
|
fig.colorbar(cb, ax=axes[2])
|
||||||
else:
|
else:
|
||||||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
||||||
|
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||||
fig.colorbar(cb, ax=axes)
|
fig.colorbar(cb, ax=axes)
|
||||||
|
|
||||||
if filename:
|
if filename:
|
||||||
plt.savefig(filename)
|
plt.savefig(filename)
|
||||||
else:
|
else:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot(self, obj, method='contourf', filename=None):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
res = 256
|
||||||
|
pts = obj.problem.domain.sample(res, 'grid')
|
||||||
|
print(pts)
|
||||||
|
grids_container = [
|
||||||
|
pts.tensor[:, 0].reshape(res, res),
|
||||||
|
pts.tensor[:, 1].reshape(res, res),
|
||||||
|
]
|
||||||
|
predicted_output = obj.model(pts)
|
||||||
|
predicted_output = predicted_output['ux']
|
||||||
|
|
||||||
|
if hasattr(obj.problem, 'truth_solution'):
|
||||||
|
truth_output = obj.problem.truth_solution(*pts.tensor.T).float()
|
||||||
|
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||||
|
|
||||||
|
cb = getattr(axes[0], method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
||||||
|
fig.colorbar(cb, ax=axes[0])
|
||||||
|
cb = getattr(axes[1], method)(*grids_container, truth_output.reshape(res, res).detach())
|
||||||
|
fig.colorbar(cb, ax=axes[1])
|
||||||
|
cb = getattr(axes[2], method)(*grids_container, (truth_output-predicted_output.tensor.float().flatten()).detach().reshape(res, res))
|
||||||
|
fig.colorbar(cb, ax=axes[2])
|
||||||
|
else:
|
||||||
|
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||||
|
# cb = getattr(axes, method)(*grids_container, predicted_output.tensor.reshape(res, res).detach())
|
||||||
|
cb = getattr(axes, method)(*grids_container, predicted_output.reshape(res, res).detach())
|
||||||
|
fig.colorbar(cb, ax=axes)
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
plt.savefig(filename)
|
||||||
|
else:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def plot_samples(self, obj):
|
||||||
|
|
||||||
|
for location in obj.input_pts:
|
||||||
|
plt.plot(*obj.input_pts[location].tensor.T.detach(), '.', label=location)
|
||||||
|
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class Span(Location):
|
|||||||
for _ in range(bounds.shape[0])])
|
for _ in range(bounds.shape[0])])
|
||||||
grids = np.meshgrid(*pts)
|
grids = np.meshgrid(*pts)
|
||||||
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
|
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
|
||||||
|
print(pts)
|
||||||
elif mode == 'lh' or mode == 'latin':
|
elif mode == 'lh' or mode == 'latin':
|
||||||
from scipy.stats import qmc
|
from scipy.stats import qmc
|
||||||
sampler = qmc.LatinHypercube(d=bounds.shape[0])
|
sampler = qmc.LatinHypercube(d=bounds.shape[0])
|
||||||
|
|||||||
Reference in New Issue
Block a user