Update RBAPINN

This commit is contained in:
Dario Coscia
2024-10-01 15:40:32 +02:00
committed by Nicola Demo
parent 5d2ca62e65
commit 801b6b8d34

View File

@@ -117,7 +117,7 @@ class RBAPINN(PINN):
# initialize weights
self.weights = {}
for condition_name in problem.conditions:
self.weights[condition_name] = 0
self.weights[condition_name] = 1
# define vectorial loss
self._vectorial_loss = deepcopy(loss)
@@ -158,7 +158,7 @@ class RBAPINN(PINN):
residual = self.compute_residual(samples=samples, equation=equation)
cond = self.current_condition_name
r_norm = self.eta * torch.abs(residual) / torch.max(torch.abs(residual))
r_norm = self.eta * torch.abs(residual) / (torch.max(torch.abs(residual))+1e-12)
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
loss_value = self._vectorial_loss(