minor changes/ trainer update
This commit is contained in:
committed by
Nicola Demo
parent
7528f6ef74
commit
b9753c34b2
@@ -32,29 +32,18 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
|
||||
self._model = solver
|
||||
self.solver = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
|
||||
# create dataloader
|
||||
# if solver.problem.have_sampled_points is False:
|
||||
# raise RuntimeError(
|
||||
# f"Input points in {solver.problem.not_sampled_points} "
|
||||
# "training are None. Please "
|
||||
# "sample points in your problem by calling "
|
||||
# "discretise_domain function before train "
|
||||
# "in the provided locations."
|
||||
# )
|
||||
|
||||
# self._create_or_update_loader()
|
||||
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
|
||||
# move parameters to device
|
||||
pb = self._model.problem
|
||||
pb = self.solver.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
@@ -67,14 +56,21 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
during training, there is no need to define to touch the
|
||||
trainer dataloader, just call the method.
|
||||
"""
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join(
|
||||
[f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
||||
for key, value in self.solver.problem.collector._is_conditions_ready.items()])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
if len(devices) > 1:
|
||||
raise RuntimeError("Parallel training is not supported yet.")
|
||||
|
||||
device = devices[0]
|
||||
dataset_phys = SamplePointDataset(self._model.problem, device)
|
||||
dataset_data = DataPointDataset(self._model.problem, device)
|
||||
dataset_phys = SamplePointDataset(self.solver.problem, device)
|
||||
dataset_data = DataPointDataset(self.solver.problem, device)
|
||||
self._loader = SamplePointLoader(
|
||||
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
||||
)
|
||||
@@ -84,7 +80,7 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(
|
||||
self._model, train_dataloaders=self._loader, **kwargs
|
||||
self.solver, train_dataloaders=self._loader, **kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -92,4 +88,4 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
"""
|
||||
Returning trainer solver.
|
||||
"""
|
||||
return self._model
|
||||
return self._solver
|
||||
|
||||
Reference in New Issue
Block a user