Files
PINA/pina/trainer.py
Dario Coscia e0429bb445 PINN variants addition and Solvers Update (#263)
* gpinn/basepinn new classes, pinn restructure
* codacy fix gpinn/basepinn/pinn
* inverse problem fix
* Causal PINN (#267)
* fix GPU training in inverse problem (#283)
* Create a `compute_residual` attribute for `PINNInterface`
* Modify dataloading in solvers (#286)
* Modify PINNInterface by removing _loss_phys, _loss_data
* Adding in PINNInterface a variable to track the current condition during training
* Modify GPINN,PINN,CausalPINN to match changes in PINNInterface
* Competitive Pinn Addition (#288)
* fixing after rebase/ fix loss
* fixing final issues

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>

* Modify min max formulation to max min for paper consistency
* Adding SAPINN solver (#291)
* rom solver
* fix import

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Anna Ivagnes <75523024+annaivagnes@users.noreply.github.com>
Co-authored-by: valc89 <103250118+valc89@users.noreply.github.com>
Co-authored-by: Monthly Tag bot <mtbot@noreply.github.com>
Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
2024-05-10 14:07:01 +02:00

88 lines
3.0 KiB
Python

""" Trainer module. """
import torch
import pytorch_lightning
from .utils import check_consistency
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, **kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
:param solver: A pina:class:`SolverInterface` solver for the differential problem.
:type solver: SolverInterface
:param batch_size: How many samples per batch to load. If ``batch_size=None`` all
samples are loaded and data are not batched, defaults to None.
:type batch_size: int | None
:Keyword Arguments:
The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
"""
super().__init__(**kwargs)
# check inheritance consistency for solver and batch size
check_consistency(solver, SolverInterface)
if batch_size is not None:
check_consistency(batch_size, int)
self._model = solver
self.batch_size = batch_size
# create dataloader
if solver.problem.have_sampled_points is False:
raise RuntimeError(
f"Input points in {solver.problem.not_sampled_points} "
"training are None. Please "
"sample points in your problem by calling "
"discretise_domain function before train "
"in the provided locations."
)
self._create_or_update_loader()
def _create_or_update_loader(self):
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
devices = self._accelerator_connector._parallel_devices
if len(devices) > 1:
raise RuntimeError("Parallel training is not supported yet.")
device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device)
dataset_data = DataPointDataset(self._model.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
)
pb = self._model.problem
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device))
def train(self, **kwargs):
"""
Train the solver method.
"""
return super().fit(
self._model, train_dataloaders=self._loader, **kwargs
)
@property
def solver(self):
"""
Returning trainer solver.
"""
return self._model