* Fixing mean tracked loss

* Adding a PINA progress bar
This commit is contained in:
dario-coscia
2024-08-06 12:31:01 +02:00
committed by Nicola Demo
parent 0fa4e1e58a
commit cce9876751
6 changed files with 194 additions and 36 deletions

View File

@@ -15,7 +15,7 @@ The pipeline to solve differential equations with PINA follows just five steps:
2. Generate data using built in `Geometries`_, or load high level simulation results as :doc:`LabelTensor <label_tensor>` 2. Generate data using built in `Geometries`_, or load high level simulation results as :doc:`LabelTensor <label_tensor>`
3. Choose or build one or more `Models`_ to solve the problem 3. Choose or build one or more `Models`_ to solve the problem
4. Choose a solver across PINA available `Solvers`_, or build one using the :doc:`SolverInterface <solvers/solver_interface>` 4. Choose a solver across PINA available `Solvers`_, or build one using the :doc:`SolverInterface <solvers/solver_interface>`
5. Train the model with the PINA :doc:`Trainer <solvers/solver_interface>`, enhance the train with `Callbacks_` 5. Train the model with the PINA :doc:`Trainer <solvers/solver_interface>`, enhance the train with `Callbacks`_
PINA Features PINA Features
-------------- --------------
@@ -155,9 +155,9 @@ Callbacks
.. toctree:: .. toctree::
:titlesonly: :titlesonly:
Metric tracking <callbacks/processing_callbacks.rst> Processing Callbacks <callbacks/processing_callbacks.rst>
Optimizer callbacks <callbacks/optimizer_callbacks.rst> Optimizer Callbacks <callbacks/optimizer_callbacks.rst>
Adaptive Refinments <callbacks/adaptive_refinment_callbacks.rst> Adaptive Refinment Callback <callbacks/adaptive_refinment_callbacks.rst>
Metrics and Losses Metrics and Losses
-------------------- --------------------

View File

@@ -5,3 +5,7 @@ Processing callbacks
.. autoclass:: MetricTracker .. autoclass:: MetricTracker
:members: :members:
:show-inheritance: :show-inheritance:
.. autoclass:: PINAProgressBar
:members:
:show-inheritance:

View File

@@ -1,5 +1,10 @@
__all__ = ["SwitchOptimizer", "R3Refinement", "MetricTracker"] __all__ = [
"SwitchOptimizer",
"R3Refinement",
"MetricTracker",
"PINAProgressBar"
]
from .optimizer_callbacks import SwitchOptimizer from .optimizer_callbacks import SwitchOptimizer
from .adaptive_refinment_callbacks import R3Refinement from .adaptive_refinment_callbacks import R3Refinement
from .processing_callbacks import MetricTracker from .processing_callbacks import MetricTracker, PINAProgressBar

View File

@@ -1,11 +1,13 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.module import LightningModule from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.trainer.trainer import Trainer
import torch import torch
import copy import copy
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
from pina.utils import check_consistency
class MetricTracker(Callback): class MetricTracker(Callback):
@@ -13,9 +15,11 @@ class MetricTracker(Callback):
""" """
PINA Implementation of a Lightning Callback for Metric Tracking. PINA Implementation of a Lightning Callback for Metric Tracking.
This class provides functionality to track relevant metrics during the training process. This class provides functionality to track relevant metrics during
the training process.
:ivar _collection: A list to store collected metrics after each training epoch. :ivar _collection: A list to store collected metrics after each
training epoch.
:param trainer: The trainer object managing the training process. :param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer :type trainer: pytorch_lightning.Trainer
@@ -28,20 +32,16 @@ class MetricTracker(Callback):
>>> # ... Perform training ... >>> # ... Perform training ...
>>> metrics = tracker.metrics >>> metrics = tracker.metrics
""" """
super().__init__()
self._collection = [] self._collection = []
def on_train_epoch_start(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
""" """
Collect and track metrics at the start of each training epoch. At epoch Collect and track metrics at the end of each training epoch.
zero the metric is not saved. At epoch ``k`` the metric which is tracked
is the one of epoch ``k-1``.
:param trainer: The trainer object managing the training process. :param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer :type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument. :param pl_module: Placeholder argument.
:return: None
:rtype: None
""" """
super().on_train_epoch_end(trainer, pl_module) super().on_train_epoch_end(trainer, pl_module)
if trainer.current_epoch > 0: if trainer.current_epoch > 0:
@@ -49,23 +49,6 @@ class MetricTracker(Callback):
copy.deepcopy(trainer.logged_metrics) copy.deepcopy(trainer.logged_metrics)
) # track them ) # track them
def on_train_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of training.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument.
:return: None
:rtype: None
"""
super().on_train_end(trainer, pl_module)
if trainer.current_epoch > 0:
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
@property @property
def metrics(self): def metrics(self):
""" """
@@ -80,3 +63,91 @@ class MetricTracker(Callback):
for k in common_keys for k in common_keys
} }
return v return v
class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics='mean', **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
This class provides functionality to display only relevant metrics
during the training process.
:param metrics: Logged metrics to display during the training. It should
be a subset of the conditions keys defined in
:obj:`pina.condition.Condition`.
:type metrics: str | list(str) | tuple(str)
:Keyword Arguments:
The additional keyword arguments specify the progress bar
and can be choosen from the `pytorch-lightning
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callbacks/progress/tqdm_progress.html#TQDMProgressBar>`_
Example:
>>> pbar = PINAProgressBar(['mean'])
>>> # ... Perform training ...
>>> trainer = Trainer(solver, callbacks=[pbar])
"""
super().__init__(**kwargs)
# check consistency
if not isinstance(metrics, (list, tuple)):
metrics = [metrics]
check_consistency(metrics, str)
self._sorted_metrics = metrics
def get_metrics(self, trainer, pl_module):
r"""Combines progress bar metrics collected from the trainer with
standard metrics from get_standard_metrics.
Implement this to override the items displayed in the progress bar.
The progress bar metrics are sorted according to ``metrics``.
Here is an example of how to override the defaults:
.. code-block:: python
def get_metrics(self, trainer, model):
# don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items
:return: Dictionary with the items to be displayed in the progress bar.
:rtype: tuple(dict)
"""
standard_metrics = get_standard_metrics(trainer)
pbar_metrics = trainer.progress_bar_metrics
if pbar_metrics:
pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics
}
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
" If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",
)
return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module):
"""
Check that the metrics defined in the initialization are available,
i.e. are correctly logged.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument.
"""
# Check if all keys in sort_keys are present in the dictionary
for key in self._sorted_metrics:
if key not in trainer.solver.problem.conditions.keys() and key != 'mean':
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix
self._sorted_metrics = [
metric + '_loss' for metric in self._sorted_metrics]
return super().on_fit_start(trainer, pl_module)

View File

@@ -126,8 +126,9 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
# clamp unknown parameters in InverseProblem (if needed) # clamp unknown parameters in InverseProblem (if needed)
self._clamp_params() self._clamp_params()
# total loss (must be a torch.Tensor) # total loss (must be a torch.Tensor), and logs
total_loss = sum(condition_losses) total_loss = sum(condition_losses)
self.save_logs_and_release()
return total_loss.as_subclass(torch.Tensor) return total_loss.as_subclass(torch.Tensor)
def loss_data(self, input_tensor, output_tensor): def loss_data(self, input_tensor, output_tensor):
@@ -205,7 +206,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
) )
self.__logged_res_losses.append(loss_value) self.__logged_res_losses.append(loss_value)
def on_train_epoch_end(self): def save_logs_and_release(self):
""" """
At the end of each epoch we free the stored losses. This function At the end of each epoch we free the stored losses. This function
should not be override if not intentionally. should not be override if not intentionally.
@@ -218,7 +219,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
) )
# free the logged losses # free the logged losses
self.__logged_res_losses = [] self.__logged_res_losses = []
return super().on_train_epoch_end()
def _clamp_inverse_problem_params(self): def _clamp_inverse_problem_params(self):
""" """

View File

@@ -0,0 +1,78 @@
import torch
import pytest
from pina.problem import SpatialProblem
from pina.operators import laplacian
from pina.geometry import CartesianDomain
from pina import Condition, LabelTensor
from pina.solvers import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.equation.equation import Equation
from pina.equation.equation_factory import FixedValue
from pina.callbacks.processing_callbacks import PINAProgressBar
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
torch.sin(input_.extract(['y']) * torch.pi))
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
conditions = {
'gamma1': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 1}),
equation=FixedValue(0.0)),
'gamma2': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 0}),
equation=FixedValue(0.0)),
'gamma3': Condition(
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'gamma4': Condition(
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
equation=FixedValue(0.0)),
'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
equation=my_laplace),
'data': Condition(
input_points=in_,
output_points=out_)
}
# make the problem
poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables))
# make the solver
solver = PINN(problem=poisson_problem, model=model)
def test_progress_bar_constructor():
PINAProgressBar(['mean_loss'])
def test_progress_bar_routine():
# make the trainer
trainer = Trainer(solver=solver,
callbacks=[
PINAProgressBar(['mean', 'D'])
],
accelerator='cpu',
max_epochs=5)
trainer.train()
# TODO there should be a check that the correct metrics are displayed