Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver

This commit is contained in:
FilippoOlivo
2024-10-16 11:24:37 +02:00
committed by Nicola Demo
parent b9753c34b2
commit c9304fb9bb
30 changed files with 770 additions and 784 deletions

View File

@@ -3,13 +3,13 @@
import torch
import pytorch_lightning
from .utils import check_consistency
from .data import SamplePointDataset, SamplePointLoader, DataPointDataset
from .data import PinaDataModule
from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, **kwargs):
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
"""
PINA Trainer class for costumizing every aspect of training via flags.
@@ -31,10 +31,11 @@ class Trainer(pytorch_lightning.Trainer):
check_consistency(solver, SolverInterface)
if batch_size is not None:
check_consistency(batch_size, int)
self.train_size = train_size
self.test_size = test_size
self.eval_size = eval_size
self.solver = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
@@ -69,11 +70,12 @@ class Trainer(pytorch_lightning.Trainer):
raise RuntimeError("Parallel training is not supported yet.")
device = devices[0]
dataset_phys = SamplePointDataset(self.solver.problem, device)
dataset_data = DataPointDataset(self.solver.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
)
data_module = PinaDataModule(problem=self.solver.problem, device=device,
train_size=self.train_size, test_size=self.test_size,
eval_size=self.eval_size)
data_module.setup()
self._loader = data_module.train_dataloader()
def train(self, **kwargs):
"""
@@ -89,3 +91,7 @@ class Trainer(pytorch_lightning.Trainer):
Returning trainer solver.
"""
return self._solver
@solver.setter
def solver(self, solver):
self._solver = solver