Vectorial output
This commit is contained in:
15
pina/pinn.py
15
pina/pinn.py
@@ -65,8 +65,8 @@ class PINN(object):
|
||||
self.model = model
|
||||
self.model.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
self.input_pts = {}
|
||||
self.truth_values = {}
|
||||
self.input_pts = {}
|
||||
|
||||
|
||||
self.trained_epoch = 0
|
||||
@@ -171,13 +171,15 @@ class PINN(object):
|
||||
except:
|
||||
pts = condition.input_points
|
||||
|
||||
print(location, pts)
|
||||
|
||||
self.input_pts[location] = pts
|
||||
print(pts.tensor.shape)
|
||||
self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device)
|
||||
self.input_pts[location].tensor.requires_grad_(True)
|
||||
self.input_pts[location].tensor.retain_grad()
|
||||
|
||||
|
||||
|
||||
def plot_pts(self, locations='all'):
|
||||
import matplotlib
|
||||
matplotlib.use('GTK3Agg')
|
||||
@@ -209,8 +211,13 @@ class PINN(object):
|
||||
|
||||
predicted = self.model(pts)
|
||||
|
||||
residuals = condition.function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
if isinstance(condition.function, list):
|
||||
for function in condition.function:
|
||||
residuals = function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
else:
|
||||
residuals = condition.function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
sum(losses).backward()
|
||||
|
||||
Reference in New Issue
Block a user