supervised working

This commit is contained in:
Nicola Demo
2024-08-08 16:19:52 +02:00
parent 5245a0b68c
commit 9d9c2aa23e
61 changed files with 375 additions and 262 deletions

View File

@@ -3,7 +3,7 @@
import torch
import pytorch_lightning
from .utils import check_consistency
from .data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from .data import SamplePointDataset, SamplePointLoader, DataPointDataset
from .solvers.solver import SolverInterface
@@ -35,19 +35,33 @@ class Trainer(pytorch_lightning.Trainer):
self._model = solver
self.batch_size = batch_size
self._create_loader()
self._move_to_device()
# create dataloader
if solver.problem.have_sampled_points is False:
raise RuntimeError(
f"Input points in {solver.problem.not_sampled_points} "
"training are None. Please "
"sample points in your problem by calling "
"discretise_domain function before train "
"in the provided locations."
)
# if solver.problem.have_sampled_points is False:
# raise RuntimeError(
# f"Input points in {solver.problem.not_sampled_points} "
# "training are None. Please "
# "sample points in your problem by calling "
# "discretise_domain function before train "
# "in the provided locations."
# )
self._create_or_update_loader()
# self._create_or_update_loader()
def _create_or_update_loader(self):
def _move_to_device(self):
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self._model.problem
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device)
)
def _create_loader(self):
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
@@ -64,12 +78,6 @@ class Trainer(pytorch_lightning.Trainer):
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
)
pb = self._model.problem
if hasattr(pb, "unknown_parameters"):
for key in pb.unknown_parameters:
pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device)
)
def train(self, **kwargs):
"""