🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -5,6 +5,7 @@ from .utils import check_consistency
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, **kwargs):
@@ -29,18 +30,20 @@ class Trainer(pytorch_lightning.Trainer):
check_consistency(solver, SolverInterface)
if batch_size is not None:
check_consistency(batch_size, int)
self._model = solver
self.batch_size = batch_size
# create dataloader
if solver.problem.have_sampled_points is False:
raise RuntimeError(f'Input points in {solver.problem.not_sampled_points} '
'training are None. Please '
'sample points in your problem by calling '
'discretise_domain function before train '
'in the provided locations.')
raise RuntimeError(
f"Input points in {solver.problem.not_sampled_points} "
"training are None. Please "
"sample points in your problem by calling "
"discretise_domain function before train "
"in the provided locations."
)
self._create_or_update_loader()
def _create_or_update_loader(self):
@@ -52,21 +55,23 @@ class Trainer(pytorch_lightning.Trainer):
devices = self._accelerator_connector._parallel_devices
if len(devices) > 1:
raise RuntimeError('Parallel training is not supported yet.')
raise RuntimeError("Parallel training is not supported yet.")
device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device)
dataset_data = DataPointDataset(self._model.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size,
shuffle=True)
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
)
def train(self, **kwargs):
"""
Train the solver method.
"""
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)
return super().fit(
self._model, train_dataloaders=self._loader, **kwargs
)
@property
def solver(self):
"""