minor changes/ trainer update

This commit is contained in:
Dario Coscia
2024-10-10 19:24:46 +02:00
committed by Nicola Demo
parent 7528f6ef74
commit b9753c34b2
8 changed files with 69 additions and 46 deletions

View File

@@ -32,29 +32,18 @@ class Trainer(pytorch_lightning.Trainer):
if batch_size is not None:
check_consistency(batch_size, int)
self._model = solver
self.solver = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
# create dataloader
# if solver.problem.have_sampled_points is False:
# raise RuntimeError(
# f"Input points in {solver.problem.not_sampled_points} "
# "training are None. Please "
# "sample points in your problem by calling "
# "discretise_domain function before train "
# "in the provided locations."
# )
# self._create_or_update_loader()
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self._model.problem
pb = self.solver.problem
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
@@ -67,14 +56,21 @@ class Trainer(pytorch_lightning.Trainer):
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
devices = self._accelerator_connector._parallel_devices
if len(devices) > 1:
raise RuntimeError("Parallel training is not supported yet.")
device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device)
dataset_data = DataPointDataset(self._model.problem, device)
dataset_phys = SamplePointDataset(self.solver.problem, device)
dataset_data = DataPointDataset(self.solver.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
)
@@ -84,7 +80,7 @@ class Trainer(pytorch_lightning.Trainer):
Train the solver method.
"""
return super().fit(
self._model, train_dataloaders=self._loader, **kwargs
self.solver, train_dataloaders=self._loader, **kwargs
)
@property
@@ -92,4 +88,4 @@ class Trainer(pytorch_lightning.Trainer):
"""
Returning trainer solver.
"""
return self._model
return self._solver