Files
PINA/pina/optim/torch_optimizer.py
Dario Coscia 9cae9a438f Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
2025-03-19 17:46:35 +01:00

27 lines
722 B
Python

""" Module for PINA Torch Optimizer """
import torch
from ..utils import check_consistency
from .optimizer_interface import Optimizer
class TorchOptimizer(Optimizer):
def __init__(self, optimizer_class, **kwargs):
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
self.optimizer_class = optimizer_class
self.kwargs = kwargs
self._optimizer_instance = None
def hook(self, parameters):
self._optimizer_instance = self.optimizer_class(parameters,
**self.kwargs)
@property
def instance(self):
"""
Optimizer instance.
"""
return self._optimizer_instance