Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -1,20 +1,21 @@
""" Trainer module. """
import warnings
import torch
import pytorch_lightning
import lightning
from .utils import check_consistency
from .data import PinaDataModule
from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
class Trainer(lightning.pytorch.Trainer):
def __init__(self,
solver,
batch_size=None,
train_size=.7,
test_size=.2,
eval_size=.1,
val_size=.1,
predict_size=.0,
**kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
@@ -39,11 +40,13 @@ class Trainer(pytorch_lightning.Trainer):
check_consistency(batch_size, int)
self.train_size = train_size
self.test_size = test_size
self.eval_size = eval_size
self.val_size = val_size
self.predict_size = predict_size
self.solver = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
self.data_module = None
self._create_loader()
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
@@ -64,34 +67,34 @@ class Trainer(pytorch_lightning.Trainer):
if not self.solver.problem.collector.full:
error_message = '\n'.join([
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}""" for key, value in
"not sampled"}""" for key, value in
self._solver.problem.collector._is_conditions_ready.items()
])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
devices = self._accelerator_connector._parallel_devices
if len(devices) > 1:
raise RuntimeError("Parallel training is not supported yet.")
device = devices[0]
data_module = PinaDataModule(problem=self.solver.problem,
device=device,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.eval_size)
data_module.setup()
self._loader = data_module.train_dataloader()
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.val_size,
predict_size=self.predict_size,
batch_size=self.batch_size,)
def train(self, **kwargs):
"""
Train the solver method.
"""
return super().fit(self.solver,
train_dataloaders=self._loader,
**kwargs)
datamodule=self.data_module,
**kwargs)
def test(self, **kwargs):
"""
Test the solver method.
"""
return super().test(self.solver,
datamodule=self.data_module,
**kwargs)
@property
def solver(self):