add batching support for self-adaptive pinns
This commit is contained in:
committed by
Giovanni Canali
parent
1ed14916f1
commit
6d1d4ef423
@@ -15,15 +15,20 @@ class Weights(torch.nn.Module):
|
||||
:class:`SelfAdaptivePINN` solver.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
def __init__(self, func, num_points):
|
||||
"""
|
||||
Initialization of the :class:`Weights` class.
|
||||
|
||||
:param torch.nn.Module func: the mask model.
|
||||
:param int num_points: the number of input points.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_consistency(func, torch.nn.Module)
|
||||
self.sa_weights = torch.nn.Parameter(torch.Tensor())
|
||||
|
||||
# Initialize the weights as a learnable parameter
|
||||
self.sa_weights = torch.nn.Parameter(torch.zeros(num_points, 1))
|
||||
self.func = func
|
||||
|
||||
def forward(self):
|
||||
@@ -140,17 +145,17 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
"""
|
||||
# check consistency weitghs_function
|
||||
# Check consistency
|
||||
check_consistency(weight_function, torch.nn.Module)
|
||||
|
||||
# create models for weights
|
||||
weights_dict = {}
|
||||
for condition_name in problem.conditions:
|
||||
weights_dict[condition_name] = Weights(weight_function)
|
||||
weights_dict = torch.nn.ModuleDict(weights_dict)
|
||||
# Define a ModuleDict for the weights
|
||||
weights = {}
|
||||
for cond, data in problem.input_pts.items():
|
||||
weights[cond] = Weights(func=weight_function, num_points=len(data))
|
||||
weights = torch.nn.ModuleDict(weights)
|
||||
|
||||
super().__init__(
|
||||
models=[model, weights_dict],
|
||||
models=[model, weights],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
schedulers=[scheduler_model, scheduler_weights],
|
||||
@@ -158,116 +163,93 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
self._vectorial_loss = deepcopy(self.loss)
|
||||
self._vectorial_loss.reduction = "none"
|
||||
# Extract the reduction method from the loss function
|
||||
self._reduction = self._loss_fn.reduction
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass.
|
||||
# Set the loss function to return non-aggregated losses
|
||||
self._loss_fn = type(self._loss_fn)(reduction="none")
|
||||
|
||||
:param LabelTensor x: Input tensor.
|
||||
:return: The output of the neural network.
|
||||
:rtype: LabelTensor
|
||||
def training_step(self, batch, batch_idx, **kwargs):
|
||||
"""
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch):
|
||||
"""
|
||||
Solver training step, overridden to perform manual optimization.
|
||||
Solver training step. It computes the optimization cycle and aggregates
|
||||
the losses using the ``weighting`` attribute.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The aggregated loss.
|
||||
:rtype: LabelTensor
|
||||
:param int batch_idx: The index of the current batch.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The loss of the training step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Weights optimization
|
||||
self.optimizer_weights.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
loss = self._optimization_cycle(
|
||||
batch=batch, batch_idx=batch_idx, **kwargs
|
||||
)
|
||||
self.manual_backward(-loss)
|
||||
self.optimizer_weights.instance.step()
|
||||
self.scheduler_weights.instance.step()
|
||||
|
||||
# Model optimization
|
||||
self.optimizer_model.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
loss = self._optimization_cycle(
|
||||
batch=batch, batch_idx=batch_idx, **kwargs
|
||||
)
|
||||
self.manual_backward(loss)
|
||||
self.optimizer_model.instance.step()
|
||||
self.scheduler_model.instance.step()
|
||||
|
||||
# Log the loss
|
||||
self.store_log("train_loss", loss, self.get_batch_size(batch))
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
@torch.set_grad_enabled(True)
|
||||
def validation_step(self, batch, **kwargs):
|
||||
"""
|
||||
Optimizer configuration.
|
||||
The validation step for the Self-Adaptive PINN solver. It returns the
|
||||
average residual computed with the ``loss`` function not aggregated.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple[list[Optimizer], list[Scheduler]]
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The loss of the validation step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# If the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters to be optimized
|
||||
self.optimizer_model.hook(self.model.parameters())
|
||||
self.optimizer_weights.hook(self.weights_dict.parameters())
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizer_model.instance.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
self.scheduler_model.hook(self.optimizer_model)
|
||||
self.scheduler_weights.hook(self.optimizer_weights)
|
||||
return (
|
||||
[self.optimizer_model.instance, self.optimizer_weights.instance],
|
||||
[self.scheduler_model.instance, self.scheduler_weights.instance],
|
||||
)
|
||||
losses = self.optimization_cycle(batch=batch, **kwargs)
|
||||
|
||||
def on_train_start(self):
|
||||
# Aggregate losses for each condition
|
||||
for cond, loss in losses.items():
|
||||
losses[cond] = self._apply_reduction(loss=losses[cond])
|
||||
|
||||
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def test_step(self, batch, **kwargs):
|
||||
"""
|
||||
This method is called at the start of the training process to set the
|
||||
self-adaptive weights as parameters of the mask model.
|
||||
The test step for the Self-Adaptive PINN solver. It returns the average
|
||||
residual computed with the ``loss`` function not aggregated.
|
||||
|
||||
:raises NotImplementedError: If the batch size is not ``None``.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The loss of the test step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self.trainer.batch_size is not None:
|
||||
raise NotImplementedError(
|
||||
"SelfAdaptivePINN only works with full "
|
||||
"batch size, set batch_size=None inside "
|
||||
"the Trainer to use the solver."
|
||||
)
|
||||
device = torch.device(
|
||||
self.trainer._accelerator_connector._accelerator_flag
|
||||
)
|
||||
losses = self.optimization_cycle(batch=batch, **kwargs)
|
||||
|
||||
# Initialize the self adaptive weights only for training points
|
||||
for (
|
||||
condition_name,
|
||||
tensor,
|
||||
) in self.trainer.data_module.train_dataset.input.items():
|
||||
self.weights_dict[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1), device=device
|
||||
)
|
||||
return super().on_train_start()
|
||||
# Aggregate losses for each condition
|
||||
for cond, loss in losses.items():
|
||||
losses[cond] = self._apply_reduction(loss=losses[cond])
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
"""
|
||||
Override of the Pytorch Lightning ``on_load_checkpoint`` method to
|
||||
handle checkpoints for Self-Adaptive Weights. This method should not be
|
||||
overridden, if not intentionally.
|
||||
|
||||
:param dict checkpoint: Pytorch Lightning checkpoint dict.
|
||||
"""
|
||||
# First initialize self-adaptive weights with correct shape,
|
||||
# then load the values from the checkpoint.
|
||||
for condition_name, _ in self.problem.input_pts.items():
|
||||
shape = checkpoint["state_dict"][
|
||||
f"_pina_models.1.{condition_name}.sa_weights"
|
||||
].shape
|
||||
self.weights_dict[condition_name].sa_weights.data = torch.rand(
|
||||
shape
|
||||
)
|
||||
return super().on_load_checkpoint(checkpoint)
|
||||
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
@@ -279,47 +261,138 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
:return: The computed physics loss.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
residual = self.compute_residual(samples, equation)
|
||||
weights = self.weights_dict[self.current_condition_name].forward()
|
||||
loss_value = self._vectorial_loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
return self._vect_to_scalar(weights * loss_value)
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self._loss_fn(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
between the network's output and the true solution. This method should
|
||||
not be overridden, if not intentionally.
|
||||
Compute the data loss for the Self-Adaptive PINN solver by evaluating
|
||||
the loss between the network's output and the true solution. This method
|
||||
should not be overridden, if not intentionally.
|
||||
|
||||
:param input: The input to the neural network.
|
||||
:type input: LabelTensor
|
||||
:type input: LabelTensor | torch.Tensor
|
||||
:param target: The target to compare with the network's output.
|
||||
:type target: LabelTensor
|
||||
:type target: LabelTensor | torch.Tensor
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: LabelTensor
|
||||
:rtype: LabelTensor | torch.Tensor
|
||||
"""
|
||||
return self._loss_fn(self.forward(input), target)
|
||||
|
||||
def _vect_to_scalar(self, loss_value):
|
||||
def forward(self, x):
|
||||
"""
|
||||
Computation of the scalar loss.
|
||||
Forward pass.
|
||||
|
||||
:param LabelTensor loss_value: the tensor of pointwise losses.
|
||||
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
|
||||
:return: The computed scalar loss.
|
||||
:rtype: LabelTensor
|
||||
:param x: Input tensor.
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The output of the neural network.
|
||||
:rtype: torch.Tensor | LabelTensor
|
||||
"""
|
||||
if self.loss.reduction == "mean":
|
||||
ret = torch.mean(loss_value)
|
||||
elif self.loss.reduction == "sum":
|
||||
ret = torch.sum(loss_value)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid reduction, got {self.loss.reduction} "
|
||||
"but expected mean or sum."
|
||||
return self.model(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple[list[Optimizer], list[Scheduler]]
|
||||
"""
|
||||
# Hook the optimizers to the models
|
||||
self.optimizer_model.hook(self.model.parameters())
|
||||
self.optimizer_weights.hook(self.weights.parameters())
|
||||
|
||||
# Add unknown parameters to optimization list in case of InverseProblem
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizer_model.instance.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
return ret
|
||||
|
||||
# Hook the schedulers to the optimizers
|
||||
self.scheduler_model.hook(self.optimizer_model)
|
||||
self.scheduler_weights.hook(self.optimizer_weights)
|
||||
|
||||
return (
|
||||
[self.optimizer_model.instance, self.optimizer_weights.instance],
|
||||
[self.scheduler_model.instance, self.scheduler_weights.instance],
|
||||
)
|
||||
|
||||
def _optimization_cycle(self, batch, batch_idx, **kwargs):
|
||||
"""
|
||||
Aggregate the loss for each condition in the batch.
|
||||
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
:param dict kwargs: Additional keyword arguments passed to
|
||||
``optimization_cycle``.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Compute non-aggregated residuals
|
||||
residuals = self.optimization_cycle(batch)
|
||||
|
||||
# Compute losses
|
||||
losses = {}
|
||||
for cond, res in residuals.items():
|
||||
|
||||
weight_tensor = self.weights[cond]()
|
||||
|
||||
# Get the correct indices for the weights. Modulus is used according
|
||||
# to the number of points in the condition, as in the PinaDataset.
|
||||
len_res = len(res)
|
||||
idx = torch.arange(
|
||||
batch_idx * len_res,
|
||||
(batch_idx + 1) * len_res,
|
||||
device=res.device,
|
||||
) % len(self.problem.input_pts[cond])
|
||||
|
||||
# Apply the weights to the residuals
|
||||
losses[cond] = self._apply_reduction(
|
||||
loss=(res * weight_tensor[idx])
|
||||
)
|
||||
|
||||
# Store log
|
||||
self.store_log(
|
||||
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
|
||||
)
|
||||
|
||||
# Clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
|
||||
# Aggregate
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
|
||||
return loss
|
||||
|
||||
def _apply_reduction(self, loss):
|
||||
"""
|
||||
Apply the specified reduction to the loss. The reduction is deferred
|
||||
until the end of the optimization cycle to allow self-adaptive weights
|
||||
to be applied to each point beforehand.
|
||||
|
||||
:param torch.Tensor loss: The loss tensor to be reduced.
|
||||
:return: The reduced loss tensor.
|
||||
:rtype: torch.Tensor
|
||||
:raises ValueError: If the reduction method is neither "mean" nor "sum".
|
||||
"""
|
||||
# Apply the specified reduction method
|
||||
if self._reduction == "mean":
|
||||
return loss.mean()
|
||||
if self._reduction == "sum":
|
||||
return loss.sum()
|
||||
|
||||
# Raise an error if the reduction method is not recognized
|
||||
raise ValueError(
|
||||
f"Unknown reduction: {self._reduction}."
|
||||
" Supported reductions are 'mean' and 'sum'."
|
||||
)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
@@ -332,7 +405,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def weights_dict(self):
|
||||
def weights(self):
|
||||
"""
|
||||
The self-adaptive weights.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user