* Fixing mean tracked loss

* Adding a PINA progress bar
This commit is contained in:
dario-coscia
2024-08-06 12:31:01 +02:00
committed by Nicola Demo
parent 0fa4e1e58a
commit cce9876751
6 changed files with 194 additions and 36 deletions

View File

@@ -126,8 +126,9 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
# total loss (must be a torch.Tensor)
# total loss (must be a torch.Tensor), and logs
total_loss = sum(condition_losses)
self.save_logs_and_release()
return total_loss.as_subclass(torch.Tensor)
def loss_data(self, input_tensor, output_tensor):
@@ -205,7 +206,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
)
self.__logged_res_losses.append(loss_value)
def on_train_epoch_end(self):
def save_logs_and_release(self):
"""
At the end of each epoch we free the stored losses. This function
should not be override if not intentionally.
@@ -218,7 +219,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
)
# free the logged losses
self.__logged_res_losses = []
return super().on_train_epoch_end()
def _clamp_inverse_problem_params(self):
"""