Refactoring solvers (#541)
* Refactoring solvers * Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers * formatter * codacy * fix issues + speedup test
This commit is contained in:
@@ -38,7 +38,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`~pina.solver.solver.SolverInterface` class.
|
||||
@@ -53,7 +53,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._loss = loss
|
||||
self._loss_fn = loss
|
||||
|
||||
# inverse problem handling
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
@@ -65,7 +65,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
self.__metric = None
|
||||
|
||||
def optimization_cycle(self, batch):
|
||||
def optimization_cycle(self, batch, loss_residuals=None):
|
||||
"""
|
||||
The optimization cycle for the PINN solver.
|
||||
|
||||
@@ -80,51 +80,74 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._run_optimization_cycle(batch, self.loss_phys)
|
||||
# which losses to use
|
||||
if loss_residuals is None:
|
||||
loss_residuals = self.loss_phys
|
||||
# compute optimization cycle
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if "target" not in points:
|
||||
input_pts = points["input"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points["input"]
|
||||
output_pts = points["target"]
|
||||
loss = self.loss_data(
|
||||
input=input_pts.requires_grad_(), target=output_pts
|
||||
)
|
||||
# append loss
|
||||
condition_loss[condition_name] = loss
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
return condition_loss
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
The validation step for the PINN solver.
|
||||
The validation step for the PINN solver. It returns the average residual
|
||||
computed with the ``loss`` function not aggregated.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the validation step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
return super().validation_step(
|
||||
batch, loss_residuals=self._residual_loss
|
||||
)
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def test_step(self, batch):
|
||||
"""
|
||||
The test step for the PINN solver.
|
||||
The test step for the PINN solver. It returns the average residual
|
||||
computed with the ``loss`` function not aggregated.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the test step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
return super().test_step(batch, loss_residuals=self._residual_loss)
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@abstractmethod
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method should
|
||||
not be overridden, if not intentionally.
|
||||
be overridden by the derived class.
|
||||
|
||||
:param LabelTensor input_pts: The input points to the neural network.
|
||||
:param LabelTensor output_pts: The true solution to compare with the
|
||||
:param LabelTensor input: The input to the neural network.
|
||||
:param LabelTensor target: The target to compare with the
|
||||
network's output.
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: torch.Tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@abstractmethod
|
||||
def loss_phys(self, samples, equation):
|
||||
@@ -159,7 +182,11 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
def _residual_loss(self, samples, equation):
|
||||
"""
|
||||
Compute the residual loss.
|
||||
Computes the physics loss for the physics-informed solver based on the
|
||||
provided samples and equation. This method should never be overridden
|
||||
by the user, if not intentionally,
|
||||
since it is used internally to compute validation loss.
|
||||
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the loss.
|
||||
:param EquationInterface equation: The governing equation.
|
||||
@@ -167,43 +194,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self.loss(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def _run_optimization_cycle(self, batch, loss_residuals):
|
||||
"""
|
||||
Compute, given a batch, the loss for each condition and return a
|
||||
dictionary with the condition name as key and the loss as value.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param function loss_residuals: The loss function to be minimized.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if "target" not in points:
|
||||
input_pts = points["input"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points["input"]
|
||||
output_pts = points["target"]
|
||||
loss = self.loss_data(
|
||||
input_pts=input_pts.requires_grad_(), output_pts=output_pts
|
||||
)
|
||||
# append loss
|
||||
condition_loss[condition_name] = loss
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
return condition_loss
|
||||
return self._loss_fn(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
@@ -223,7 +214,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:return: The loss function used for training.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._loss
|
||||
return self._loss_fn
|
||||
|
||||
@property
|
||||
def current_condition_name(self):
|
||||
|
||||
Reference in New Issue
Block a user