Update PINNInterface Inheritance (#542)
This commit is contained in:
committed by
FilippoOlivo
parent
88b27605f1
commit
e1863d9318
@@ -5,7 +5,7 @@ import lightning
|
||||
import torch
|
||||
|
||||
from torch._dynamo import OptimizedModule
|
||||
from ..problem import AbstractProblem
|
||||
from ..problem import AbstractProblem, InverseProblem
|
||||
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||
from ..loss import WeightingInterface
|
||||
from ..loss.scalar_weighting import _NoWeighting
|
||||
@@ -64,6 +64,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
self._pina_optimizers = None
|
||||
self._pina_schedulers = None
|
||||
|
||||
# inverse problem handling
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self._params = self.problem.unknown_parameters
|
||||
self._clamp_params = self._clamp_inverse_problem_params
|
||||
else:
|
||||
self._params = None
|
||||
self._clamp_params = lambda: None
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -231,14 +239,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
# compute losses
|
||||
losses = self.optimization_cycle(batch)
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
# store log
|
||||
for name, value in losses.items():
|
||||
self.store_log(
|
||||
f"{name}_loss", value.item(), self.get_batch_size(batch)
|
||||
)
|
||||
# aggregate
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
return loss
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
Clamps the parameters of the inverse problem solver to specified ranges.
|
||||
"""
|
||||
for v in self._params:
|
||||
self._params[v].data.clamp_(
|
||||
self.problem.unknown_parameter_domain.range_[v][0],
|
||||
self.problem.unknown_parameter_domain.range_[v][1],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compile_modules(model):
|
||||
"""
|
||||
@@ -405,6 +428,15 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:rtype: tuple[list[Optimizer], list[Scheduler]]
|
||||
"""
|
||||
self.optimizer.hook(self.model.parameters())
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizer.instance.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
self.scheduler.hook(self.optimizer)
|
||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user