Implementation of DataLoader and DataModule (#383)

Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers

Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2024-11-27 16:01:39 +01:00
committed by Nicola Demo
parent dd43c8304c
commit a27bd35443
34 changed files with 827 additions and 1349 deletions

View File

@@ -9,10 +9,8 @@ except ImportError:
_LRScheduler as LRScheduler,
) # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .basepinn import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem
@@ -56,16 +54,16 @@ class PINN(PINNInterface):
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
"""
__name__ = 'PINN'
def __init__(
self,
problem,
model,
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
loss=None,
optimizer=None,
scheduler=None,
):
"""
:param AbstractProblem problem: The formulation of the problem.
@@ -82,20 +80,15 @@ class PINN(PINNInterface):
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
"""
super().__init__(
models=[model],
models=model,
problem=problem,
optimizers=[optimizer],
optimizers_kwargs=[optimizer_kwargs],
optimizers=optimizer,
schedulers=scheduler,
extra_features=extra_features,
loss=loss,
)
# check consistency
check_consistency(scheduler, LRScheduler, subclass=True)
check_consistency(scheduler_kwargs, dict)
# assign variables
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
self._neural_net = self.models[0]
def forward(self, x):
@@ -126,9 +119,8 @@ class PINN(PINNInterface):
"""
residual = self.compute_residual(samples=samples, equation=equation)
loss_value = self.loss(
torch.zeros_like(residual, requires_grad=True), residual
torch.zeros_like(residual), residual
)
self.store_log(loss_value=float(loss_value))
return loss_value
def configure_optimizers(self):
@@ -141,16 +133,21 @@ class PINN(PINNInterface):
"""
# if the problem is an InverseProblem, add the unknown parameters
# to the parameters that the optimizer needs to optimize
self._optimizer.hook(self._model.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizers[0].add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
return self.optimizers, [self.scheduler]
self._optimizer.optimizer_instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self._scheduler.hook(self._optimizer)
return ([self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance])
@property
def scheduler(self):