version 0.0.1
This commit is contained in:
31
pina/pinn.py
31
pina/pinn.py
@@ -163,17 +163,16 @@ class PINN(object):
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.problem.conditions]
|
||||
|
||||
|
||||
for location in locations:
|
||||
manifold, func = self.problem.conditions[location].values()
|
||||
if torch.is_tensor(manifold):
|
||||
pts = manifold
|
||||
else:
|
||||
pts = manifold.discretize(n, mode)
|
||||
condition = self.problem.conditions[location]
|
||||
|
||||
pts = torch.from_numpy(pts)
|
||||
try:
|
||||
pts = condition.location.sample(n, mode)
|
||||
except:
|
||||
pts = condition.input_points
|
||||
|
||||
self.input_pts[location] = LabelTensor(pts, self.problem.input_variables)
|
||||
self.input_pts[location] = pts
|
||||
print(pts.tensor.shape)
|
||||
self.input_pts[location].tensor.to(dtype=self.dtype, device=self.device)
|
||||
self.input_pts[location].tensor.requires_grad_(True)
|
||||
self.input_pts[location].tensor.retain_grad()
|
||||
@@ -203,17 +202,15 @@ class PINN(object):
|
||||
while True:
|
||||
|
||||
losses = []
|
||||
|
||||
for condition_name in self.problem.conditions:
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = self.input_pts[condition_name]
|
||||
predicted = self.model(pts.tensor)
|
||||
if isinstance(self.problem.conditions[condition_name]['func'], list):
|
||||
for func in self.problem.conditions[condition_name]['func']:
|
||||
residuals = func(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
else:
|
||||
residuals = self.problem.conditions[condition_name]['func'](pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
#print(condition_name, losses[-1])
|
||||
|
||||
predicted = self.model(pts)
|
||||
|
||||
residuals = condition.function(pts, predicted)
|
||||
losses.append(self._compute_norm(residuals))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
sum(losses).backward()
|
||||
|
||||
Reference in New Issue
Block a user