Documentation for v0.1 version (#199)

* Adding Equations, solving typos
* improve _code.rst
* the team rst and restuctore index.rst
* fixing errors

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:39:00 +01:00
committed by Nicola Demo
parent 3f9305d475
commit 8b7b61b3bd
144 changed files with 2741 additions and 1766 deletions

View File

@@ -6,20 +6,53 @@ import copy
class MetricTracker(Callback):
"""
PINA implementation of a Lightining Callback to track relevant
metrics during training.
"""
def __init__(self):
"""
PINA Implementation of a Lightning Callback for Metric Tracking.
This class provides functionality to track relevant metrics during the training process.
:ivar _collection: A list to store collected metrics after each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:return: A dictionary containing aggregated metric values.
:rtype: dict
Example:
>>> tracker = MetricTracker()
>>> # ... Perform training ...
>>> metrics = tracker.metrics
"""
self._collection = []
def on_train_epoch_end(self, trainer, __):
self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them
"""
Collect and track metrics at the end of each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param _: Placeholder argument.
:return: None
:rtype: None
"""
self._collection.append(copy.deepcopy(
trainer.logged_metrics)) # track them
@property
def metrics(self):
common_keys = set.intersection(*map(set, self._collection))
v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
return v
"""
Aggregate collected metrics during training.
:return: A dictionary containing aggregated metric values.
:rtype: dict
"""
common_keys = set.intersection(*map(set, self._collection))
v = {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys
}
return v