From 6d1d4ef423bf4b5c6d73bc71e5d69cbacc326205 Mon Sep 17 00:00:00 2001 From: Giovanni Canali Date: Tue, 22 Jul 2025 14:45:43 +0200 Subject: [PATCH] add batching support for self-adaptive pinns --- .../self_adaptive_pinn.py | 307 +++++++++++------- tests/test_solver/test_self_adaptive_pinn.py | 41 +-- 2 files changed, 208 insertions(+), 140 deletions(-) diff --git a/pina/solver/physics_informed_solver/self_adaptive_pinn.py b/pina/solver/physics_informed_solver/self_adaptive_pinn.py index 9521556..b1d2a2c 100644 --- a/pina/solver/physics_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physics_informed_solver/self_adaptive_pinn.py @@ -15,15 +15,20 @@ class Weights(torch.nn.Module): :class:`SelfAdaptivePINN` solver. """ - def __init__(self, func): + def __init__(self, func, num_points): """ Initialization of the :class:`Weights` class. :param torch.nn.Module func: the mask model. + :param int num_points: the number of input points. """ super().__init__() + + # Check consistency check_consistency(func, torch.nn.Module) - self.sa_weights = torch.nn.Parameter(torch.Tensor()) + + # Initialize the weights as a learnable parameter + self.sa_weights = torch.nn.Parameter(torch.zeros(num_points, 1)) self.func = func def forward(self): @@ -140,17 +145,17 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. """ - # check consistency weitghs_function + # Check consistency check_consistency(weight_function, torch.nn.Module) - # create models for weights - weights_dict = {} - for condition_name in problem.conditions: - weights_dict[condition_name] = Weights(weight_function) - weights_dict = torch.nn.ModuleDict(weights_dict) + # Define a ModuleDict for the weights + weights = {} + for cond, data in problem.input_pts.items(): + weights[cond] = Weights(func=weight_function, num_points=len(data)) + weights = torch.nn.ModuleDict(weights) super().__init__( - models=[model, weights_dict], + models=[model, weights], problem=problem, optimizers=[optimizer_model, optimizer_weights], schedulers=[scheduler_model, scheduler_weights], @@ -158,116 +163,93 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): loss=loss, ) - self._vectorial_loss = deepcopy(self.loss) - self._vectorial_loss.reduction = "none" + # Extract the reduction method from the loss function + self._reduction = self._loss_fn.reduction - def forward(self, x): - """ - Forward pass. + # Set the loss function to return non-aggregated losses + self._loss_fn = type(self._loss_fn)(reduction="none") - :param LabelTensor x: Input tensor. - :return: The output of the neural network. - :rtype: LabelTensor + def training_step(self, batch, batch_idx, **kwargs): """ - return self.model(x) - - def training_step(self, batch): - """ - Solver training step, overridden to perform manual optimization. + Solver training step. It computes the optimization cycle and aggregates + the losses using the ``weighting`` attribute. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. - :return: The aggregated loss. - :rtype: LabelTensor + :param int batch_idx: The index of the current batch. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the training step. + :rtype: torch.Tensor """ # Weights optimization self.optimizer_weights.instance.zero_grad() - loss = super().training_step(batch) + loss = self._optimization_cycle( + batch=batch, batch_idx=batch_idx, **kwargs + ) self.manual_backward(-loss) self.optimizer_weights.instance.step() self.scheduler_weights.instance.step() # Model optimization self.optimizer_model.instance.zero_grad() - loss = super().training_step(batch) + loss = self._optimization_cycle( + batch=batch, batch_idx=batch_idx, **kwargs + ) self.manual_backward(loss) self.optimizer_model.instance.step() self.scheduler_model.instance.step() + # Log the loss + self.store_log("train_loss", loss, self.get_batch_size(batch)) + return loss - def configure_optimizers(self): + @torch.set_grad_enabled(True) + def validation_step(self, batch, **kwargs): """ - Optimizer configuration. + The validation step for the Self-Adaptive PINN solver. It returns the + average residual computed with the ``loss`` function not aggregated. - :return: The optimizers and the schedulers - :rtype: tuple[list[Optimizer], list[Scheduler]] + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the validation step. + :rtype: torch.Tensor """ - # If the problem is an InverseProblem, add the unknown parameters - # to the parameters to be optimized - self.optimizer_model.hook(self.model.parameters()) - self.optimizer_weights.hook(self.weights_dict.parameters()) - if isinstance(self.problem, InverseProblem): - self.optimizer_model.instance.add_param_group( - { - "params": [ - self._params[var] - for var in self.problem.unknown_variables - ] - } - ) - self.scheduler_model.hook(self.optimizer_model) - self.scheduler_weights.hook(self.optimizer_weights) - return ( - [self.optimizer_model.instance, self.optimizer_weights.instance], - [self.scheduler_model.instance, self.scheduler_weights.instance], - ) + losses = self.optimization_cycle(batch=batch, **kwargs) - def on_train_start(self): + # Aggregate losses for each condition + for cond, loss in losses.items(): + losses[cond] = self._apply_reduction(loss=losses[cond]) + + loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) + self.store_log("val_loss", loss, self.get_batch_size(batch)) + return loss + + @torch.set_grad_enabled(True) + def test_step(self, batch, **kwargs): """ - This method is called at the start of the training process to set the - self-adaptive weights as parameters of the mask model. + The test step for the Self-Adaptive PINN solver. It returns the average + residual computed with the ``loss`` function not aggregated. - :raises NotImplementedError: If the batch size is not ``None``. + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the test step. + :rtype: torch.Tensor """ - if self.trainer.batch_size is not None: - raise NotImplementedError( - "SelfAdaptivePINN only works with full " - "batch size, set batch_size=None inside " - "the Trainer to use the solver." - ) - device = torch.device( - self.trainer._accelerator_connector._accelerator_flag - ) + losses = self.optimization_cycle(batch=batch, **kwargs) - # Initialize the self adaptive weights only for training points - for ( - condition_name, - tensor, - ) in self.trainer.data_module.train_dataset.input.items(): - self.weights_dict[condition_name].sa_weights.data = torch.rand( - (tensor.shape[0], 1), device=device - ) - return super().on_train_start() + # Aggregate losses for each condition + for cond, loss in losses.items(): + losses[cond] = self._apply_reduction(loss=losses[cond]) - def on_load_checkpoint(self, checkpoint): - """ - Override of the Pytorch Lightning ``on_load_checkpoint`` method to - handle checkpoints for Self-Adaptive Weights. This method should not be - overridden, if not intentionally. - - :param dict checkpoint: Pytorch Lightning checkpoint dict. - """ - # First initialize self-adaptive weights with correct shape, - # then load the values from the checkpoint. - for condition_name, _ in self.problem.input_pts.items(): - shape = checkpoint["state_dict"][ - f"_pina_models.1.{condition_name}.sa_weights" - ].shape - self.weights_dict[condition_name].sa_weights.data = torch.rand( - shape - ) - return super().on_load_checkpoint(checkpoint) + loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) + self.store_log("test_loss", loss, self.get_batch_size(batch)) + return loss def loss_phys(self, samples, equation): """ @@ -279,47 +261,138 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): :return: The computed physics loss. :rtype: LabelTensor """ - residual = self.compute_residual(samples, equation) - weights = self.weights_dict[self.current_condition_name].forward() - loss_value = self._vectorial_loss( - torch.zeros_like(residual, requires_grad=True), residual - ) - return self._vect_to_scalar(weights * loss_value) + residuals = self.compute_residual(samples, equation) + return self._loss_fn(residuals, torch.zeros_like(residuals)) def loss_data(self, input, target): """ - Compute the data loss for the PINN solver by evaluating the loss - between the network's output and the true solution. This method should - not be overridden, if not intentionally. + Compute the data loss for the Self-Adaptive PINN solver by evaluating + the loss between the network's output and the true solution. This method + should not be overridden, if not intentionally. :param input: The input to the neural network. - :type input: LabelTensor + :type input: LabelTensor | torch.Tensor :param target: The target to compare with the network's output. - :type target: LabelTensor + :type target: LabelTensor | torch.Tensor :return: The supervised loss, averaged over the number of observations. - :rtype: LabelTensor + :rtype: LabelTensor | torch.Tensor """ return self._loss_fn(self.forward(input), target) - def _vect_to_scalar(self, loss_value): + def forward(self, x): """ - Computation of the scalar loss. + Forward pass. - :param LabelTensor loss_value: the tensor of pointwise losses. - :raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``. - :return: The computed scalar loss. - :rtype: LabelTensor + :param x: Input tensor. + :type x: torch.Tensor | LabelTensor + :return: The output of the neural network. + :rtype: torch.Tensor | LabelTensor """ - if self.loss.reduction == "mean": - ret = torch.mean(loss_value) - elif self.loss.reduction == "sum": - ret = torch.sum(loss_value) - else: - raise RuntimeError( - f"Invalid reduction, got {self.loss.reduction} " - "but expected mean or sum." + return self.model(x) + + def configure_optimizers(self): + """ + Optimizer configuration. + + :return: The optimizers and the schedulers + :rtype: tuple[list[Optimizer], list[Scheduler]] + """ + # Hook the optimizers to the models + self.optimizer_model.hook(self.model.parameters()) + self.optimizer_weights.hook(self.weights.parameters()) + + # Add unknown parameters to optimization list in case of InverseProblem + if isinstance(self.problem, InverseProblem): + self.optimizer_model.instance.add_param_group( + { + "params": [ + self._params[var] + for var in self.problem.unknown_variables + ] + } ) - return ret + + # Hook the schedulers to the optimizers + self.scheduler_model.hook(self.optimizer_model) + self.scheduler_weights.hook(self.optimizer_weights) + + return ( + [self.optimizer_model.instance, self.optimizer_weights.instance], + [self.scheduler_model.instance, self.scheduler_weights.instance], + ) + + def _optimization_cycle(self, batch, batch_idx, **kwargs): + """ + Aggregate the loss for each condition in the batch. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param int batch_idx: The index of the current batch. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The losses computed for all conditions in the batch, casted + to a subclass of :class:`torch.Tensor`. It should return a dict + containing the condition name and the associated scalar loss. + :rtype: dict + """ + # Compute non-aggregated residuals + residuals = self.optimization_cycle(batch) + + # Compute losses + losses = {} + for cond, res in residuals.items(): + + weight_tensor = self.weights[cond]() + + # Get the correct indices for the weights. Modulus is used according + # to the number of points in the condition, as in the PinaDataset. + len_res = len(res) + idx = torch.arange( + batch_idx * len_res, + (batch_idx + 1) * len_res, + device=res.device, + ) % len(self.problem.input_pts[cond]) + + # Apply the weights to the residuals + losses[cond] = self._apply_reduction( + loss=(res * weight_tensor[idx]) + ) + + # Store log + self.store_log( + f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch) + ) + + # Clamp unknown parameters in InverseProblem (if needed) + self._clamp_params() + + # Aggregate + loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) + + return loss + + def _apply_reduction(self, loss): + """ + Apply the specified reduction to the loss. The reduction is deferred + until the end of the optimization cycle to allow self-adaptive weights + to be applied to each point beforehand. + + :param torch.Tensor loss: The loss tensor to be reduced. + :return: The reduced loss tensor. + :rtype: torch.Tensor + :raises ValueError: If the reduction method is neither "mean" nor "sum". + """ + # Apply the specified reduction method + if self._reduction == "mean": + return loss.mean() + if self._reduction == "sum": + return loss.sum() + + # Raise an error if the reduction method is not recognized + raise ValueError( + f"Unknown reduction: {self._reduction}." + " Supported reductions are 'mean' and 'sum'." + ) @property def model(self): @@ -332,7 +405,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): return self.models[0] @property - def weights_dict(self): + def weights(self): """ The self-adaptive weights. diff --git a/tests/test_solver/test_self_adaptive_pinn.py b/tests/test_solver/test_self_adaptive_pinn.py index aba43da..f7ef7b5 100644 --- a/tests/test_solver/test_self_adaptive_pinn.py +++ b/tests/test_solver/test_self_adaptive_pinn.py @@ -42,9 +42,11 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables)) @pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("weight_fn", [torch.nn.Sigmoid(), torch.nn.Tanh()]) def test_constructor(problem, weight_fn): + + solver = SAPINN(problem=problem, model=model, weight_function=weight_fn) + with pytest.raises(ValueError): SAPINN(model=model, problem=problem, weight_function=1) - solver = SAPINN(problem=problem, model=model, weight_function=weight_fn) assert solver.accepted_conditions_types == ( InputTargetCondition, @@ -53,26 +55,13 @@ def test_constructor(problem, weight_fn): ) -@pytest.mark.parametrize("problem", [problem, inverse_problem]) -def test_wrong_batch(problem): - with pytest.raises(NotImplementedError): - solver = SAPINN(model=model, problem=problem) - trainer = Trainer( - solver=solver, - max_epochs=2, - accelerator="cpu", - batch_size=10, - train_size=1.0, - val_size=0.0, - test_size=0.0, - ) - trainer.train() - - @pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("compile", [True, False]) -def test_solver_train(problem, compile): - solver = SAPINN(model=model, problem=problem) +@pytest.mark.parametrize( + "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()] +) +def test_solver_train(problem, compile, loss): + solver = SAPINN(model=model, problem=problem, loss=loss) trainer = Trainer( solver=solver, max_epochs=2, @@ -95,8 +84,11 @@ def test_solver_train(problem, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("compile", [True, False]) -def test_solver_validation(problem, compile): - solver = SAPINN(model=model, problem=problem) +@pytest.mark.parametrize( + "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()] +) +def test_solver_validation(problem, compile, loss): + solver = SAPINN(model=model, problem=problem, loss=loss) trainer = Trainer( solver=solver, max_epochs=2, @@ -119,8 +111,11 @@ def test_solver_validation(problem, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("compile", [True, False]) -def test_solver_test(problem, compile): - solver = SAPINN(model=model, problem=problem) +@pytest.mark.parametrize( + "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()] +) +def test_solver_test(problem, compile, loss): + solver = SAPINN(model=model, problem=problem, loss=loss) trainer = Trainer( solver=solver, max_epochs=2,