add dataset and dataloader for sample points (#195)

* add dataset and dataloader for sample points
* unittests
This commit is contained in:
Nicola Demo
2023-11-07 11:34:44 +01:00
parent cd5bc9a558
commit d654259428
19 changed files with 581 additions and 196 deletions

View File

@@ -1,18 +1,19 @@
""" Solver module. """
import lightning.pytorch as pl
from pytorch_lightning import Trainer
from .utils import check_consistency
from .dataset import DummyLoader
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from .solvers.solver import SolverInterface
class Trainer(pl.Trainer):
class Trainer(Trainer):
def __init__(self, solver, **kwargs):
def __init__(self, solver, batch_size=None, **kwargs):
super().__init__(**kwargs)
# check inheritance consistency for solver
check_consistency(solver, SolverInterface)
self._model = solver
self.batch_size = batch_size
# create dataloader
if solver.problem.have_sampled_points is False:
@@ -22,19 +23,31 @@ class Trainer(pl.Trainer):
'discretise_domain function before train '
'in the provided locations.')
# TODO: make a better dataloader for train
self._create_or_update_loader()
# this method is used here because is resampling is needed
# during training, there is no need to define to touch the
# trainer dataloader, just call the method.
def _create_or_update_loader(self):
# get accellerator
device = self._accelerator_connector._accelerator_flag
self._loader = DummyLoader(self._model.problem.input_pts, device)
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
devices = self._accelerator_connector._parallel_devices
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
return super().fit(self._model, self._loader, **kwargs)
if len(devices) > 1:
raise RuntimeError('Parallel training is not supported yet.')
device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device)
dataset_data = DataPointDataset(self._model.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size,
shuffle=True)
def train(self, **kwargs):
"""
Train the solver.
"""
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)
@property
def solver(self):