Update PINNInterface Inheritance (#542)

This commit is contained in:
Dario Coscia
2025-04-14 09:45:01 +02:00
committed by FilippoOlivo
parent 88b27605f1
commit e1863d9318
2 changed files with 46 additions and 51 deletions

View File

@@ -5,7 +5,7 @@ import lightning
import torch
from torch._dynamo import OptimizedModule
from ..problem import AbstractProblem
from ..problem import AbstractProblem, InverseProblem
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from ..loss import WeightingInterface
from ..loss.scalar_weighting import _NoWeighting
@@ -64,6 +64,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self._pina_optimizers = None
self._pina_schedulers = None
# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
self._clamp_params = self._clamp_inverse_problem_params
else:
self._params = None
self._clamp_params = lambda: None
@abstractmethod
def forward(self, *args, **kwargs):
"""
@@ -231,14 +239,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
containing the condition name and the associated scalar loss.
:rtype: dict
"""
# compute losses
losses = self.optimization_cycle(batch)
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
# store log
for name, value in losses.items():
self.store_log(
f"{name}_loss", value.item(), self.get_batch_size(batch)
)
# aggregate
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem solver to specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)
@staticmethod
def _compile_modules(model):
"""
@@ -405,6 +428,15 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
:rtype: tuple[list[Optimizer], list[Scheduler]]
"""
self.optimizer.hook(self.model.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizer.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self.scheduler.hook(self.optimizer)
return ([self.optimizer.instance], [self.scheduler.instance])