Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver
This commit is contained in:
committed by
Nicola Demo
parent
b9753c34b2
commit
c9304fb9bb
@@ -3,13 +3,13 @@
|
||||
import torch
|
||||
import pytorch_lightning
|
||||
from .utils import check_consistency
|
||||
from .data import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .data import PinaDataModule
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
|
||||
def __init__(self, solver, batch_size=None, **kwargs):
|
||||
def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
|
||||
@@ -31,10 +31,11 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
check_consistency(solver, SolverInterface)
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
|
||||
self.train_size = train_size
|
||||
self.test_size = test_size
|
||||
self.eval_size = eval_size
|
||||
self.solver = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
|
||||
@@ -69,11 +70,12 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
raise RuntimeError("Parallel training is not supported yet.")
|
||||
|
||||
device = devices[0]
|
||||
dataset_phys = SamplePointDataset(self.solver.problem, device)
|
||||
dataset_data = DataPointDataset(self.solver.problem, device)
|
||||
self._loader = SamplePointLoader(
|
||||
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
||||
)
|
||||
|
||||
data_module = PinaDataModule(problem=self.solver.problem, device=device,
|
||||
train_size=self.train_size, test_size=self.test_size,
|
||||
eval_size=self.eval_size)
|
||||
data_module.setup()
|
||||
self._loader = data_module.train_dataloader()
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
@@ -89,3 +91,7 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
Returning trainer solver.
|
||||
"""
|
||||
return self._solver
|
||||
|
||||
@solver.setter
|
||||
def solver(self, solver):
|
||||
self._solver = solver
|
||||
|
||||
Reference in New Issue
Block a user