Refactoring solvers (#541)
* Refactoring solvers * Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers * formatter * codacy * fix issues + speedup test
This commit is contained in:
committed by
FilippoOlivo
parent
fa6fda0bd5
commit
1bb3c125ac
@@ -83,15 +83,15 @@ class CausalPINN(PINN):
|
||||
:class:`~pina.problem.time_dependent_problem.TimeDependentProblem`.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:param float eps: The exponential decay parameter. Default is ``100``.
|
||||
:raises ValueError: If the problem is not a TimeDependentProblem.
|
||||
@@ -134,7 +134,7 @@ class CausalPINN(PINN):
|
||||
chunk.labels = labels
|
||||
# classical PINN loss
|
||||
residual = self.compute_residual(samples=chunk, equation=equation)
|
||||
loss_val = self.loss(
|
||||
loss_val = self._loss_fn(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
time_loss.append(loss_val)
|
||||
|
||||
@@ -69,26 +69,26 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param torch.nn.Module discriminator: The discriminator to be used.
|
||||
If `None`, the discriminator is a deepcopy of the ``model``.
|
||||
If ``None``, the discriminator is a deepcopy of the ``model``.
|
||||
Default is ``None``.
|
||||
:param torch.optim.Optimizer optimizer_model: The optimizer of the
|
||||
``model``. If `None`, the :class:`torch.optim.Adam` optimizer is
|
||||
``model``. If ``None``, the :class:`torch.optim.Adam` optimizer is
|
||||
used. Default is ``None``.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The optimizer of
|
||||
the ``discriminator``. If `None`, the :class:`torch.optim.Adam`
|
||||
the ``discriminator``. If ``None``, the :class:`torch.optim.Adam`
|
||||
optimizer is used. Default is ``None``.
|
||||
:param Scheduler scheduler_model: Learning rate scheduler for the
|
||||
``model``.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param Scheduler scheduler_discriminator: Learning rate scheduler for
|
||||
the ``discriminator``.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
"""
|
||||
if discriminator is None:
|
||||
@@ -156,12 +156,27 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
residual = residual * discriminator_bets
|
||||
|
||||
# Compute competitive residual.
|
||||
loss_val = self.loss(
|
||||
loss_val = self._loss_fn(
|
||||
torch.zeros_like(residual, requires_grad=True),
|
||||
residual,
|
||||
)
|
||||
return loss_val
|
||||
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method should
|
||||
not be overridden, if not intentionally.
|
||||
|
||||
:param input: The input to the neural network.
|
||||
:type input: LabelTensor
|
||||
:param target: The target to compare with the network's output.
|
||||
:type target: LabelTensor
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self._loss_fn(self.forward(input), target)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration.
|
||||
@@ -195,24 +210,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
],
|
||||
)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch and overrides
|
||||
the PyTorch Lightning implementation to log checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
|
||||
@@ -75,15 +75,15 @@ class GradientPINN(PINN):
|
||||
gradient of the loss.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Scheduler scheduler: Learning rate scheduler.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:raises ValueError: If the problem is not a SpatialProblem.
|
||||
"""
|
||||
@@ -116,7 +116,7 @@ class GradientPINN(PINN):
|
||||
"""
|
||||
# classical PINN loss
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
loss_value = self.loss(
|
||||
loss_value = self._loss_fn(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
|
||||
@@ -124,7 +124,7 @@ class GradientPINN(PINN):
|
||||
loss_value = loss_value.reshape(-1, 1)
|
||||
loss_value.labels = ["__loss"]
|
||||
loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables)
|
||||
g_loss_phys = self.loss(
|
||||
g_loss_phys = self._loss_fn(
|
||||
torch.zeros_like(loss_grad, requires_grad=True), loss_grad
|
||||
)
|
||||
return loss_value + g_loss_phys
|
||||
|
||||
@@ -62,15 +62,15 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Scheduler scheduler: Learning rate scheduler.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -82,6 +82,21 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method should
|
||||
not be overridden, if not intentionally.
|
||||
|
||||
:param input: The input to the neural network.
|
||||
:type input: LabelTensor
|
||||
:param target: The target to compare with the network's output.
|
||||
:type target: LabelTensor
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self._loss_fn(self.forward(input), target)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the physics-informed solver based on the
|
||||
@@ -92,11 +107,8 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
:return: The computed physics loss.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
loss_value = self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
return loss_value
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self._loss_fn(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
|
||||
@@ -38,7 +38,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`~pina.solver.solver.SolverInterface` class.
|
||||
@@ -53,7 +53,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._loss = loss
|
||||
self._loss_fn = loss
|
||||
|
||||
# inverse problem handling
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
@@ -65,7 +65,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
self.__metric = None
|
||||
|
||||
def optimization_cycle(self, batch):
|
||||
def optimization_cycle(self, batch, loss_residuals=None):
|
||||
"""
|
||||
The optimization cycle for the PINN solver.
|
||||
|
||||
@@ -80,51 +80,74 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._run_optimization_cycle(batch, self.loss_phys)
|
||||
# which losses to use
|
||||
if loss_residuals is None:
|
||||
loss_residuals = self.loss_phys
|
||||
# compute optimization cycle
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if "target" not in points:
|
||||
input_pts = points["input"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points["input"]
|
||||
output_pts = points["target"]
|
||||
loss = self.loss_data(
|
||||
input=input_pts.requires_grad_(), target=output_pts
|
||||
)
|
||||
# append loss
|
||||
condition_loss[condition_name] = loss
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
return condition_loss
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
The validation step for the PINN solver.
|
||||
The validation step for the PINN solver. It returns the average residual
|
||||
computed with the ``loss`` function not aggregated.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the validation step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
return super().validation_step(
|
||||
batch, loss_residuals=self._residual_loss
|
||||
)
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def test_step(self, batch):
|
||||
"""
|
||||
The test step for the PINN solver.
|
||||
The test step for the PINN solver. It returns the average residual
|
||||
computed with the ``loss`` function not aggregated.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the test step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
return super().test_step(batch, loss_residuals=self._residual_loss)
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@abstractmethod
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method should
|
||||
not be overridden, if not intentionally.
|
||||
be overridden by the derived class.
|
||||
|
||||
:param LabelTensor input_pts: The input points to the neural network.
|
||||
:param LabelTensor output_pts: The true solution to compare with the
|
||||
:param LabelTensor input: The input to the neural network.
|
||||
:param LabelTensor target: The target to compare with the
|
||||
network's output.
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: torch.Tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@abstractmethod
|
||||
def loss_phys(self, samples, equation):
|
||||
@@ -159,7 +182,11 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
|
||||
def _residual_loss(self, samples, equation):
|
||||
"""
|
||||
Compute the residual loss.
|
||||
Computes the physics loss for the physics-informed solver based on the
|
||||
provided samples and equation. This method should never be overridden
|
||||
by the user, if not intentionally,
|
||||
since it is used internally to compute validation loss.
|
||||
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the loss.
|
||||
:param EquationInterface equation: The governing equation.
|
||||
@@ -167,43 +194,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self.loss(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def _run_optimization_cycle(self, batch, loss_residuals):
|
||||
"""
|
||||
Compute, given a batch, the loss for each condition and return a
|
||||
dictionary with the condition name as key and the loss as value.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param function loss_residuals: The loss function to be minimized.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if "target" not in points:
|
||||
input_pts = points["input"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points["input"]
|
||||
output_pts = points["target"]
|
||||
loss = self.loss_data(
|
||||
input_pts=input_pts.requires_grad_(), output_pts=output_pts
|
||||
)
|
||||
# append loss
|
||||
condition_loss[condition_name] = loss
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
return condition_loss
|
||||
return self._loss_fn(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
@@ -223,7 +214,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:return: The loss function used for training.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._loss
|
||||
return self._loss_fn
|
||||
|
||||
@property
|
||||
def current_condition_name(self):
|
||||
|
||||
@@ -83,15 +83,15 @@ class RBAPINN(PINN):
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Scheduler scheduler: Learning rate scheduler.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:param float | int eta: The learning rate for the weights of the
|
||||
residuals. Default is ``0.001``.
|
||||
|
||||
@@ -120,24 +120,24 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
:param torch.nn.Module weight_function: The Self-Adaptive mask model.
|
||||
Default is ``torch.nn.Sigmoid()``.
|
||||
:param Optimizer optimizer_model: The optimizer of the ``model``.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Optimizer optimizer_weights: The optimizer of the
|
||||
``weight_function``.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Scheduler scheduler_model: Learning rate scheduler for the
|
||||
``model``.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param Scheduler scheduler_weights: Learning rate scheduler for the
|
||||
``weight_function``.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
If ``None``, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
"""
|
||||
# check consistency weitghs_function
|
||||
@@ -223,24 +223,6 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
[self.scheduler_model.instance, self.scheduler_weights.instance],
|
||||
)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch and overrides
|
||||
the PyTorch Lightning implementation to log checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def on_train_start(self):
|
||||
"""
|
||||
This method is called at the start of the training process to set the
|
||||
@@ -304,6 +286,21 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
)
|
||||
return self._vect_to_scalar(weights * loss_value)
|
||||
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method should
|
||||
not be overridden, if not intentionally.
|
||||
|
||||
:param input: The input to the neural network.
|
||||
:type input: LabelTensor
|
||||
:param target: The target to compare with the network's output.
|
||||
:type target: LabelTensor
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self._loss_fn(self.forward(input), target)
|
||||
|
||||
def _vect_to_scalar(self, loss_value):
|
||||
"""
|
||||
Computation of the scalar loss.
|
||||
|
||||
Reference in New Issue
Block a user