Vectorial output

This commit is contained in:
Nicola Demo
2022-03-07 10:09:40 +01:00
parent 1812ddb8d9
commit 8a1f07c8ae
6 changed files with 71 additions and 7 deletions

View File

@@ -65,8 +65,8 @@ class PINN(object):
self.model = model
self.model.to(dtype=self.dtype, device=self.device)
self.input_pts = {}
self.truth_values = {}
self.input_pts = {}
self.trained_epoch = 0
@@ -171,13 +171,15 @@ class PINN(object):
except:
pts = condition.input_points
print(location, pts)
self.input_pts[location] = pts
print(pts.tensor.shape)
self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device)
self.input_pts[location].tensor.requires_grad_(True)
self.input_pts[location].tensor.retain_grad()
def plot_pts(self, locations='all'):
import matplotlib
matplotlib.use('GTK3Agg')
@@ -209,8 +211,13 @@ class PINN(object):
predicted = self.model(pts)
residuals = condition.function(pts, predicted)
losses.append(self._compute_norm(residuals))
if isinstance(condition.function, list):
for function in condition.function:
residuals = function(pts, predicted)
losses.append(self._compute_norm(residuals))
else:
residuals = condition.function(pts, predicted)
losses.append(self._compute_norm(residuals))
self.optimizer.zero_grad()
sum(losses).backward()