diff --git a/pina/optim/__init__.py b/pina/optim/__init__.py index 38301bb..8266c8c 100644 --- a/pina/optim/__init__.py +++ b/pina/optim/__init__.py @@ -1,4 +1,4 @@ -"""Module for Optimizer class.""" +"""Module for the Optimizers and Schedulers.""" __all__ = [ "Optimizer", diff --git a/pina/optim/optimizer_interface.py b/pina/optim/optimizer_interface.py index d61ef4b..5f2fbe6 100644 --- a/pina/optim/optimizer_interface.py +++ b/pina/optim/optimizer_interface.py @@ -1,24 +1,23 @@ -"""Module for PINA Optimizer.""" +"""Module for the PINA Optimizer.""" from abc import ABCMeta, abstractmethod class Optimizer(metaclass=ABCMeta): """ - TODO - :param metaclass: _description_, defaults to ABCMeta - :type metaclass: _type_, optional + Abstract base class for defining an optimizer. All specific optimizers + should inherit form this class and implement the required methods. """ @property @abstractmethod def instance(self): """ - TODO + Abstract property to retrieve the optimizer instance. """ @abstractmethod def hook(self): """ - TODO + Abstract method to define the hook logic for the optimizer. """ diff --git a/pina/optim/scheduler_interface.py b/pina/optim/scheduler_interface.py index ddb515c..5ae5d8b 100644 --- a/pina/optim/scheduler_interface.py +++ b/pina/optim/scheduler_interface.py @@ -1,25 +1,23 @@ -"""Module for PINA Scheduler.""" +"""Module for the PINA Scheduler.""" from abc import ABCMeta, abstractmethod class Scheduler(metaclass=ABCMeta): """ - TODO - - :param metaclass: _description_, defaults to ABCMeta - :type metaclass: _type_, optional + Abstract base class for defining a scheduler. All specific schedulers should + inherit form this class and implement the required methods. """ @property @abstractmethod def instance(self): """ - TODO + Abstract property to retrieve the scheduler instance. """ @abstractmethod def hook(self): """ - TODO + Abstract method to define the hook logic for the scheduler. """ diff --git a/pina/optim/torch_optimizer.py b/pina/optim/torch_optimizer.py index 74b5337..fc34296 100644 --- a/pina/optim/torch_optimizer.py +++ b/pina/optim/torch_optimizer.py @@ -1,4 +1,4 @@ -"""Module for PINA Torch Optimizer""" +"""Module for the PINA Torch Optimizer""" import torch @@ -8,18 +8,17 @@ from .optimizer_interface import Optimizer class TorchOptimizer(Optimizer): """ - TODO - - :param Optimizer: _description_ - :type Optimizer: _type_ + A wrapper class for using PyTorch optimizers. """ def __init__(self, optimizer_class, **kwargs): """ - TODO + Initialization of the :class:`TorchOptimizer` class. - :param optimizer_class: _description_ - :type optimizer_class: _type_ + :param torch.optim.Optimizer optimizer_class: The PyTorch optimizer + class. + :param dict kwargs: Additional parameters passed to `optimizer_class`, + see more: _. """ check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True) @@ -29,10 +28,9 @@ class TorchOptimizer(Optimizer): def hook(self, parameters): """ - TODO + Initialize the optimizer instance with the given parameters. - :param parameters: _description_ - :type parameters: _type_ + :param dict parameters: The parameters of the model to be optimized. """ self._optimizer_instance = self.optimizer_class( parameters, **self.kwargs @@ -41,6 +39,9 @@ class TorchOptimizer(Optimizer): @property def instance(self): """ - Optimizer instance. + Get the optimizer instance. + + :return: The optimizer instance. + :rtype: torch.optim.Optimizer """ return self._optimizer_instance diff --git a/pina/optim/torch_scheduler.py b/pina/optim/torch_scheduler.py index 41c589c..c436e22 100644 --- a/pina/optim/torch_scheduler.py +++ b/pina/optim/torch_scheduler.py @@ -1,4 +1,4 @@ -"""Module for PINA Torch Optimizer""" +"""Module for the PINA Torch Optimizer""" try: from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 @@ -14,18 +14,17 @@ from .scheduler_interface import Scheduler class TorchScheduler(Scheduler): """ - TODO - - :param Scheduler: _description_ - :type Scheduler: _type_ + A wrapper class for using PyTorch schedulers. """ def __init__(self, scheduler_class, **kwargs): """ - TODO + Initialization of the :class:`TorchScheduler` class. - :param scheduler_class: _description_ - :type scheduler_class: _type_ + :param torch.optim.LRScheduler scheduler_class: The PyTorch scheduler + class. + :param dict kwargs: Additional parameters passed to `scheduler_class`, + see more: _. """ check_consistency(scheduler_class, LRScheduler, subclass=True) @@ -35,10 +34,9 @@ class TorchScheduler(Scheduler): def hook(self, optimizer): """ - TODO + Initialize the scheduler instance with the given parameters. - :param optimizer: _description_ - :type optimizer: _type_ + :param dict parameters: The parameters of the optimizer. """ check_consistency(optimizer, Optimizer) self._scheduler_instance = self.scheduler_class( @@ -48,6 +46,9 @@ class TorchScheduler(Scheduler): @property def instance(self): """ - Scheduler instance. + Get the scheduler instance. + + :return: The scheduelr instance. + :rtype: torch.optim.LRScheduler """ return self._scheduler_instance