Update RBAPINN
This commit is contained in:
committed by
Nicola Demo
parent
5d2ca62e65
commit
801b6b8d34
@@ -117,7 +117,7 @@ class RBAPINN(PINN):
|
||||
# initialize weights
|
||||
self.weights = {}
|
||||
for condition_name in problem.conditions:
|
||||
self.weights[condition_name] = 0
|
||||
self.weights[condition_name] = 1
|
||||
|
||||
# define vectorial loss
|
||||
self._vectorial_loss = deepcopy(loss)
|
||||
@@ -158,7 +158,7 @@ class RBAPINN(PINN):
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
cond = self.current_condition_name
|
||||
|
||||
r_norm = self.eta * torch.abs(residual) / torch.max(torch.abs(residual))
|
||||
r_norm = self.eta * torch.abs(residual) / (torch.max(torch.abs(residual))+1e-12)
|
||||
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
|
||||
|
||||
loss_value = self._vectorial_loss(
|
||||
|
||||
Reference in New Issue
Block a user