Implementation of DataLoader and DataModule (#383)
Refactoring for 0.2 * Data module, data loader and dataset * Refactor LabelTensor * Refactor solvers Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
dd43c8304c
commit
a27bd35443
@@ -1,20 +1,21 @@
|
||||
""" Trainer module. """
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
import pytorch_lightning
|
||||
import lightning
|
||||
from .utils import check_consistency
|
||||
from .data import PinaDataModule
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
|
||||
def __init__(self,
|
||||
solver,
|
||||
batch_size=None,
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
eval_size=.1,
|
||||
val_size=.1,
|
||||
predict_size=.0,
|
||||
**kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
@@ -39,11 +40,13 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
check_consistency(batch_size, int)
|
||||
self.train_size = train_size
|
||||
self.test_size = test_size
|
||||
self.eval_size = eval_size
|
||||
self.val_size = val_size
|
||||
self.predict_size = predict_size
|
||||
self.solver = solver
|
||||
self.batch_size = batch_size
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
self.data_module = None
|
||||
self._create_loader()
|
||||
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
@@ -64,34 +67,34 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join([
|
||||
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
||||
"not sampled"}""" for key, value in
|
||||
"not sampled"}""" for key, value in
|
||||
self._solver.problem.collector._is_conditions_ready.items()
|
||||
])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
if len(devices) > 1:
|
||||
raise RuntimeError("Parallel training is not supported yet.")
|
||||
|
||||
device = devices[0]
|
||||
|
||||
data_module = PinaDataModule(problem=self.solver.problem,
|
||||
device=device,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.eval_size)
|
||||
data_module.setup()
|
||||
self._loader = data_module.train_dataloader()
|
||||
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.val_size,
|
||||
predict_size=self.predict_size,
|
||||
batch_size=self.batch_size,)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(self.solver,
|
||||
train_dataloaders=self._loader,
|
||||
**kwargs)
|
||||
datamodule=self.data_module,
|
||||
**kwargs)
|
||||
|
||||
def test(self, **kwargs):
|
||||
"""
|
||||
Test the solver method.
|
||||
"""
|
||||
return super().test(self.solver,
|
||||
datamodule=self.data_module,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
|
||||
Reference in New Issue
Block a user