* Fixing mean tracked loss
* Adding a PINA progress bar
This commit is contained in:
committed by
Nicola Demo
parent
0fa4e1e58a
commit
cce9876751
@@ -15,7 +15,7 @@ The pipeline to solve differential equations with PINA follows just five steps:
|
|||||||
2. Generate data using built in `Geometries`_, or load high level simulation results as :doc:`LabelTensor <label_tensor>`
|
2. Generate data using built in `Geometries`_, or load high level simulation results as :doc:`LabelTensor <label_tensor>`
|
||||||
3. Choose or build one or more `Models`_ to solve the problem
|
3. Choose or build one or more `Models`_ to solve the problem
|
||||||
4. Choose a solver across PINA available `Solvers`_, or build one using the :doc:`SolverInterface <solvers/solver_interface>`
|
4. Choose a solver across PINA available `Solvers`_, or build one using the :doc:`SolverInterface <solvers/solver_interface>`
|
||||||
5. Train the model with the PINA :doc:`Trainer <solvers/solver_interface>`, enhance the train with `Callbacks_`
|
5. Train the model with the PINA :doc:`Trainer <solvers/solver_interface>`, enhance the train with `Callbacks`_
|
||||||
|
|
||||||
PINA Features
|
PINA Features
|
||||||
--------------
|
--------------
|
||||||
@@ -155,9 +155,9 @@ Callbacks
|
|||||||
.. toctree::
|
.. toctree::
|
||||||
:titlesonly:
|
:titlesonly:
|
||||||
|
|
||||||
Metric tracking <callbacks/processing_callbacks.rst>
|
Processing Callbacks <callbacks/processing_callbacks.rst>
|
||||||
Optimizer callbacks <callbacks/optimizer_callbacks.rst>
|
Optimizer Callbacks <callbacks/optimizer_callbacks.rst>
|
||||||
Adaptive Refinments <callbacks/adaptive_refinment_callbacks.rst>
|
Adaptive Refinment Callback <callbacks/adaptive_refinment_callbacks.rst>
|
||||||
|
|
||||||
Metrics and Losses
|
Metrics and Losses
|
||||||
--------------------
|
--------------------
|
||||||
|
|||||||
@@ -3,5 +3,9 @@ Processing callbacks
|
|||||||
|
|
||||||
.. currentmodule:: pina.callbacks.processing_callbacks
|
.. currentmodule:: pina.callbacks.processing_callbacks
|
||||||
.. autoclass:: MetricTracker
|
.. autoclass:: MetricTracker
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: PINAProgressBar
|
||||||
:members:
|
:members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
@@ -1,5 +1,10 @@
|
|||||||
__all__ = ["SwitchOptimizer", "R3Refinement", "MetricTracker"]
|
__all__ = [
|
||||||
|
"SwitchOptimizer",
|
||||||
|
"R3Refinement",
|
||||||
|
"MetricTracker",
|
||||||
|
"PINAProgressBar"
|
||||||
|
]
|
||||||
|
|
||||||
from .optimizer_callbacks import SwitchOptimizer
|
from .optimizer_callbacks import SwitchOptimizer
|
||||||
from .adaptive_refinment_callbacks import R3Refinement
|
from .adaptive_refinment_callbacks import R3Refinement
|
||||||
from .processing_callbacks import MetricTracker
|
from .processing_callbacks import MetricTracker, PINAProgressBar
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
"""PINA Callbacks Implementations"""
|
"""PINA Callbacks Implementations"""
|
||||||
|
|
||||||
from pytorch_lightning.callbacks import Callback
|
|
||||||
from pytorch_lightning.core.module import LightningModule
|
from pytorch_lightning.core.module import LightningModule
|
||||||
from pytorch_lightning.trainer.trainer import Trainer
|
from pytorch_lightning.trainer.trainer import Trainer
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
|
||||||
|
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
|
||||||
class MetricTracker(Callback):
|
class MetricTracker(Callback):
|
||||||
|
|
||||||
@@ -13,9 +15,11 @@ class MetricTracker(Callback):
|
|||||||
"""
|
"""
|
||||||
PINA Implementation of a Lightning Callback for Metric Tracking.
|
PINA Implementation of a Lightning Callback for Metric Tracking.
|
||||||
|
|
||||||
This class provides functionality to track relevant metrics during the training process.
|
This class provides functionality to track relevant metrics during
|
||||||
|
the training process.
|
||||||
|
|
||||||
:ivar _collection: A list to store collected metrics after each training epoch.
|
:ivar _collection: A list to store collected metrics after each
|
||||||
|
training epoch.
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
:param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
:type trainer: pytorch_lightning.Trainer
|
||||||
@@ -28,20 +32,16 @@ class MetricTracker(Callback):
|
|||||||
>>> # ... Perform training ...
|
>>> # ... Perform training ...
|
||||||
>>> metrics = tracker.metrics
|
>>> metrics = tracker.metrics
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
self._collection = []
|
self._collection = []
|
||||||
|
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
"""
|
"""
|
||||||
Collect and track metrics at the start of each training epoch. At epoch
|
Collect and track metrics at the end of each training epoch.
|
||||||
zero the metric is not saved. At epoch ``k`` the metric which is tracked
|
|
||||||
is the one of epoch ``k-1``.
|
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
:param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
:type trainer: pytorch_lightning.Trainer
|
||||||
:param pl_module: Placeholder argument.
|
:param pl_module: Placeholder argument.
|
||||||
|
|
||||||
:return: None
|
|
||||||
:rtype: None
|
|
||||||
"""
|
"""
|
||||||
super().on_train_epoch_end(trainer, pl_module)
|
super().on_train_epoch_end(trainer, pl_module)
|
||||||
if trainer.current_epoch > 0:
|
if trainer.current_epoch > 0:
|
||||||
@@ -49,23 +49,6 @@ class MetricTracker(Callback):
|
|||||||
copy.deepcopy(trainer.logged_metrics)
|
copy.deepcopy(trainer.logged_metrics)
|
||||||
) # track them
|
) # track them
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module):
|
|
||||||
"""
|
|
||||||
Collect and track metrics at the end of training.
|
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
|
||||||
:type trainer: pytorch_lightning.Trainer
|
|
||||||
:param pl_module: Placeholder argument.
|
|
||||||
|
|
||||||
:return: None
|
|
||||||
:rtype: None
|
|
||||||
"""
|
|
||||||
super().on_train_end(trainer, pl_module)
|
|
||||||
if trainer.current_epoch > 0:
|
|
||||||
self._collection.append(
|
|
||||||
copy.deepcopy(trainer.logged_metrics)
|
|
||||||
) # track them
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metrics(self):
|
def metrics(self):
|
||||||
"""
|
"""
|
||||||
@@ -80,3 +63,91 @@ class MetricTracker(Callback):
|
|||||||
for k in common_keys
|
for k in common_keys
|
||||||
}
|
}
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class PINAProgressBar(TQDMProgressBar):
|
||||||
|
|
||||||
|
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||||
|
def __init__(self, metrics='mean', **kwargs):
|
||||||
|
"""
|
||||||
|
PINA Implementation of a Lightning Callback for enriching the progress
|
||||||
|
bar.
|
||||||
|
|
||||||
|
This class provides functionality to display only relevant metrics
|
||||||
|
during the training process.
|
||||||
|
|
||||||
|
:param metrics: Logged metrics to display during the training. It should
|
||||||
|
be a subset of the conditions keys defined in
|
||||||
|
:obj:`pina.condition.Condition`.
|
||||||
|
:type metrics: str | list(str) | tuple(str)
|
||||||
|
|
||||||
|
:Keyword Arguments:
|
||||||
|
The additional keyword arguments specify the progress bar
|
||||||
|
and can be choosen from the `pytorch-lightning
|
||||||
|
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callbacks/progress/tqdm_progress.html#TQDMProgressBar>`_
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> pbar = PINAProgressBar(['mean'])
|
||||||
|
>>> # ... Perform training ...
|
||||||
|
>>> trainer = Trainer(solver, callbacks=[pbar])
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# check consistency
|
||||||
|
if not isinstance(metrics, (list, tuple)):
|
||||||
|
metrics = [metrics]
|
||||||
|
check_consistency(metrics, str)
|
||||||
|
self._sorted_metrics = metrics
|
||||||
|
|
||||||
|
def get_metrics(self, trainer, pl_module):
|
||||||
|
r"""Combines progress bar metrics collected from the trainer with
|
||||||
|
standard metrics from get_standard_metrics.
|
||||||
|
Implement this to override the items displayed in the progress bar.
|
||||||
|
The progress bar metrics are sorted according to ``metrics``.
|
||||||
|
|
||||||
|
Here is an example of how to override the defaults:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def get_metrics(self, trainer, model):
|
||||||
|
# don't show the version number
|
||||||
|
items = super().get_metrics(trainer, model)
|
||||||
|
items.pop("v_num", None)
|
||||||
|
return items
|
||||||
|
|
||||||
|
:return: Dictionary with the items to be displayed in the progress bar.
|
||||||
|
:rtype: tuple(dict)
|
||||||
|
|
||||||
|
"""
|
||||||
|
standard_metrics = get_standard_metrics(trainer)
|
||||||
|
pbar_metrics = trainer.progress_bar_metrics
|
||||||
|
if pbar_metrics:
|
||||||
|
pbar_metrics = {
|
||||||
|
key: pbar_metrics[key] for key in self._sorted_metrics
|
||||||
|
}
|
||||||
|
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
|
||||||
|
if duplicates:
|
||||||
|
rank_zero_warn(
|
||||||
|
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
|
||||||
|
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
|
||||||
|
" If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {**standard_metrics, **pbar_metrics}
|
||||||
|
|
||||||
|
def on_fit_start(self, trainer, pl_module):
|
||||||
|
"""
|
||||||
|
Check that the metrics defined in the initialization are available,
|
||||||
|
i.e. are correctly logged.
|
||||||
|
|
||||||
|
:param trainer: The trainer object managing the training process.
|
||||||
|
:type trainer: pytorch_lightning.Trainer
|
||||||
|
:param pl_module: Placeholder argument.
|
||||||
|
"""
|
||||||
|
# Check if all keys in sort_keys are present in the dictionary
|
||||||
|
for key in self._sorted_metrics:
|
||||||
|
if key not in trainer.solver.problem.conditions.keys() and key != 'mean':
|
||||||
|
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
||||||
|
# add the loss pedix
|
||||||
|
self._sorted_metrics = [
|
||||||
|
metric + '_loss' for metric in self._sorted_metrics]
|
||||||
|
return super().on_fit_start(trainer, pl_module)
|
||||||
@@ -126,8 +126,9 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
# clamp unknown parameters in InverseProblem (if needed)
|
# clamp unknown parameters in InverseProblem (if needed)
|
||||||
self._clamp_params()
|
self._clamp_params()
|
||||||
|
|
||||||
# total loss (must be a torch.Tensor)
|
# total loss (must be a torch.Tensor), and logs
|
||||||
total_loss = sum(condition_losses)
|
total_loss = sum(condition_losses)
|
||||||
|
self.save_logs_and_release()
|
||||||
return total_loss.as_subclass(torch.Tensor)
|
return total_loss.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
def loss_data(self, input_tensor, output_tensor):
|
def loss_data(self, input_tensor, output_tensor):
|
||||||
@@ -205,7 +206,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
self.__logged_res_losses.append(loss_value)
|
self.__logged_res_losses.append(loss_value)
|
||||||
|
|
||||||
def on_train_epoch_end(self):
|
def save_logs_and_release(self):
|
||||||
"""
|
"""
|
||||||
At the end of each epoch we free the stored losses. This function
|
At the end of each epoch we free the stored losses. This function
|
||||||
should not be override if not intentionally.
|
should not be override if not intentionally.
|
||||||
@@ -218,7 +219,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
# free the logged losses
|
# free the logged losses
|
||||||
self.__logged_res_losses = []
|
self.__logged_res_losses = []
|
||||||
return super().on_train_epoch_end()
|
|
||||||
|
|
||||||
def _clamp_inverse_problem_params(self):
|
def _clamp_inverse_problem_params(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
78
tests/test_callbacks/test_progress_bar.py
Normal file
78
tests/test_callbacks/test_progress_bar.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.problem import SpatialProblem
|
||||||
|
from pina.operators import laplacian
|
||||||
|
from pina.geometry import CartesianDomain
|
||||||
|
from pina import Condition, LabelTensor
|
||||||
|
from pina.solvers import PINN
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.equation.equation import Equation
|
||||||
|
from pina.equation.equation_factory import FixedValue
|
||||||
|
from pina.callbacks.processing_callbacks import PINAProgressBar
|
||||||
|
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_):
|
||||||
|
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(input_.extract(['y']) * torch.pi))
|
||||||
|
delta_u = laplacian(output_.extract(['u']), input_)
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
|
||||||
|
my_laplace = Equation(laplace_equation)
|
||||||
|
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||||
|
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||||
|
|
||||||
|
|
||||||
|
class Poisson(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||||
|
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma2': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma3': Condition(
|
||||||
|
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma4': Condition(
|
||||||
|
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'D': Condition(
|
||||||
|
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||||
|
equation=my_laplace),
|
||||||
|
'data': Condition(
|
||||||
|
input_points=in_,
|
||||||
|
output_points=out_)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# make the problem
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
model = FeedForward(len(poisson_problem.input_variables),
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
|
||||||
|
# make the solver
|
||||||
|
solver = PINN(problem=poisson_problem, model=model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_progress_bar_constructor():
|
||||||
|
PINAProgressBar(['mean_loss'])
|
||||||
|
|
||||||
|
def test_progress_bar_routine():
|
||||||
|
# make the trainer
|
||||||
|
trainer = Trainer(solver=solver,
|
||||||
|
callbacks=[
|
||||||
|
PINAProgressBar(['mean', 'D'])
|
||||||
|
],
|
||||||
|
accelerator='cpu',
|
||||||
|
max_epochs=5)
|
||||||
|
trainer.train()
|
||||||
|
# TODO there should be a check that the correct metrics are displayed
|
||||||
Reference in New Issue
Block a user