supervised working
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import pytorch_lightning
|
||||
from .utils import check_consistency
|
||||
from .data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .data import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
|
||||
@@ -35,19 +35,33 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
self._model = solver
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
|
||||
# create dataloader
|
||||
if solver.problem.have_sampled_points is False:
|
||||
raise RuntimeError(
|
||||
f"Input points in {solver.problem.not_sampled_points} "
|
||||
"training are None. Please "
|
||||
"sample points in your problem by calling "
|
||||
"discretise_domain function before train "
|
||||
"in the provided locations."
|
||||
)
|
||||
# if solver.problem.have_sampled_points is False:
|
||||
# raise RuntimeError(
|
||||
# f"Input points in {solver.problem.not_sampled_points} "
|
||||
# "training are None. Please "
|
||||
# "sample points in your problem by calling "
|
||||
# "discretise_domain function before train "
|
||||
# "in the provided locations."
|
||||
# )
|
||||
|
||||
self._create_or_update_loader()
|
||||
# self._create_or_update_loader()
|
||||
|
||||
def _create_or_update_loader(self):
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
|
||||
# move parameters to device
|
||||
pb = self._model.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
|
||||
def _create_loader(self):
|
||||
"""
|
||||
This method is used here because is resampling is needed
|
||||
during training, there is no need to define to touch the
|
||||
@@ -64,12 +78,6 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
self._loader = SamplePointLoader(
|
||||
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
||||
)
|
||||
pb = self._model.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user