Documentation for v0.1 version (#199)
* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
3f9305d475
commit
8b7b61b3bd
@@ -6,20 +6,53 @@ import copy
|
||||
|
||||
|
||||
class MetricTracker(Callback):
|
||||
"""
|
||||
PINA implementation of a Lightining Callback to track relevant
|
||||
metrics during training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback for Metric Tracking.
|
||||
|
||||
This class provides functionality to track relevant metrics during the training process.
|
||||
|
||||
:ivar _collection: A list to store collected metrics after each training epoch.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
|
||||
:return: A dictionary containing aggregated metric values.
|
||||
:rtype: dict
|
||||
|
||||
Example:
|
||||
>>> tracker = MetricTracker()
|
||||
>>> # ... Perform training ...
|
||||
>>> metrics = tracker.metrics
|
||||
"""
|
||||
self._collection = []
|
||||
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them
|
||||
"""
|
||||
Collect and track metrics at the end of each training epoch.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
:param _: Placeholder argument.
|
||||
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
self._collection.append(copy.deepcopy(
|
||||
trainer.logged_metrics)) # track them
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
common_keys = set.intersection(*map(set, self._collection))
|
||||
v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
|
||||
return v
|
||||
"""
|
||||
Aggregate collected metrics during training.
|
||||
|
||||
|
||||
:return: A dictionary containing aggregated metric values.
|
||||
:rtype: dict
|
||||
"""
|
||||
common_keys = set.intersection(*map(set, self._collection))
|
||||
v = {
|
||||
k: torch.stack([dic[k] for dic in self._collection])
|
||||
for k in common_keys
|
||||
}
|
||||
return v
|
||||
|
||||
Reference in New Issue
Block a user