diff --git a/pina/solver/physics_informed_solver/rba_pinn.py b/pina/solver/physics_informed_solver/rba_pinn.py index 808ac5a..831f7d4 100644 --- a/pina/solver/physics_informed_solver/rba_pinn.py +++ b/pina/solver/physics_informed_solver/rba_pinn.py @@ -1,6 +1,5 @@ """Module for the Residual-Based Attention PINN solver.""" -from copy import deepcopy import torch from .pinn import PINN @@ -98,6 +97,8 @@ class RBAPINN(PINN): :param float gamma: The decay parameter in the update of the weights of the residuals. Must be between ``0`` and ``1``. Default is ``0.999``. + :raises: ValueError if `gamma` is not in the range (0, 1). + :raises: ValueError if `eta` is not greater than 0. """ super().__init__( model=model, @@ -111,78 +112,201 @@ class RBAPINN(PINN): # check consistency check_consistency(eta, (float, int)) check_consistency(gamma, float) - assert ( - 0 < gamma < 1 - ), f"Invalid range: expected 0 < gamma < 1, got {gamma=}" + + # Validate range for gamma + if not 0 < gamma < 1: + raise ValueError( + f"Invalid range: expected 0 < gamma < 1, but got {gamma}" + ) + + # Validate range for eta + if eta <= 0: + raise ValueError(f"Invalid range: expected eta > 0, but got {eta}") + + # Initialize parameters self.eta = eta self.gamma = gamma - # initialize weights + # Initialize the weight of each point to 0 self.weights = {} - for condition_name in problem.conditions: - self.weights[condition_name] = 0 + for cond, data in self.problem.input_pts.items(): + buffer_tensor = torch.zeros((len(data), 1), device=self.device) + self.register_buffer(f"weight_{cond}", buffer_tensor) + self.weights[cond] = getattr(self, f"weight_{cond}") - # define vectorial loss - self._vectorial_loss = deepcopy(self.loss) - self._vectorial_loss.reduction = "none" + # Extract the reduction method from the loss function + self._reduction = self._loss_fn.reduction - # for now RBAPINN is implemented only for batch_size = None - def on_train_start(self): + # Set the loss function to return non-aggregated losses + self._loss_fn = type(self._loss_fn)(reduction="none") + + def training_step(self, batch, batch_idx, **kwargs): """ - Hook method called at the beginning of training. + Solver training step. It computes the optimization cycle and aggregates + the losses using the ``weighting`` attribute. - :raises NotImplementedError: If the batch size is not ``None``. + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param int batch_idx: The index of the current batch. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the training step. + :rtype: torch.Tensor """ - if self.trainer.batch_size is not None: - raise NotImplementedError( - "RBAPINN only works with full batch " - "size, set batch_size=None inside the " - "Trainer to use the solver." - ) - return super().on_train_start() - - def _vect_to_scalar(self, loss_value): - """ - Computation of the scalar loss. - - :param LabelTensor loss_value: the tensor of pointwise losses. - :raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``. - :return: The computed scalar loss. - :rtype: LabelTensor - """ - if self.loss.reduction == "mean": - ret = torch.mean(loss_value) - elif self.loss.reduction == "sum": - ret = torch.sum(loss_value) - else: - raise RuntimeError( - f"Invalid reduction, got {self.loss.reduction} " - "but expected mean or sum." - ) - return ret - - def loss_phys(self, samples, equation): - """ - Computes the physics loss for the physics-informed solver based on the - provided samples and equation. - - :param LabelTensor samples: The samples to evaluate the physics loss. - :param EquationInterface equation: The governing equation. - :return: The computed physics loss. - :rtype: LabelTensor - """ - residual = self.compute_residual(samples=samples, equation=equation) - cond = self.current_condition_name - - r_norm = ( - self.eta - * torch.abs(residual) - / (torch.max(torch.abs(residual)) + 1e-12) + loss = self._optimization_cycle( + batch=batch, batch_idx=batch_idx, **kwargs ) - self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach() + self.store_log("train_loss", loss, self.get_batch_size(batch)) + return loss - loss_value = self._vectorial_loss( - torch.zeros_like(residual, requires_grad=True), residual + @torch.set_grad_enabled(True) + def validation_step(self, batch, **kwargs): + """ + The validation step for the PINN solver. It returns the average residual + computed with the ``loss`` function not aggregated. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the validation step. + :rtype: torch.Tensor + """ + losses = self.optimization_cycle(batch=batch, **kwargs) + + # Aggregate losses for each condition + for cond, loss in losses.items(): + losses[cond] = self._apply_reduction(loss=losses[cond]) + + loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) + self.store_log("val_loss", loss, self.get_batch_size(batch)) + return loss + + @torch.set_grad_enabled(True) + def test_step(self, batch, **kwargs): + """ + The test step for the PINN solver. It returns the average residual + computed with the ``loss`` function not aggregated. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the test step. + :rtype: torch.Tensor + """ + losses = self.optimization_cycle(batch=batch, **kwargs) + + # Aggregate losses for each condition + for cond, loss in losses.items(): + losses[cond] = self._apply_reduction(loss=losses[cond]) + + loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) + self.store_log("test_loss", loss, self.get_batch_size(batch)) + return loss + + def _optimization_cycle(self, batch, batch_idx, **kwargs): + """ + Aggregate the loss for each condition in the batch. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param int batch_idx: The index of the current batch. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The losses computed for all conditions in the batch, casted + to a subclass of :class:`torch.Tensor`. It should return a dict + containing the condition name and the associated scalar loss. + :rtype: dict + """ + # compute non-aggregated residuals + residuals = self.optimization_cycle(batch) + + # update weights based on residuals + self._update_weights(batch, batch_idx, residuals) + + # compute losses + losses = {} + for cond, res in residuals.items(): + + # Get the correct indices for the weights. Modulus is used according + # to the number of points in the condition, as in the PinaDataset. + len_res = len(res) + idx = torch.arange( + batch_idx * len_res, + (batch_idx + 1) * len_res, + device=res.device, + ) % len(self.problem.input_pts[cond]) + + losses[cond] = self._apply_reduction( + loss=(res * self.weights[cond][idx]) + ) + + # store log + self.store_log( + f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch) + ) + + # clamp unknown parameters in InverseProblem (if needed) + self._clamp_params() + + # aggregate + loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) + + return loss + + def _update_weights(self, batch, batch_idx, residuals): + """ + Update weights based on residuals. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param int batch_idx: The index of the current batch. + :param dict residuals: A dictionary containing the residuals for each + condition. The keys are the condition names and the values are the + residuals as tensors. + """ + # Iterate over each condition in the batch + for cond, data in batch: + + # Compute normalized residuals + res = residuals[cond] + res_abs = res.abs() + r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12) + + # Get the correct indices for the weights. Modulus is used according + # to the number of points in the condition, as in the PinaDataset. + len_pts = len(data["input"]) + idx = torch.arange( + batch_idx * len_pts, + (batch_idx + 1) * len_pts, + device=res.device, + ) % len(self.problem.input_pts[cond]) + + # Update weights + weights = self.weights[cond] + update = self.gamma * weights[idx] + r_norm + weights[idx] = update.detach() + + def _apply_reduction(self, loss): + """ + Apply the specified reduction to the loss. The reduction is deferred + until the end of the optimization cycle to allow residual-based weights + to be applied to each point beforehand. + + :param torch.Tensor loss: The loss tensor to be reduced. + :return: The reduced loss tensor. + :rtype: torch.Tensor + :raises ValueError: If the reduction method is neither "mean" nor "sum". + """ + # Apply the specified reduction method + if self._reduction == "mean": + return loss.mean() + if self._reduction == "sum": + return loss.sum() + + # Raise an error if the reduction method is not recognized + raise ValueError( + f"Unknown reduction: {self._reduction}." + " Supported reductions are 'mean' and 'sum'." ) - - return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value) diff --git a/tests/test_solver/test_rba_pinn.py b/tests/test_solver/test_rba_pinn.py index 8eaf340..92ef539 100644 --- a/tests/test_solver/test_rba_pinn.py +++ b/tests/test_solver/test_rba_pinn.py @@ -42,10 +42,14 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables)) @pytest.mark.parametrize("eta", [1, 0.001]) @pytest.mark.parametrize("gamma", [0.5, 0.9]) def test_constructor(problem, eta, gamma): - with pytest.raises(AssertionError): - solver = RBAPINN(model=model, problem=problem, gamma=1.5) solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma) + with pytest.raises(ValueError): + solver = RBAPINN(model=model, problem=problem, gamma=1.5) + + with pytest.raises(ValueError): + solver = RBAPINN(model=model, problem=problem, eta=-0.1) + assert solver.accepted_conditions_types == ( InputTargetCondition, InputEquationCondition, @@ -54,30 +58,18 @@ def test_constructor(problem, eta, gamma): @pytest.mark.parametrize("problem", [problem, inverse_problem]) -def test_wrong_batch(problem): - with pytest.raises(NotImplementedError): - solver = RBAPINN(model=model, problem=problem) - trainer = Trainer( - solver=solver, - max_epochs=2, - accelerator="cpu", - batch_size=10, - train_size=1.0, - val_size=0.0, - test_size=0.0, - ) - trainer.train() - - -@pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("compile", [True, False]) -def test_solver_train(problem, compile): - solver = RBAPINN(model=model, problem=problem) +@pytest.mark.parametrize( + "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()] +) +def test_solver_train(problem, batch_size, loss, compile): + solver = RBAPINN(model=model, problem=problem, loss=loss) trainer = Trainer( solver=solver, max_epochs=2, accelerator="cpu", - batch_size=None, + batch_size=batch_size, train_size=1.0, val_size=0.0, test_size=0.0, @@ -89,14 +81,18 @@ def test_solver_train(problem, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("compile", [True, False]) -def test_solver_validation(problem, compile): - solver = RBAPINN(model=model, problem=problem) +@pytest.mark.parametrize( + "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()] +) +def test_solver_validation(problem, batch_size, loss, compile): + solver = RBAPINN(model=model, problem=problem, loss=loss) trainer = Trainer( solver=solver, max_epochs=2, accelerator="cpu", - batch_size=None, + batch_size=batch_size, train_size=0.9, val_size=0.1, test_size=0.0, @@ -108,14 +104,18 @@ def test_solver_validation(problem, compile): @pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("compile", [True, False]) -def test_solver_test(problem, compile): - solver = RBAPINN(model=model, problem=problem) +@pytest.mark.parametrize( + "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()] +) +def test_solver_test(problem, batch_size, loss, compile): + solver = RBAPINN(model=model, problem=problem, loss=loss) trainer = Trainer( solver=solver, max_epochs=2, accelerator="cpu", - batch_size=None, + batch_size=batch_size, train_size=0.7, val_size=0.2, test_size=0.1,