diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 290629b..05c62a3 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -68,6 +68,8 @@ Solvers SolverInterface SingleSolverInterface MultiSolverInterface + SupervisedSolverInterface + DeepEnsembleSolverInterface PINNInterface PINN GradientPINN @@ -75,8 +77,10 @@ Solvers CompetitivePINN SelfAdaptivePINN RBAPINN - SupervisedSolver - ReducedOrderModelSolver + DeepEnsemblePINN + SupervisedSolver + DeepEnsembleSupervisedSolver + ReducedOrderModelSolver GAROM diff --git a/docs/source/_rst/solver/ensemble_solver/ensemble_pinn.rst b/docs/source/_rst/solver/ensemble_solver/ensemble_pinn.rst new file mode 100644 index 0000000..2e42dcf --- /dev/null +++ b/docs/source/_rst/solver/ensemble_solver/ensemble_pinn.rst @@ -0,0 +1,8 @@ +DeepEnsemblePINN +================== +.. currentmodule:: pina.solver.ensemble_solver.ensemble_pinn + +.. autoclass:: DeepEnsemblePINN + :show-inheritance: + :members: + diff --git a/docs/source/_rst/solver/ensemble_solver/ensemble_solver_interface.rst b/docs/source/_rst/solver/ensemble_solver/ensemble_solver_interface.rst new file mode 100644 index 0000000..664bb8c --- /dev/null +++ b/docs/source/_rst/solver/ensemble_solver/ensemble_solver_interface.rst @@ -0,0 +1,8 @@ +DeepEnsembleSolverInterface +============================= +.. currentmodule:: pina.solver.ensemble_solver.ensemble_solver_interface + +.. autoclass:: DeepEnsembleSolverInterface + :show-inheritance: + :members: + diff --git a/docs/source/_rst/solver/ensemble_solver/ensemble_supervised.rst b/docs/source/_rst/solver/ensemble_solver/ensemble_supervised.rst new file mode 100644 index 0000000..575b285 --- /dev/null +++ b/docs/source/_rst/solver/ensemble_solver/ensemble_supervised.rst @@ -0,0 +1,8 @@ +DeepEnsembleSupervisedSolver +============================= +.. currentmodule:: pina.solver.ensemble_solver.ensemble_supervised + +.. autoclass:: DeepEnsembleSupervisedSolver + :show-inheritance: + :members: + diff --git a/docs/source/_rst/solver/reduced_order_model.rst b/docs/source/_rst/solver/supervised_solver/reduced_order_model.rst similarity index 64% rename from docs/source/_rst/solver/reduced_order_model.rst rename to docs/source/_rst/solver/supervised_solver/reduced_order_model.rst index 33a9095..878014c 100644 --- a/docs/source/_rst/solver/reduced_order_model.rst +++ b/docs/source/_rst/solver/supervised_solver/reduced_order_model.rst @@ -1,6 +1,6 @@ ReducedOrderModelSolver ========================== -.. currentmodule:: pina.solver.reduced_order_model +.. currentmodule:: pina.solver.supervised_solver.reduced_order_model .. autoclass:: ReducedOrderModelSolver :members: diff --git a/docs/source/_rst/solver/supervised.rst b/docs/source/_rst/solver/supervised_solver/supervised.rst similarity index 63% rename from docs/source/_rst/solver/supervised.rst rename to docs/source/_rst/solver/supervised_solver/supervised.rst index 19978f9..60ffdf8 100644 --- a/docs/source/_rst/solver/supervised.rst +++ b/docs/source/_rst/solver/supervised_solver/supervised.rst @@ -1,6 +1,6 @@ SupervisedSolver =================== -.. currentmodule:: pina.solver.supervised +.. currentmodule:: pina.solver.supervised_solver.supervised .. autoclass:: SupervisedSolver :members: diff --git a/docs/source/_rst/solver/supervised_solver/supervised_solver_interface.rst b/docs/source/_rst/solver/supervised_solver/supervised_solver_interface.rst new file mode 100644 index 0000000..4903a18 --- /dev/null +++ b/docs/source/_rst/solver/supervised_solver/supervised_solver_interface.rst @@ -0,0 +1,8 @@ +SupervisedSolverInterface +========================== +.. currentmodule:: pina.solver.supervised_solver.supervised_solver_interface + +.. autoclass:: SupervisedSolverInterface + :show-inheritance: + :members: + diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index c89c626..43f1807 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -11,13 +11,33 @@ __all__ = [ "CompetitivePINN", "SelfAdaptivePINN", "RBAPINN", + "SupervisedSolverInterface", "SupervisedSolver", "ReducedOrderModelSolver", + "DeepEnsembleSolverInterface", + "DeepEnsembleSupervisedSolver", + "DeepEnsemblePINN", "GAROM", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface -from .physics_informed_solver import * -from .supervised import SupervisedSolver -from .reduced_order_model import ReducedOrderModelSolver +from .physics_informed_solver import ( + PINNInterface, + PINN, + GradientPINN, + CausalPINN, + CompetitivePINN, + SelfAdaptivePINN, + RBAPINN, +) +from .supervised_solver import ( + SupervisedSolverInterface, + SupervisedSolver, + ReducedOrderModelSolver, +) +from .ensemble_solver import ( + DeepEnsembleSolverInterface, + DeepEnsembleSupervisedSolver, + DeepEnsemblePINN, +) from .garom import GAROM diff --git a/pina/solver/ensemble_solver/__init__.py b/pina/solver/ensemble_solver/__init__.py new file mode 100644 index 0000000..0e4eab5 --- /dev/null +++ b/pina/solver/ensemble_solver/__init__.py @@ -0,0 +1,11 @@ +"""Module for the Ensemble solver classes.""" + +__all__ = [ + "DeepEnsembleSolverInterface", + "DeepEnsembleSupervisedSolver", + "DeepEnsemblePINN", +] + +from .ensemble_solver_interface import DeepEnsembleSolverInterface +from .ensemble_supervised import DeepEnsembleSupervisedSolver +from .ensemble_pinn import DeepEnsemblePINN diff --git a/pina/solver/ensemble_solver/ensemble_pinn.py b/pina/solver/ensemble_solver/ensemble_pinn.py new file mode 100644 index 0000000..33d929a --- /dev/null +++ b/pina/solver/ensemble_solver/ensemble_pinn.py @@ -0,0 +1,170 @@ +"""Module for the DeepEnsemble physics solver.""" + +import torch + +from .ensemble_solver_interface import DeepEnsembleSolverInterface +from ..physics_informed_solver import PINNInterface +from ...problem import InverseProblem + + +class DeepEnsemblePINN(PINNInterface, DeepEnsembleSolverInterface): + r""" + Deep Ensemble Physics Informed Solver class. This class implements a + Deep Ensemble for Physics Informed Neural Networks using user + specified ``model``s to solve a specific ``problem``. + + An ensemble model is constructed by combining multiple models that solve + the same type of problem. Mathematically, this creates an implicit + distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible + outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`. + The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in + the ensemble work collaboratively to capture different + aspects of the data or task, with each model contributing a distinct + prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`. + By aggregating these predictions, the ensemble + model can achieve greater robustness and accuracy compared to individual + models, leveraging the diversity of the models to reduce overfitting and + improve generalization. Furthemore, statistical metrics can + be computed, e.g. the ensemble mean and variance: + + .. math:: + \mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i} + + .. math:: + \mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r + (\mathbf{y}_{i} - \mathbf{\mu})^2 + + During training the PINN loss is minimized by each ensemble model: + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^4 + \mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) + + \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)), + + for the differential system: + + .. math:: + + \begin{cases} + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, + \mathbf{x}\in\partial\Omega + \end{cases} + + :math:`\mathcal{L}` indicates a specific loss function, typically the MSE: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + .. seealso:: + + **Original reference**: Zou, Z., Wang, Z., & Karniadakis, G. E. (2025). + *Learning and discovering multiple solutions using physics-informed + neural networks with random initialization and deep ensemble*. + DOI: `arXiv:2503.06320 `_. + + .. warning:: + This solver does not work with inverse problem. Hence in the ``problem`` + definition must not inherit from + :class:`~pina.problem.inverse_problem.InverseProblem`. + """ + + def __init__( + self, + problem, + models, + loss=None, + optimizers=None, + schedulers=None, + weighting=None, + ensemble_dim=0, + ): + """ + Initialization of the :class:`DeepEnsemblePINN` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module models: The neural network models to be used. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. + Default is ``None``. + :param Optimizer optimizer: The optimizer to be used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` + scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If ``None``, no weighting schema is used. Default is ``None``. + :param int ensemble_dim: The dimension along which the ensemble + outputs are stacked. Default is 0. + :raises NotImplementedError: If an inverse problem is passed. + """ + if isinstance(problem, InverseProblem): + raise NotImplementedError( + "DeepEnsemblePINN can not be used to solve inverse problems." + ) + super().__init__( + problem=problem, + models=models, + loss=loss, + optimizers=optimizers, + schedulers=schedulers, + weighting=weighting, + ensemble_dim=ensemble_dim, + ) + + def loss_data(self, input, target): + """ + Compute the data loss for the ensemble PINN solver by evaluating + the loss between the network's output and the true solution for each + model. This method should not be overridden, if not intentionally. + + :param input: The input to the neural network. + :type input: LabelTensor | torch.Tensor | Graph | Data + :param target: The target to compare with the network's output. + :type target: LabelTensor | torch.Tensor | Graph | Data + :return: The supervised loss, averaged over the number of observations. + :rtype: torch.Tensor + """ + predictions = self.forward(input) + loss = sum( + self._loss_fn(predictions[idx], target) + for idx in range(self.num_ensemble) + ) + return loss / self.num_ensemble + + def loss_phys(self, samples, equation): + """ + Computes the physics loss for the ensemble PINN solver by evaluating + the loss between the network's output and the true solution for each + model. This method should not be overridden, if not intentionally. + + :param LabelTensor samples: The samples to evaluate the physics loss. + :param EquationInterface equation: The governing equation. + :return: The computed physics loss. + :rtype: LabelTensor + """ + return self._residual_loss(samples, equation) + + def _residual_loss(self, samples, equation): + """ + Computes the physics loss for the physics-informed solver based on the + provided samples and equation. This method should never be overridden + by the user, if not intentionally, + since it is used internally to compute validation loss. It overrides the + :obj:`~pina.solver.physics_informed_solver.PINNInterface._residual_loss` + method. + + :param LabelTensor samples: The samples to evaluate the loss. + :param EquationInterface equation: The governing equation. + :return: The residual loss. + :rtype: torch.Tensor + """ + loss = 0 + predictions = self.forward(samples) + for idx in range(self.num_ensemble): + residuals = equation.residual(samples, predictions[idx]) + target = torch.zeros_like(residuals, requires_grad=True) + loss = loss + self._loss_fn(residuals, target) + return loss / self.num_ensemble diff --git a/pina/solver/ensemble_solver/ensemble_solver_interface.py b/pina/solver/ensemble_solver/ensemble_solver_interface.py new file mode 100644 index 0000000..6d874e1 --- /dev/null +++ b/pina/solver/ensemble_solver/ensemble_solver_interface.py @@ -0,0 +1,152 @@ +"""Module for the DeepEnsemble solver interface.""" + +import torch +from ..solver import MultiSolverInterface +from ...utils import check_consistency + + +class DeepEnsembleSolverInterface(MultiSolverInterface): + r""" + A class for handling ensemble models in a multi-solver training framework. + It allows for manual optimization, as well as the ability to train, + validate, and test multiple models as part of an ensemble. + The ensemble dimension can be customized to control how outputs are stacked. + + By default, it is compatible with problems defined by + :class:`~pina.problem.abstract_problem.AbstractProblem`, + and users can choose the problem type the solver is meant to address. + + An ensemble model is constructed by combining multiple models that solve + the same type of problem. Mathematically, this creates an implicit + distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible + outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`. + The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in + the ensemble work collaboratively to capture different + aspects of the data or task, with each model contributing a distinct + prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`. + By aggregating these predictions, the ensemble + model can achieve greater robustness and accuracy compared to individual + models, leveraging the diversity of the models to reduce overfitting and + improve generalization. Furthemore, statistical metrics can + be computed, e.g. the ensemble mean and variance: + + .. math:: + \mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i} + + .. math:: + \mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r + (\mathbf{y}_{i} - \mathbf{\mu})^2 + + .. seealso:: + + **Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell, + C. (2017). *Simple and scalable predictive uncertainty estimation + using deep ensembles*. Advances in neural information + processing systems, 30. + DOI: `arXiv:1612.01474 `_. + """ + + def __init__( + self, + problem, + models, + optimizers=None, + schedulers=None, + weighting=None, + use_lt=True, + ensemble_dim=0, + ): + """ + Initialization of the :class:`DeepEnsembleSolverInterface` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module models: The neural network models to be used. + :param Optimizer optimizer: The optimizer to be used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` + scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If ``None``, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. + Default is ``True``. + :param int ensemble_dim: The dimension along which the ensemble + outputs are stacked. Default is 0. + """ + super().__init__( + problem, models, optimizers, schedulers, weighting, use_lt + ) + # check consistency + check_consistency(ensemble_dim, int) + self._ensemble_dim = ensemble_dim + + def forward(self, x, ensemble_idx=None): + """ + Forward pass through the ensemble models. If an `ensemble_idx` is + provided, it returns the output of the specific model + corresponding to that index. If no index is given, it stacks the outputs + of all models along the ensemble dimension. + + :param LabelTensor x: The input tensor to the models. + :param int ensemble_idx: Optional index to select a specific + model from the ensemble. If ``None`` results for all models are + stacked in ``ensemble_dim`` dimension. Default is ``None``. + :return: The output of the selected model or the stacked + outputs from all models. + :rtype: LabelTensor + """ + # if an index is passed, return the specific model output for that index + if ensemble_idx is not None: + return self.models[ensemble_idx].forward(x) + # otherwise return the stacked output + return torch.stack( + [self.forward(x, idx) for idx in range(self.num_ensemble)], + dim=self.ensemble_dim, + ) + + def training_step(self, batch): + """ + Training step for the solver, overridden for manual optimization. + This method performs a forward pass, calculates the loss, and applies + manual backward propagation and optimization steps for each model in + the ensemble. + + :param list[tuple[str, dict]] batch: A batch of training data. + Each element is a tuple containing a condition name and a + dictionary of points. + :return: The aggregated loss after the training step. + :rtype: torch.Tensor + """ + # zero grad for optimizer + for opt in self.optimizers: + opt.instance.zero_grad() + # perform forward passes and aggregate losses + loss = super().training_step(batch) + # perform backpropagation + self.manual_backward(loss) + # optimize + for opt, sched in zip(self.optimizers, self.schedulers): + opt.instance.step() + sched.instance.step() + return loss + + @property + def ensemble_dim(self): + """ + The dimension along which the ensemble outputs are stacked. + + :return: The ensemble dimension. + :rtype: int + """ + return self._ensemble_dim + + @property + def num_ensemble(self): + """ + The number of models in the ensemble. + + :return: The number of models in the ensemble. + :rtype: int + """ + return len(self.models) diff --git a/pina/solver/ensemble_solver/ensemble_supervised.py b/pina/solver/ensemble_solver/ensemble_supervised.py new file mode 100644 index 0000000..e4837cc --- /dev/null +++ b/pina/solver/ensemble_solver/ensemble_supervised.py @@ -0,0 +1,122 @@ +"""Module for the DeepEnsemble supervised solver.""" + +from .ensemble_solver_interface import DeepEnsembleSolverInterface +from ..supervised_solver import SupervisedSolverInterface + + +class DeepEnsembleSupervisedSolver( + SupervisedSolverInterface, DeepEnsembleSolverInterface +): + r""" + Deep Ensemble Supervised Solver class. This class implements a + Deep Ensemble Supervised Solver using user specified ``model``s to solve + a specific ``problem``. + + An ensemble model is constructed by combining multiple models that solve + the same type of problem. Mathematically, this creates an implicit + distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible + outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`. + The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in + the ensemble work collaboratively to capture different + aspects of the data or task, with each model contributing a distinct + prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`. + By aggregating these predictions, the ensemble + model can achieve greater robustness and accuracy compared to individual + models, leveraging the diversity of the models to reduce overfitting and + improve generalization. Furthemore, statistical metrics can + be computed, e.g. the ensemble mean and variance: + + .. math:: + \mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i} + + .. math:: + \mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r + (\mathbf{y}_{i} - \mathbf{\mu})^2 + + During training the supervised loss is minimized by each ensemble model: + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathbf{u}_i - \mathcal{M}_{j}(\mathbf{s}_i)), + \quad j \in (1,\dots,N_{ensemble}) + + where :math:`\mathcal{L}` is a specific loss function, typically the MSE: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{s}_i` indicates + the will to approximate multiple (discretised) functions given multiple + (discretised) input functions. + + .. seealso:: + + **Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell, + C. (2017). *Simple and scalable predictive uncertainty estimation + using deep ensembles*. Advances in neural information + processing systems, 30. + DOI: `arXiv:1612.01474 `_. + """ + + def __init__( + self, + problem, + models, + loss=None, + optimizers=None, + schedulers=None, + weighting=None, + use_lt=False, + ensemble_dim=0, + ): + """ + Initialization of the :class:`DeepEnsembleSupervisedSolver` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module models: The neural network models to be used. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. + Default is ``None``. + :param Optimizer optimizer: The optimizer to be used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` + scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If ``None``, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. + Default is ``True``. + :param int ensemble_dim: The dimension along which the ensemble + outputs are stacked. Default is 0. + """ + super().__init__( + problem=problem, + models=models, + loss=loss, + optimizers=optimizers, + schedulers=schedulers, + weighting=weighting, + use_lt=use_lt, + ensemble_dim=ensemble_dim, + ) + + def loss_data(self, input, target): + """ + Compute the data loss for the EnsembleSupervisedSolver by evaluating + the loss between the network's output and the true solution for each + model. This method should not be overridden, if not intentionally. + + :param input: The input to the neural network. + :type input: LabelTensor | torch.Tensor | Graph | Data + :param target: The target to compare with the network's output. + :type target: LabelTensor | torch.Tensor | Graph | Data + :return: The supervised loss, averaged over the number of observations. + :rtype: torch.Tensor + """ + predictions = self.forward(input) + loss = sum( + self._loss_fn(predictions[idx], target) + for idx in range(self.num_ensemble) + ) + return loss / self.num_ensemble diff --git a/pina/solver/garom.py b/pina/solver/garom.py index b854ce7..372eedd 100644 --- a/pina/solver/garom.py +++ b/pina/solver/garom.py @@ -48,18 +48,18 @@ class GAROM(MultiSolverInterface): If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1`` is used. Default is ``None``. :param Optimizer optimizer_generator: The optimizer for the generator. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Optimizer optimizer_discriminator: The optimizer for the - discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is - used. Default is ``None``. + discriminator. If ``None``, the :class:`torch.optim.Adam` + optimizer is used. Default is ``None``. :param Scheduler scheduler_generator: The learning rate scheduler for the generator. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param Scheduler scheduler_discriminator: The learning rate scheduler for the discriminator. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param float gamma: Ratio of expected loss for generator and discriminator. Default is ``0.3``. @@ -88,7 +88,7 @@ class GAROM(MultiSolverInterface): check_consistency( loss, (LossInterface, _Loss, torch.nn.Module), subclass=False ) - self._loss = loss + self._loss_fn = loss # set automatic optimization for GANs self.automatic_optimization = False @@ -157,10 +157,11 @@ class GAROM(MultiSolverInterface): generated_snapshots = self.sample(parameters) # generator loss - r_loss = self._loss(snapshots, generated_snapshots) + r_loss = self._loss_fn(snapshots, generated_snapshots) d_fake = self.discriminator([generated_snapshots, parameters]) g_loss = ( - self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss + self._loss_fn(d_fake, generated_snapshots) + + self.regularizer * r_loss ) # backward step @@ -170,24 +171,6 @@ class GAROM(MultiSolverInterface): return r_loss, g_loss - def on_train_batch_end(self, outputs, batch, batch_idx): - """ - This method is called at the end of each training batch and overrides - the PyTorch Lightning implementation to log checkpoints. - - :param torch.Tensor outputs: The ``model``'s output for the current - batch. - :param list[tuple[str, dict]] batch: A batch of data. Each element is a - tuple containing a condition name and a dictionary of points. - :param int batch_idx: The index of the current batch. - """ - # increase by one the counter of optimization to save loggers - ( - self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed - ) += 1 - - return super().on_train_batch_end(outputs, batch, batch_idx) - def _train_discriminator(self, parameters, snapshots): """ Train the discriminator model. @@ -207,8 +190,8 @@ class GAROM(MultiSolverInterface): d_fake = self.discriminator([generated_snapshots, parameters]) # evaluate loss - d_loss_real = self._loss(d_real, snapshots) - d_loss_fake = self._loss(d_fake, generated_snapshots.detach()) + d_loss_real = self._loss_fn(d_real, snapshots) + d_loss_fake = self._loss_fn(d_fake, generated_snapshots.detach()) d_loss = d_loss_real - self.k * d_loss_fake # backward step @@ -288,7 +271,7 @@ class GAROM(MultiSolverInterface): points["target"], ) snapshots_gen = self.generator(parameters) - condition_loss[condition_name] = self._loss( + condition_loss[condition_name] = self._loss_fn( snapshots, snapshots_gen ) loss = self.weighting.aggregate(condition_loss) @@ -311,7 +294,7 @@ class GAROM(MultiSolverInterface): points["target"], ) snapshots_gen = self.generator(parameters) - condition_loss[condition_name] = self._loss( + condition_loss[condition_name] = self._loss_fn( snapshots, snapshots_gen ) loss = self.weighting.aggregate(condition_loss) diff --git a/pina/solver/physics_informed_solver/causal_pinn.py b/pina/solver/physics_informed_solver/causal_pinn.py index 1fb102a..ab085be 100644 --- a/pina/solver/physics_informed_solver/causal_pinn.py +++ b/pina/solver/physics_informed_solver/causal_pinn.py @@ -83,15 +83,15 @@ class CausalPINN(PINN): :class:`~pina.problem.time_dependent_problem.TimeDependentProblem`. :param torch.nn.Module model: The neural network model to be used. :param Optimizer optimizer: The optimizer to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param torch.optim.LRScheduler scheduler: Learning rate scheduler. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. :param float eps: The exponential decay parameter. Default is ``100``. :raises ValueError: If the problem is not a TimeDependentProblem. @@ -134,7 +134,7 @@ class CausalPINN(PINN): chunk.labels = labels # classical PINN loss residual = self.compute_residual(samples=chunk, equation=equation) - loss_val = self.loss( + loss_val = self._loss_fn( torch.zeros_like(residual, requires_grad=True), residual ) time_loss.append(loss_val) diff --git a/pina/solver/physics_informed_solver/competitive_pinn.py b/pina/solver/physics_informed_solver/competitive_pinn.py index e294c70..5375efb 100644 --- a/pina/solver/physics_informed_solver/competitive_pinn.py +++ b/pina/solver/physics_informed_solver/competitive_pinn.py @@ -69,26 +69,26 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. :param torch.nn.Module discriminator: The discriminator to be used. - If `None`, the discriminator is a deepcopy of the ``model``. + If ``None``, the discriminator is a deepcopy of the ``model``. Default is ``None``. :param torch.optim.Optimizer optimizer_model: The optimizer of the - ``model``. If `None`, the :class:`torch.optim.Adam` optimizer is + ``model``. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param torch.optim.Optimizer optimizer_discriminator: The optimizer of - the ``discriminator``. If `None`, the :class:`torch.optim.Adam` + the ``discriminator``. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler_model: Learning rate scheduler for the ``model``. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param Scheduler scheduler_discriminator: Learning rate scheduler for the ``discriminator``. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. """ if discriminator is None: @@ -156,12 +156,27 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): residual = residual * discriminator_bets # Compute competitive residual. - loss_val = self.loss( + loss_val = self._loss_fn( torch.zeros_like(residual, requires_grad=True), residual, ) return loss_val + def loss_data(self, input, target): + """ + Compute the data loss for the PINN solver by evaluating the loss + between the network's output and the true solution. This method should + not be overridden, if not intentionally. + + :param input: The input to the neural network. + :type input: LabelTensor + :param target: The target to compare with the network's output. + :type target: LabelTensor + :return: The supervised loss, averaged over the number of observations. + :rtype: LabelTensor + """ + return self._loss_fn(self.forward(input), target) + def configure_optimizers(self): """ Optimizer configuration. @@ -195,24 +210,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface): ], ) - def on_train_batch_end(self, outputs, batch, batch_idx): - """ - This method is called at the end of each training batch and overrides - the PyTorch Lightning implementation to log checkpoints. - - :param torch.Tensor outputs: The ``model``'s output for the current - batch. - :param list[tuple[str, dict]] batch: A batch of data. Each element is a - tuple containing a condition name and a dictionary of points. - :param int batch_idx: The index of the current batch. - """ - # increase by one the counter of optimization to save loggers - ( - self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed - ) += 1 - - return super().on_train_batch_end(outputs, batch, batch_idx) - @property def neural_net(self): """ diff --git a/pina/solver/physics_informed_solver/gradient_pinn.py b/pina/solver/physics_informed_solver/gradient_pinn.py index 4ac2b4c..0de431c 100644 --- a/pina/solver/physics_informed_solver/gradient_pinn.py +++ b/pina/solver/physics_informed_solver/gradient_pinn.py @@ -75,15 +75,15 @@ class GradientPINN(PINN): gradient of the loss. :param torch.nn.Module model: The neural network model to be used. :param Optimizer optimizer: The optimizer to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler: Learning rate scheduler. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. :raises ValueError: If the problem is not a SpatialProblem. """ @@ -116,7 +116,7 @@ class GradientPINN(PINN): """ # classical PINN loss residual = self.compute_residual(samples=samples, equation=equation) - loss_value = self.loss( + loss_value = self._loss_fn( torch.zeros_like(residual, requires_grad=True), residual ) @@ -124,7 +124,7 @@ class GradientPINN(PINN): loss_value = loss_value.reshape(-1, 1) loss_value.labels = ["__loss"] loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables) - g_loss_phys = self.loss( + g_loss_phys = self._loss_fn( torch.zeros_like(loss_grad, requires_grad=True), loss_grad ) return loss_value + g_loss_phys diff --git a/pina/solver/physics_informed_solver/pinn.py b/pina/solver/physics_informed_solver/pinn.py index 6d92d9c..914d014 100644 --- a/pina/solver/physics_informed_solver/pinn.py +++ b/pina/solver/physics_informed_solver/pinn.py @@ -62,15 +62,15 @@ class PINN(PINNInterface, SingleSolverInterface): :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. :param Optimizer optimizer: The optimizer to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler: Learning rate scheduler. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. """ super().__init__( @@ -82,6 +82,21 @@ class PINN(PINNInterface, SingleSolverInterface): loss=loss, ) + def loss_data(self, input, target): + """ + Compute the data loss for the PINN solver by evaluating the loss + between the network's output and the true solution. This method should + not be overridden, if not intentionally. + + :param input: The input to the neural network. + :type input: LabelTensor + :param target: The target to compare with the network's output. + :type target: LabelTensor + :return: The supervised loss, averaged over the number of observations. + :rtype: LabelTensor + """ + return self._loss_fn(self.forward(input), target) + def loss_phys(self, samples, equation): """ Computes the physics loss for the physics-informed solver based on the @@ -92,11 +107,8 @@ class PINN(PINNInterface, SingleSolverInterface): :return: The computed physics loss. :rtype: LabelTensor """ - residual = self.compute_residual(samples=samples, equation=equation) - loss_value = self.loss( - torch.zeros_like(residual, requires_grad=True), residual - ) - return loss_value + residuals = self.compute_residual(samples, equation) + return self._loss_fn(residuals, torch.zeros_like(residuals)) def configure_optimizers(self): """ diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/solver/physics_informed_solver/pinn_interface.py index 09e152f..c53e123 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/solver/physics_informed_solver/pinn_interface.py @@ -38,7 +38,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. :param kwargs: Additional keyword arguments to be passed to the :class:`~pina.solver.solver.SolverInterface` class. @@ -53,7 +53,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): check_consistency(loss, (LossInterface, _Loss), subclass=False) # assign variables - self._loss = loss + self._loss_fn = loss # inverse problem handling if isinstance(self.problem, InverseProblem): @@ -65,7 +65,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): self.__metric = None - def optimization_cycle(self, batch): + def optimization_cycle(self, batch, loss_residuals=None): """ The optimization cycle for the PINN solver. @@ -80,51 +80,74 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): containing the condition name and the associated scalar loss. :rtype: dict """ - return self._run_optimization_cycle(batch, self.loss_phys) + # which losses to use + if loss_residuals is None: + loss_residuals = self.loss_phys + # compute optimization cycle + condition_loss = {} + for condition_name, points in batch: + self.__metric = condition_name + # if equations are passed + if "target" not in points: + input_pts = points["input"] + condition = self.problem.conditions[condition_name] + loss = loss_residuals( + input_pts.requires_grad_(), condition.equation + ) + # if data are passed + else: + input_pts = points["input"] + output_pts = points["target"] + loss = self.loss_data( + input=input_pts.requires_grad_(), target=output_pts + ) + # append loss + condition_loss[condition_name] = loss + # clamp unknown parameters in InverseProblem (if needed) + self._clamp_params() + return condition_loss @torch.set_grad_enabled(True) def validation_step(self, batch): """ - The validation step for the PINN solver. + The validation step for the PINN solver. It returns the average residual + computed with the ``loss`` function not aggregated. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The loss of the validation step. :rtype: torch.Tensor """ - losses = self._run_optimization_cycle(batch, self._residual_loss) - loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) - self.store_log("val_loss", loss, self.get_batch_size(batch)) - return loss + return super().validation_step( + batch, loss_residuals=self._residual_loss + ) @torch.set_grad_enabled(True) def test_step(self, batch): """ - The test step for the PINN solver. + The test step for the PINN solver. It returns the average residual + computed with the ``loss`` function not aggregated. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The loss of the test step. :rtype: torch.Tensor """ - losses = self._run_optimization_cycle(batch, self._residual_loss) - loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) - self.store_log("test_loss", loss, self.get_batch_size(batch)) - return loss + return super().test_step(batch, loss_residuals=self._residual_loss) - def loss_data(self, input_pts, output_pts): + @abstractmethod + def loss_data(self, input, target): """ Compute the data loss for the PINN solver by evaluating the loss between the network's output and the true solution. This method should - not be overridden, if not intentionally. + be overridden by the derived class. - :param LabelTensor input_pts: The input points to the neural network. - :param LabelTensor output_pts: The true solution to compare with the + :param LabelTensor input: The input to the neural network. + :param LabelTensor target: The target to compare with the network's output. :return: The supervised loss, averaged over the number of observations. - :rtype: torch.Tensor + :rtype: LabelTensor """ - return self._loss(self.forward(input_pts), output_pts) @abstractmethod def loss_phys(self, samples, equation): @@ -159,7 +182,11 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): def _residual_loss(self, samples, equation): """ - Compute the residual loss. + Computes the physics loss for the physics-informed solver based on the + provided samples and equation. This method should never be overridden + by the user, if not intentionally, + since it is used internally to compute validation loss. + :param LabelTensor samples: The samples to evaluate the loss. :param EquationInterface equation: The governing equation. @@ -167,43 +194,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): :rtype: torch.Tensor """ residuals = self.compute_residual(samples, equation) - return self.loss(residuals, torch.zeros_like(residuals)) - - def _run_optimization_cycle(self, batch, loss_residuals): - """ - Compute, given a batch, the loss for each condition and return a - dictionary with the condition name as key and the loss as value. - - :param list[tuple[str, dict]] batch: A batch of data. Each element is a - tuple containing a condition name and a dictionary of points. - :param function loss_residuals: The loss function to be minimized. - :return: The losses computed for all conditions in the batch, casted - to a subclass of :class:`torch.Tensor`. It should return a dict - containing the condition name and the associated scalar loss. - :rtype: dict - """ - condition_loss = {} - for condition_name, points in batch: - self.__metric = condition_name - # if equations are passed - if "target" not in points: - input_pts = points["input"] - condition = self.problem.conditions[condition_name] - loss = loss_residuals( - input_pts.requires_grad_(), condition.equation - ) - # if data are passed - else: - input_pts = points["input"] - output_pts = points["target"] - loss = self.loss_data( - input_pts=input_pts.requires_grad_(), output_pts=output_pts - ) - # append loss - condition_loss[condition_name] = loss - # clamp unknown parameters in InverseProblem (if needed) - self._clamp_params() - return condition_loss + return self._loss_fn(residuals, torch.zeros_like(residuals)) def _clamp_inverse_problem_params(self): """ @@ -223,7 +214,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): :return: The loss function used for training. :rtype: torch.nn.Module """ - return self._loss + return self._loss_fn @property def current_condition_name(self): diff --git a/pina/solver/physics_informed_solver/rba_pinn.py b/pina/solver/physics_informed_solver/rba_pinn.py index feeb5c8..808ac5a 100644 --- a/pina/solver/physics_informed_solver/rba_pinn.py +++ b/pina/solver/physics_informed_solver/rba_pinn.py @@ -83,15 +83,15 @@ class RBAPINN(PINN): :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. :param Optimizer optimizer: The optimizer to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler: Learning rate scheduler. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. :param float | int eta: The learning rate for the weights of the residuals. Default is ``0.001``. diff --git a/pina/solver/physics_informed_solver/self_adaptive_pinn.py b/pina/solver/physics_informed_solver/self_adaptive_pinn.py index 78dd1ce..9521556 100644 --- a/pina/solver/physics_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physics_informed_solver/self_adaptive_pinn.py @@ -120,24 +120,24 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): :param torch.nn.Module weight_function: The Self-Adaptive mask model. Default is ``torch.nn.Sigmoid()``. :param Optimizer optimizer_model: The optimizer of the ``model``. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Optimizer optimizer_weights: The optimizer of the ``weight_function``. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler_model: Learning rate scheduler for the ``model``. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param Scheduler scheduler_weights: Learning rate scheduler for the ``weight_function``. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. """ # check consistency weitghs_function @@ -223,24 +223,6 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): [self.scheduler_model.instance, self.scheduler_weights.instance], ) - def on_train_batch_end(self, outputs, batch, batch_idx): - """ - This method is called at the end of each training batch and overrides - the PyTorch Lightning implementation to log checkpoints. - - :param torch.Tensor outputs: The ``model``'s output for the current - batch. - :param list[tuple[str, dict]] batch: A batch of data. Each element is a - tuple containing a condition name and a dictionary of points. - :param int batch_idx: The index of the current batch. - """ - # increase by one the counter of optimization to save loggers - ( - self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed - ) += 1 - - return super().on_train_batch_end(outputs, batch, batch_idx) - def on_train_start(self): """ This method is called at the start of the training process to set the @@ -304,6 +286,21 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): ) return self._vect_to_scalar(weights * loss_value) + def loss_data(self, input, target): + """ + Compute the data loss for the PINN solver by evaluating the loss + between the network's output and the true solution. This method should + not be overridden, if not intentionally. + + :param input: The input to the neural network. + :type input: LabelTensor + :param target: The target to compare with the network's output. + :type target: LabelTensor + :return: The supervised loss, averaged over the number of observations. + :rtype: LabelTensor + """ + return self._loss_fn(self.forward(input), target) + def _vect_to_scalar(self, loss_value): """ Computation of the scalar loss. diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 6776fea..99de1df 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod import lightning import torch -from torch._dynamo.eval_frame import OptimizedModule +from torch._dynamo import OptimizedModule from ..problem import AbstractProblem from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler from ..loss import WeightingInterface @@ -29,7 +29,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): :param AbstractProblem problem: The problem to be solved. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """ super().__init__() @@ -64,18 +64,20 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): self._pina_optimizers = None self._pina_schedulers = None - def _check_solver_consistency(self, problem): + @abstractmethod + def forward(self, *args, **kwargs): """ - Check the consistency of the solver with the problem formulation. + Abstract method for the forward pass implementation. - :param AbstractProblem problem: The problem to be solved. + :param args: The input tensor. + :type args: torch.Tensor | LabelTensor | Data | Graph + :param dict kwargs: Additional keyword arguments. """ - for condition in problem.conditions.values(): - check_consistency(condition, self.accepted_conditions_types) - def _optimization_cycle(self, batch): + @abstractmethod + def optimization_cycle(self, batch): """ - Aggregate the loss for each condition in the batch. + The optimization cycle for the solvers. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. @@ -84,46 +86,58 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): containing the condition name and the associated scalar loss. :rtype: dict """ - losses = self.optimization_cycle(batch) - for name, value in losses.items(): - self.store_log( - f"{name}_loss", value.item(), self.get_batch_size(batch) - ) - loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) - return loss - def training_step(self, batch): + def training_step(self, batch, **kwargs): """ - Solver training step. + Solver training step. It computes the optimization cycle and aggregates + the losses using the ``weighting`` attribute. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. :return: The loss of the training step. - :rtype: LabelTensor + :rtype: torch.Tensor """ - loss = self._optimization_cycle(batch=batch) + loss = self._optimization_cycle(batch=batch, **kwargs) self.store_log("train_loss", loss, self.get_batch_size(batch)) return loss - def validation_step(self, batch): + def validation_step(self, batch, **kwargs): """ - Solver validation step. + Solver validation step. It computes the optimization cycle and + averages the losses. No aggregation using the ``weighting`` attribute is + performed. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the training step. + :rtype: torch.Tensor """ - loss = self._optimization_cycle(batch=batch) + losses = self.optimization_cycle(batch=batch, **kwargs) + loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) self.store_log("val_loss", loss, self.get_batch_size(batch)) + return loss - def test_step(self, batch): + def test_step(self, batch, **kwargs): """ - Solver test step. + Solver test step. It computes the optimization cycle and + averages the losses. No aggregation using the ``weighting`` attribute is + performed. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. + :return: The loss of the training step. + :rtype: torch.Tensor """ - loss = self._optimization_cycle(batch=batch) + losses = self.optimization_cycle(batch=batch, **kwargs) + loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) self.store_log("test_loss", loss, self.get_batch_size(batch)) + return loss def store_log(self, name, value, batch_size): """ @@ -141,58 +155,118 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): **self.trainer.logging_kwargs, ) - @abstractmethod - def forward(self, *args, **kwargs): + def setup(self, stage): """ - Abstract method for the forward pass implementation. + This method is called at the start of the train and test process to + compile the model if the :class:`~pina.trainer.Trainer` + ``compile`` is ``True``. - :param args: The input tensor. - :type args: torch.Tensor | LabelTensor - :param dict kwargs: Additional keyword arguments. - """ - @abstractmethod - def optimization_cycle(self, batch): """ - The optimization cycle for the solvers. + if stage == "fit" and self.trainer.compile: + self._setup_compile() + if stage == "test" and ( + self.trainer.compile and not self._is_compiled() + ): + self._setup_compile() + return super().setup(stage) + + def _is_compiled(self): + """ + Check if the model is compiled. + + :return: ``True`` if the model is compiled, ``False`` otherwise. + :rtype: bool + """ + for model in self._pina_models: + if not isinstance(model, OptimizedModule): + return False + return True + + def _setup_compile(self): + """ + Compile all models in the solver using ``torch.compile``. + + This method iterates through each model stored in the solver + list and attempts to compile them for optimized execution. It supports + models of type `torch.nn.Module` and `torch.nn.ModuleDict`. For models + stored in a `ModuleDict`, each submodule is compiled individually. + Models on Apple Silicon (MPS) use the 'eager' backend, + while others use 'inductor'. + + :raises RuntimeError: If a model is neither `torch.nn.Module` + nor `torch.nn.ModuleDict`. + """ + for i, model in enumerate(self._pina_models): + if isinstance(model, torch.nn.ModuleDict): + for name, module in model.items(): + self._pina_models[i][name] = self._compile_modules(module) + elif isinstance(model, torch.nn.Module): + self._pina_models[i] = self._compile_modules(model) + else: + raise RuntimeError( + "Compilation available only for " + "torch.nn.Module or torch.nn.ModuleDict." + ) + + def _check_solver_consistency(self, problem): + """ + Check the consistency of the solver with the problem formulation. + + :param AbstractProblem problem: The problem to be solved. + """ + for condition in problem.conditions.values(): + check_consistency(condition, self.accepted_conditions_types) + + def _optimization_cycle(self, batch, **kwargs): + """ + Aggregate the loss for each condition in the batch. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. + :param dict kwargs: Additional keyword arguments passed to + ``optimization_cycle``. :return: The losses computed for all conditions in the batch, casted to a subclass of :class:`torch.Tensor`. It should return a dict containing the condition name and the associated scalar loss. :rtype: dict """ + losses = self.optimization_cycle(batch) + for name, value in losses.items(): + self.store_log( + f"{name}_loss", value.item(), self.get_batch_size(batch) + ) + loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) + return loss - @property - def problem(self): + @staticmethod + def _compile_modules(model): """ - The problem instance. + Perform the compilation of the model. - :return: The problem instance. - :rtype: :class:`~pina.problem.abstract_problem.AbstractProblem` - """ - return self._pina_problem + This method attempts to compile the given PyTorch model + using ``torch.compile`` to improve execution performance. The + backend is selected based on the device on which the model resides: + ``eager`` is used for MPS devices (Apple Silicon), and ``inductor`` + is used for all others. - @property - def use_lt(self): - """ - Using LabelTensors as input during training. + If compilation fails, the method prints the error and returns the + original, uncompiled model. - :return: The use_lt attribute. - :rtype: bool + :param torch.nn.Module model: The model to compile. + :raises Exception: If the compilation fails. + :return: The compiled model. + :rtype: torch.nn.Module """ - return self._use_lt - - @property - def weighting(self): - """ - The weighting schema. - - :return: The weighting schema. - :rtype: :class:`~pina.loss.weighting_interface.WeightingInterface` - """ - return self._pina_weighting + model_device = next(model.parameters()).device + try: + if model_device == torch.device("mps:0"): + model = torch.compile(model, backend="eager") + else: + model = torch.compile(model, backend="inductor") + except Exception as e: + print("Compilation failed, running in normal mode.:\n", e) + return model @staticmethod def get_batch_size(batch): @@ -232,62 +306,35 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): return TorchScheduler(torch.optim.lr_scheduler.ConstantLR) - def on_train_start(self): + @property + def problem(self): """ - This method is called at the start of the training process to compile - the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. - """ - super().on_train_start() - if self.trainer.compile: - self._compile_model() + The problem instance. - def on_test_start(self): + :return: The problem instance. + :rtype: :class:`~pina.problem.abstract_problem.AbstractProblem` """ - This method is called at the start of the test process to compile - the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. - """ - super().on_train_start() - if self.trainer.compile and not self._check_already_compiled(): - self._compile_model() + return self._pina_problem - def _check_already_compiled(self): + @property + def use_lt(self): """ - Check if the model is already compiled. + Using LabelTensors as input during training. - :return: ``True`` if the model is already compiled, ``False`` otherwise. + :return: The use_lt attribute. :rtype: bool """ + return self._use_lt - models = self._pina_models - if len(models) == 1 and isinstance( - self._pina_models[0], torch.nn.ModuleDict - ): - models = list(self._pina_models.values()) - for model in models: - if not isinstance(model, (OptimizedModule, torch.nn.ModuleDict)): - return False - return True - - @staticmethod - def _perform_compilation(model): + @property + def weighting(self): """ - Perform the compilation of the model. + The weighting schema. - :param torch.nn.Module model: The model to compile. - :raises Exception: If the compilation fails. - :return: The compiled model. - :rtype: torch.nn.Module + :return: The weighting schema. + :rtype: :class:`~pina.loss.weighting_interface.WeightingInterface` """ - - model_device = next(model.parameters()).device - try: - if model_device == torch.device("mps:0"): - model = torch.compile(model, backend="eager") - else: - model = torch.compile(model, backend="inductor") - except Exception as e: - print("Compilation failed, running in normal mode.:\n", e) - return model + return self._pina_weighting class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): @@ -310,13 +357,13 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. :param Optimizer optimizer: The optimizer to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler: The scheduler to be used. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """ if optimizer is None: @@ -344,12 +391,11 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): Forward pass implementation. :param x: Input tensor. - :type x: torch.Tensor | LabelTensor + :type x: torch.Tensor | LabelTensor | Graph | Data :return: Solver solution. - :rtype: torch.Tensor | LabelTensor + :rtype: torch.Tensor | LabelTensor | Graph | Data """ - x = self.model(x) - return x + return self.model(x) def configure_optimizers(self): """ @@ -362,28 +408,6 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): self.scheduler.hook(self.optimizer) return ([self.optimizer.instance], [self.scheduler.instance]) - def _compile_model(self): - """ - Compile the model. - """ - if isinstance(self._pina_models[0], torch.nn.ModuleDict): - self._compile_module_dict() - else: - self._compile_single_model() - - def _compile_module_dict(self): - """ - Compile the model if it is a :class:`torch.nn.ModuleDict`. - """ - for name, model in self._pina_models[0].items(): - self._pina_models[0][name] = self._perform_compilation(model) - - def _compile_single_model(self): - """ - Compile the model if it is a single :class:`torch.nn.Module`. - """ - self._pina_models[0] = self._perform_compilation(self._pina_models[0]) - @property def model(self): """ @@ -436,13 +460,13 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): :param models: The neural network models to be used. :type model: list[torch.nn.Module] | tuple[torch.nn.Module] :param list[Optimizer] optimizers: The optimizers to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is used for all + If ``None``, the :class:`torch.optim.Adam` optimizer is used for all models. Default is ``None``. :param list[Scheduler] schedulers: The schedulers to be used. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used for all the models. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. :raises ValueError: If the models are not a list or tuple with length greater than one. @@ -519,6 +543,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): # http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html self.automatic_optimization = False + def on_train_batch_end(self, outputs, batch, batch_idx): + """ + This method is called at the end of each training batch and overrides + the PyTorch Lightning implementation to log checkpoints. + + :param torch.Tensor outputs: The ``model``'s output for the current + batch. + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :param int batch_idx: The index of the current batch. + """ + # increase by one the counter of optimization to save loggers + epoch_loop = self.trainer.fit_loop.epoch_loop + epoch_loop.manual_optimization.optim_step_progress.total.completed += 1 + return super().on_train_batch_end(outputs, batch, batch_idx) + def configure_optimizers(self): """ Optimizer configuration for the solver. @@ -537,14 +577,6 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): [scheduler.instance for scheduler in self.schedulers], ) - def _compile_model(self): - """ - Compile the model. - """ - for i, model in enumerate(self._pina_models): - if not isinstance(model, torch.nn.ModuleDict): - self._pina_models[i] = self._perform_compilation(model) - @property def models(self): """ diff --git a/pina/solver/supervised.py b/pina/solver/supervised.py deleted file mode 100644 index 9a5a5f4..0000000 --- a/pina/solver/supervised.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Module for the Supervised solver.""" - -import torch -from torch.nn.modules.loss import _Loss -from .solver import SingleSolverInterface -from ..utils import check_consistency -from ..loss.loss_interface import LossInterface -from ..condition import InputTargetCondition - - -class SupervisedSolver(SingleSolverInterface): - r""" - Supervised Solver solver class. This class implements a Supervised Solver, - using a user specified ``model`` to solve a specific ``problem``. - - The Supervised Solver class aims to find a map between the input - :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output - :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. - - Given a model :math:`\mathcal{M}`, the following loss function is - minimized during training: - - .. math:: - \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N - \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)), - - where :math:`\mathcal{L}` is a specific loss function, typically the MSE: - - .. math:: - \mathcal{L}(v) = \| v \|^2_2. - - In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` indicates - the will to approximate multiple (discretised) functions given multiple - (discretised) input functions. - """ - - accepted_conditions_types = InputTargetCondition - - def __init__( - self, - problem, - model, - loss=None, - optimizer=None, - scheduler=None, - weighting=None, - use_lt=True, - ): - """ - Initialization of the :class:`SupervisedSolver` class. - - :param AbstractProblem problem: The problem to be solved. - :param torch.nn.Module model: The neural network model to be used. - :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. - Default is `None`. - :param Optimizer optimizer: The optimizer to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is used. - Default is ``None``. - :param Scheduler scheduler: Learning rate scheduler. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` - scheduler is used. Default is ``None``. - :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. - :param bool use_lt: If ``True``, the solver uses LabelTensors as input. - Default is ``True``. - """ - if loss is None: - loss = torch.nn.MSELoss() - - super().__init__( - model=model, - problem=problem, - optimizer=optimizer, - scheduler=scheduler, - weighting=weighting, - use_lt=use_lt, - ) - - # check consistency - check_consistency( - loss, (LossInterface, _Loss, torch.nn.Module), subclass=False - ) - self._loss = loss - - def optimization_cycle(self, batch): - """ - The optimization cycle for the solvers. - - :param list[tuple[str, dict]] batch: A batch of data. Each element is a - tuple containing a condition name and a dictionary of points. - :return: The losses computed for all conditions in the batch, casted - to a subclass of :class:`torch.Tensor`. It should return a dict - containing the condition name and the associated scalar loss. - :rtype: dict - """ - condition_loss = {} - for condition_name, points in batch: - input_pts, output_pts = ( - points["input"], - points["target"], - ) - condition_loss[condition_name] = self.loss_data( - input_pts=input_pts, output_pts=output_pts - ) - return condition_loss - - def loss_data(self, input_pts, output_pts): - """ - Compute the data loss for the Supervised solver by evaluating the loss - between the network's output and the true solution. This method should - not be overridden, if not intentionally. - - :param input_pts: The input points to the neural network. - :type input_pts: LabelTensor | torch.Tensor - :param output_pts: The true solution to compare with the network's - output. - :type output_pts: LabelTensor | torch.Tensor - :return: The supervised loss, averaged over the number of observations. - :rtype: torch.Tensor - """ - return self._loss(self.forward(input_pts), output_pts) - - @property - def loss(self): - """ - The loss function to be minimized. - - :return: The loss function to be minimized. - :rtype: torch.nn.Module - """ - return self._loss diff --git a/pina/solver/supervised_solver/__init__.py b/pina/solver/supervised_solver/__init__.py new file mode 100644 index 0000000..f681d2d --- /dev/null +++ b/pina/solver/supervised_solver/__init__.py @@ -0,0 +1,11 @@ +"""Module for the Supervised solvers.""" + +__all__ = [ + "SupervisedSolverInterface", + "SupervisedSolver", + "ReducedOrderModelSolver", +] + +from .supervised_solver_interface import SupervisedSolverInterface +from .supervised import SupervisedSolver +from .reduced_order_model import ReducedOrderModelSolver diff --git a/pina/solver/reduced_order_model.py b/pina/solver/supervised_solver/reduced_order_model.py similarity index 83% rename from pina/solver/reduced_order_model.py rename to pina/solver/supervised_solver/reduced_order_model.py index 949cb01..727f438 100644 --- a/pina/solver/reduced_order_model.py +++ b/pina/solver/supervised_solver/reduced_order_model.py @@ -1,10 +1,11 @@ """Module for the Reduced Order Model solver""" import torch -from .supervised import SupervisedSolver +from .supervised_solver_interface import SupervisedSolverInterface +from ..solver import SingleSolverInterface -class ReducedOrderModelSolver(SupervisedSolver): +class ReducedOrderModelSolver(SupervisedSolverInterface, SingleSolverInterface): r""" Reduced Order Model solver class. This class implements the Reduced Order Model solver, using user specified ``reduction_network`` and @@ -50,6 +51,14 @@ class ReducedOrderModelSolver(SupervisedSolver): Journal of Computational Physics 363 (2018): 55-78. DOI `10.1016/j.jcp.2018.02.037 `_. + + Pichi, Federico, Beatriz Moya, and Jan S. + Hesthaven. + *A graph convolutional autoencoder approach to model order reduction + for parametrized PDEs.* + Journal of Computational Physics 501 (2024): 112762. + DOI `10.1016/j.jcp.2024.112762 + `_. .. note:: The specified ``reduction_network`` must contain two methods, namely @@ -63,15 +72,6 @@ class ReducedOrderModelSolver(SupervisedSolver): ``reduction_network`` and ``interpolation_network`` are trained simultaneously. For reference on this trainig strategy look at the following: - - ..seealso:: - **Original reference**: Pichi, Federico, Beatriz Moya, and Jan S. - Hesthaven. - *A graph convolutional autoencoder approach to model order reduction - for parametrized PDEs.* - Journal of Computational Physics 501 (2024): 112762. - DOI `10.1016/j.jcp.2024.112762 - `_. .. warning:: This solver works only for data-driven model. Hence in the ``problem`` @@ -102,16 +102,16 @@ class ReducedOrderModelSolver(SupervisedSolver): for interpolating the control parameters to latent space obtained by the ``reduction_network`` encoding. :param torch.nn.Module loss: The loss function to be minimized. - If `None`, the :class:`torch.nn.MSELoss` loss is used. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. Default is `None`. :param Optimizer optimizer: The optimizer to be used. - If `None`, the :class:`torch.optim.Adam` optimizer is used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler: Learning rate scheduler. - If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. - If `None`, no weighting schema is used. Default is ``None``. + If ``None``, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. Default is ``True``. """ @@ -153,39 +153,38 @@ class ReducedOrderModelSolver(SupervisedSolver): of the ``interpolation_network`` on the input, and maps it to output space by calling the decode methode of the ``reduction_network``. - :param x: Input tensor. - :type x: torch.Tensor | LabelTensor - :return: Solver solution. - :rtype: torch.Tensor | LabelTensor + :param x: The input to the neural network. + :type x: LabelTensor | torch.Tensor | Graph | Data + :return: The solver solution. + :rtype: LabelTensor | torch.Tensor | Graph | Data """ reduction_network = self.model["reduction_network"] interpolation_network = self.model["interpolation_network"] return reduction_network.decode(interpolation_network(x)) - def loss_data(self, input_pts, output_pts): + def loss_data(self, input, target): """ Compute the data loss by evaluating the loss between the network's output and the true solution. This method should not be overridden, if not intentionally. - :param LabelTensor input_pts: The input points to the neural network. - :param LabelTensor output_pts: The true solution to compare with the - network's output. + :param input: The input to the neural network. + :type input: LabelTensor | torch.Tensor | Graph | Data + :param target: The target to compare with the network's output. + :type target: LabelTensor | torch.Tensor | Graph | Data :return: The supervised loss, averaged over the number of observations. - :rtype: torch.Tensor + :rtype: LabelTensor | torch.Tensor | Graph | Data """ # extract networks reduction_network = self.model["reduction_network"] interpolation_network = self.model["interpolation_network"] # encoded representations loss - encode_repr_inter_net = interpolation_network(input_pts) - encode_repr_reduction_network = reduction_network.encode(output_pts) - loss_encode = self.loss( + encode_repr_inter_net = interpolation_network(input) + encode_repr_reduction_network = reduction_network.encode(target) + loss_encode = self._loss_fn( encode_repr_inter_net, encode_repr_reduction_network ) # reconstruction loss - loss_reconstruction = self.loss( - reduction_network.decode(encode_repr_reduction_network), output_pts - ) - + decode = reduction_network.decode(encode_repr_reduction_network) + loss_reconstruction = self._loss_fn(decode, target) return loss_encode + loss_reconstruction diff --git a/pina/solver/supervised_solver/supervised.py b/pina/solver/supervised_solver/supervised.py new file mode 100644 index 0000000..70cd8fe --- /dev/null +++ b/pina/solver/supervised_solver/supervised.py @@ -0,0 +1,85 @@ +"""Module for the Supervised solver.""" + +from .supervised_solver_interface import SupervisedSolverInterface +from ..solver import SingleSolverInterface + + +class SupervisedSolver(SupervisedSolverInterface, SingleSolverInterface): + r""" + Supervised Solver solver class. This class implements a Supervised Solver, + using a user specified ``model`` to solve a specific ``problem``. + + The Supervised Solver class aims to find a map between the input + :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output + :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. + + Given a model :math:`\mathcal{M}`, the following loss function is + minimized during training: + + .. math:: + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N + \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{s}_i)), + + where :math:`\mathcal{L}` is a specific loss function, typically the MSE: + + .. math:: + \mathcal{L}(v) = \| v \|^2_2. + + In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{s}_i` indicates + the will to approximate multiple (discretised) functions given multiple + (discretised) input functions. + """ + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=True, + ): + """ + Initialization of the :class:`SupervisedSolver` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The neural network model to be used. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. + Default is `None`. + :param Optimizer optimizer: The optimizer to be used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` + scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If ``None``, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. + Default is ``True``. + """ + super().__init__( + model=model, + problem=problem, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + def loss_data(self, input, target): + """ + Compute the data loss for the Supervised solver by evaluating the loss + between the network's output and the true solution. This method should + not be overridden, if not intentionally. + + :param input: The input to the neural network. + :type input: LabelTensor | torch.Tensor | Graph | Data + :param target: The target to compare with the network's output. + :type target: LabelTensor | torch.Tensor | Graph | Data + :return: The supervised loss, averaged over the number of observations. + :rtype: LabelTensor | torch.Tensor | Graph | Data + """ + return self._loss_fn(self.forward(input), target) diff --git a/pina/solver/supervised_solver/supervised_solver_interface.py b/pina/solver/supervised_solver/supervised_solver_interface.py new file mode 100644 index 0000000..97070ce --- /dev/null +++ b/pina/solver/supervised_solver/supervised_solver_interface.py @@ -0,0 +1,90 @@ +"""Module for the Supervised solver interface.""" + +from abc import abstractmethod + +import torch + +from torch.nn.modules.loss import _Loss +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...condition import InputTargetCondition + + +class SupervisedSolverInterface(SolverInterface): + r""" + Base class for Supervised solvers. This class implements a Supervised Solver + , using a user specified ``model`` to solve a specific ``problem``. + + The ``SupervisedSolverInterface`` class can be used to define + Supervised solvers that work with one or multiple optimizers and/or models. + By default, it is compatible with problems defined by + :class:`~pina.problem.abstract_problem.AbstractProblem`, + and users can choose the problem type the solver is meant to address. + """ + + accepted_conditions_types = InputTargetCondition + + def __init__(self, loss=None, **kwargs): + """ + Initialization of the :class:`SupervisedSolver` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. + Default is `None`. + :param kwargs: Additional keyword arguments to be passed to the + :class:`~pina.solver.solver.SolverInterface` class. + """ + if loss is None: + loss = torch.nn.MSELoss() + + super().__init__(**kwargs) + + # check consistency + check_consistency(loss, (LossInterface, _Loss), subclass=False) + + # assign variables + self._loss_fn = loss + + def optimization_cycle(self, batch): + """ + The optimization cycle for the solvers. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :return: The losses computed for all conditions in the batch, casted + to a subclass of :class:`torch.Tensor`. It should return a dict + containing the condition name and the associated scalar loss. + :rtype: dict + """ + condition_loss = {} + for condition_name, points in batch: + condition_loss[condition_name] = self.loss_data( + input=points["input"], target=points["target"] + ) + return condition_loss + + @abstractmethod + def loss_data(self, input, target): + """ + Compute the data loss for the Supervised. This method is abstract and + should be override by derived classes. + + :param input: The input to the neural network. + :type input: LabelTensor | torch.Tensor | Graph | Data + :param target: The target to compare with the network's output. + :type target: LabelTensor | torch.Tensor | Graph | Data + :return: The supervised loss, averaged over the number of observations. + :rtype: LabelTensor | torch.Tensor | Graph | Data + """ + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn diff --git a/pina/utils.py b/pina/utils.py index 56b329b..e3126de 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -72,18 +72,22 @@ def labelize_forward(forward, input_variables, output_variables): :rtype: Callable """ - def wrapper(x): + def wrapper(x, *args, **kwargs): """ Decorated forward function. :param LabelTensor x: The labelized input of the forward pass of an instance of :class:`torch.nn.Module`. + :param Iterable args: Additional positional arguments passed to + ``forward`` method. + :param dict kwargs: Additional keyword arguments passed to + ``forward`` method. :return: The labelized output of the forward pass of an instance of :class:`torch.nn.Module`. :rtype: LabelTensor """ x = x.extract(input_variables) - output = forward(x) + output = forward(x, *args, **kwargs) # keep it like this, directly using LabelTensor(...) raises errors # when compiling the code output = output.as_subclass(LabelTensor) diff --git a/tests/test_solver/test_causal_pinn.py b/tests/test_solver/test_causal_pinn.py index 4e72732..82e61ed 100644 --- a/tests/test_solver/test_causal_pinn.py +++ b/tests/test_solver/test_causal_pinn.py @@ -27,12 +27,12 @@ class DummySpatialProblem(SpatialProblem): # define problems problem = DiffusionReactionProblem() -problem.discretise_domain(50) +problem.discretise_domain(10) # add input-output condition to test supervised learning -input_pts = torch.rand(50, len(problem.input_variables)) +input_pts = torch.rand(10, len(problem.input_variables)) input_pts = LabelTensor(input_pts, problem.input_variables) -output_pts = torch.rand(50, len(problem.output_variables)) +output_pts = torch.rand(10, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) diff --git a/tests/test_solver/test_competitive_pinn.py b/tests/test_solver/test_competitive_pinn.py index 64fb280..741390e 100644 --- a/tests/test_solver/test_competitive_pinn.py +++ b/tests/test_solver/test_competitive_pinn.py @@ -19,9 +19,9 @@ from torch._dynamo.eval_frame import OptimizedModule # define problems problem = Poisson() -problem.discretise_domain(50) +problem.discretise_domain(10) inverse_problem = InversePoisson() -inverse_problem.discretise_domain(50) +inverse_problem.discretise_domain(10) # reduce the number of data points to speed up testing data_condition = inverse_problem.conditions["data"] @@ -29,9 +29,9 @@ data_condition.input = data_condition.input[:10] data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning -input_pts = torch.rand(50, len(problem.input_variables)) +input_pts = torch.rand(10, len(problem.input_variables)) input_pts = LabelTensor(input_pts, problem.input_variables) -output_pts = torch.rand(50, len(problem.output_variables)) +output_pts = torch.rand(10, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) diff --git a/tests/test_solver/test_ensemble_pinn.py b/tests/test_solver/test_ensemble_pinn.py new file mode 100644 index 0000000..50669f0 --- /dev/null +++ b/tests/test_solver/test_ensemble_pinn.py @@ -0,0 +1,149 @@ +import pytest +import torch + +from pina import LabelTensor, Condition +from pina.model import FeedForward +from pina.trainer import Trainer +from pina.solver import DeepEnsemblePINN +from pina.condition import ( + InputTargetCondition, + InputEquationCondition, + DomainEquationCondition, +) +from pina.problem.zoo import Poisson2DSquareProblem as Poisson +from torch._dynamo.eval_frame import OptimizedModule + + +# define problems +problem = Poisson() +problem.discretise_domain(10) + +# add input-output condition to test supervised learning +input_pts = torch.rand(10, len(problem.input_variables)) +input_pts = LabelTensor(input_pts, problem.input_variables) +output_pts = torch.rand(10, len(problem.output_variables)) +output_pts = LabelTensor(output_pts, problem.output_variables) +problem.conditions["data"] = Condition(input=input_pts, target=output_pts) + +# define models +models = [ + FeedForward( + len(problem.input_variables), len(problem.output_variables), n_layers=1 + ) + for _ in range(5) +] + + +def test_constructor(): + solver = DeepEnsemblePINN(problem=problem, models=models) + + assert solver.accepted_conditions_types == ( + InputTargetCondition, + InputEquationCondition, + DomainEquationCondition, + ) + assert solver.num_ensemble == 5 + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("compile", [True, False]) +def test_solver_train(batch_size, compile): + solver = DeepEnsemblePINN(models=models, problem=problem) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + val_size=0.0, + test_size=0.0, + compile=compile, + ) + trainer.train() + if trainer.compile: + assert all( + [isinstance(model, OptimizedModule) for model in solver.models] + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("compile", [True, False]) +def test_solver_validation(batch_size, compile): + solver = DeepEnsemblePINN(models=models, problem=problem) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.9, + val_size=0.1, + test_size=0.0, + compile=compile, + ) + trainer.train() + if trainer.compile: + assert all( + [isinstance(model, OptimizedModule) for model in solver.models] + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("compile", [True, False]) +def test_solver_test(batch_size, compile): + solver = DeepEnsemblePINN(models=models, problem=problem) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.7, + val_size=0.2, + test_size=0.1, + compile=compile, + ) + trainer.test() + if trainer.compile: + assert all( + [isinstance(model, OptimizedModule) for model in solver.models] + ) + + +def test_train_load_restore(): + dir = "tests/test_solver/tmp" + solver = DeepEnsemblePINN(models=models, problem=problem) + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + batch_size=None, + train_size=0.7, + val_size=0.2, + test_size=0.1, + default_root_dir=dir, + ) + trainer.train() + + # restore + new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + new_trainer.train( + ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/" + + "epoch=4-step=5.ckpt" + ) + + # loading + new_solver = DeepEnsemblePINN.load_from_checkpoint( + f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt", + problem=problem, + models=models, + ) + + test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables) + assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape + torch.testing.assert_close( + new_solver.forward(test_pts), solver.forward(test_pts) + ) + + # rm directories + import shutil + + shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_ensemble_supervised_solver.py b/tests/test_solver/test_ensemble_supervised_solver.py new file mode 100644 index 0000000..45d853f --- /dev/null +++ b/tests/test_solver/test_ensemble_supervised_solver.py @@ -0,0 +1,275 @@ +import torch +import pytest +from torch._dynamo.eval_frame import OptimizedModule +from torch_geometric.nn import GCNConv +from pina import Condition, LabelTensor +from pina.condition import InputTargetCondition +from pina.problem import AbstractProblem +from pina.solver import DeepEnsembleSupervisedSolver +from pina.model import FeedForward +from pina.trainer import Trainer +from pina.graph import KNNGraph + + +class LabelTensorProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data": Condition( + input=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]), + target=LabelTensor(torch.randn(20, 1), ["u"]), + ), + } + + +class TensorProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data": Condition(input=torch.randn(20, 2), target=torch.randn(20, 1)) + } + + +x = torch.rand((15, 20, 5)) +pos = torch.rand((15, 20, 2)) +output_ = torch.rand((15, 20, 1)) +input_ = [ + KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) + for x_, pos_ in zip(x, pos) +] + + +class GraphProblem(AbstractProblem): + output_variables = None + conditions = {"data": Condition(input=input_, target=output_)} + + +x = LabelTensor(torch.rand((15, 20, 5)), ["a", "b", "c", "d", "e"]) +pos = LabelTensor(torch.rand((15, 20, 2)), ["x", "y"]) +output_ = LabelTensor(torch.rand((15, 20, 1)), ["u"]) +input_ = [ + KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True) + for i in range(len(x)) +] + + +class GraphProblemLT(AbstractProblem): + output_variables = ["u"] + input_variables = ["a", "b", "c", "d", "e"] + conditions = {"data": Condition(input=input_, target=output_)} + + +models = [FeedForward(2, 1) for i in range(10)] + + +class Models(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lift = torch.nn.Linear(5, 10) + self.activation = torch.nn.Tanh() + self.output = torch.nn.Linear(10, 1) + + self.conv = GCNConv(10, 10) + + def forward(self, batch): + + x = batch.x + edge_index = batch.edge_index + for _ in range(1): + y = self.lift(x) + y = self.activation(y) + y = self.conv(y, edge_index) + y = self.activation(y) + y = self.output(y) + return y + + +graph_models = [Models() for i in range(10)] + + +def test_constructor(): + solver = DeepEnsembleSupervisedSolver( + problem=TensorProblem(), models=models + ) + DeepEnsembleSupervisedSolver(problem=LabelTensorProblem(), models=models) + assert DeepEnsembleSupervisedSolver.accepted_conditions_types == ( + InputTargetCondition + ) + assert solver.num_ensemble == 10 + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +def test_solver_train(use_lt, batch_size, compile): + problem = LabelTensorProblem() if use_lt else TensorProblem() + solver = DeepEnsembleSupervisedSolver( + problem=problem, models=models, use_lt=use_lt + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + test_size=0.0, + val_size=0.0, + compile=compile, + ) + + trainer.train() + if trainer.compile: + assert all( + [isinstance(model, OptimizedModule) for model in solver.models] + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_train_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = DeepEnsembleSupervisedSolver( + problem=problem, models=graph_models, use_lt=use_lt + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + test_size=0.0, + val_size=0.0, + ) + + trainer.train() + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +def test_solver_validation(use_lt, compile): + problem = LabelTensorProblem() if use_lt else TensorProblem() + solver = DeepEnsembleSupervisedSolver( + problem=problem, models=models, use_lt=use_lt + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=None, + train_size=0.9, + val_size=0.1, + test_size=0.0, + compile=compile, + ) + trainer.train() + if trainer.compile: + assert all( + [isinstance(model, OptimizedModule) for model in solver.models] + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_validation_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = DeepEnsembleSupervisedSolver( + problem=problem, models=graph_models, use_lt=use_lt + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.9, + val_size=0.1, + test_size=0.0, + ) + + trainer.train() + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +def test_solver_test(use_lt, compile): + problem = LabelTensorProblem() if use_lt else TensorProblem() + solver = DeepEnsembleSupervisedSolver( + problem=problem, models=models, use_lt=use_lt + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=None, + train_size=0.8, + val_size=0.1, + test_size=0.1, + compile=compile, + ) + trainer.test() + if trainer.compile: + assert all( + [isinstance(model, OptimizedModule) for model in solver.models] + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_test_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = DeepEnsembleSupervisedSolver( + problem=problem, models=graph_models, use_lt=use_lt + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.8, + val_size=0.1, + test_size=0.1, + ) + + trainer.test() + + +def test_train_load_restore(): + dir = "tests/test_solver/tmp/" + problem = LabelTensorProblem() + solver = DeepEnsembleSupervisedSolver(problem=problem, models=models) + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + batch_size=None, + train_size=0.9, + test_size=0.1, + val_size=0.0, + default_root_dir=dir, + ) + trainer.train() + + # restore + new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + new_trainer.train( + ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/" + + "epoch=4-step=5.ckpt" + ) + + # loading + new_solver = DeepEnsembleSupervisedSolver.load_from_checkpoint( + f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt", + problem=problem, + models=models, + ) + + test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables) + assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape + torch.testing.assert_close( + new_solver.forward(test_pts), solver.forward(test_pts) + ) + + # rm directories + import shutil + + shutil.rmtree("tests/test_solver/tmp") diff --git a/tests/test_solver/test_garom.py b/tests/test_solver/test_garom.py index ed147c8..6257582 100644 --- a/tests/test_solver/test_garom.py +++ b/tests/test_solver/test_garom.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import pytest -from pina import Condition, LabelTensor +from pina import Condition from pina.solver import GAROM from pina.condition import InputTargetCondition from pina.problem import AbstractProblem @@ -15,7 +15,7 @@ class TensorProblem(AbstractProblem): input_variables = ["u_0", "u_1"] output_variables = ["u"] conditions = { - "data": Condition(target=torch.randn(50, 2), input=torch.randn(50, 1)) + "data": Condition(target=torch.randn(10, 2), input=torch.randn(10, 1)) } diff --git a/tests/test_solver/test_gradient_pinn.py b/tests/test_solver/test_gradient_pinn.py index 31666db..6e6c76c 100644 --- a/tests/test_solver/test_gradient_pinn.py +++ b/tests/test_solver/test_gradient_pinn.py @@ -30,9 +30,9 @@ class DummyTimeProblem(TimeDependentProblem): # define problems problem = Poisson() -problem.discretise_domain(50) +problem.discretise_domain(10) inverse_problem = InversePoisson() -inverse_problem.discretise_domain(50) +inverse_problem.discretise_domain(10) # reduce the number of data points to speed up testing data_condition = inverse_problem.conditions["data"] @@ -40,9 +40,9 @@ data_condition.input = data_condition.input[:10] data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning -input_pts = torch.rand(50, len(problem.input_variables)) +input_pts = torch.rand(10, len(problem.input_variables)) input_pts = LabelTensor(input_pts, problem.input_variables) -output_pts = torch.rand(50, len(problem.output_variables)) +output_pts = torch.rand(10, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) diff --git a/tests/test_solver/test_pinn.py b/tests/test_solver/test_pinn.py index 97511cb..ee501d8 100644 --- a/tests/test_solver/test_pinn.py +++ b/tests/test_solver/test_pinn.py @@ -19,9 +19,9 @@ from torch._dynamo.eval_frame import OptimizedModule # define problems problem = Poisson() -problem.discretise_domain(50) +problem.discretise_domain(10) inverse_problem = InversePoisson() -inverse_problem.discretise_domain(50) +inverse_problem.discretise_domain(10) # reduce the number of data points to speed up testing data_condition = inverse_problem.conditions["data"] @@ -29,9 +29,9 @@ data_condition.input = data_condition.input[:10] data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning -input_pts = torch.rand(50, len(problem.input_variables)) +input_pts = torch.rand(10, len(problem.input_variables)) input_pts = LabelTensor(input_pts, problem.input_variables) -output_pts = torch.rand(50, len(problem.output_variables)) +output_pts = torch.rand(10, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) diff --git a/tests/test_solver/test_rba_pinn.py b/tests/test_solver/test_rba_pinn.py index f355aab..8eaf340 100644 --- a/tests/test_solver/test_rba_pinn.py +++ b/tests/test_solver/test_rba_pinn.py @@ -18,9 +18,9 @@ from torch._dynamo.eval_frame import OptimizedModule # define problems problem = Poisson() -problem.discretise_domain(50) +problem.discretise_domain(10) inverse_problem = InversePoisson() -inverse_problem.discretise_domain(50) +inverse_problem.discretise_domain(10) # reduce the number of data points to speed up testing data_condition = inverse_problem.conditions["data"] @@ -28,9 +28,9 @@ data_condition.input = data_condition.input[:10] data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning -input_pts = torch.rand(50, len(problem.input_variables)) +input_pts = torch.rand(10, len(problem.input_variables)) input_pts = LabelTensor(input_pts, problem.input_variables) -output_pts = torch.rand(50, len(problem.output_variables)) +output_pts = torch.rand(10, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) diff --git a/tests/test_solver/test_self_adaptive_pinn.py b/tests/test_solver/test_self_adaptive_pinn.py index 48e3d9f..aba43da 100644 --- a/tests/test_solver/test_self_adaptive_pinn.py +++ b/tests/test_solver/test_self_adaptive_pinn.py @@ -19,9 +19,9 @@ from torch._dynamo.eval_frame import OptimizedModule # define problems problem = Poisson() -problem.discretise_domain(50) +problem.discretise_domain(10) inverse_problem = InversePoisson() -inverse_problem.discretise_domain(50) +inverse_problem.discretise_domain(10) # reduce the number of data points to speed up testing data_condition = inverse_problem.conditions["data"] @@ -29,9 +29,9 @@ data_condition.input = data_condition.input[:10] data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning -input_pts = torch.rand(50, len(problem.input_variables)) +input_pts = torch.rand(10, len(problem.input_variables)) input_pts = LabelTensor(input_pts, problem.input_variables) -output_pts = torch.rand(50, len(problem.output_variables)) +output_pts = torch.rand(10, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 30ae080..7578ace 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -30,9 +30,9 @@ class TensorProblem(AbstractProblem): } -x = torch.rand((100, 20, 5)) -pos = torch.rand((100, 20, 2)) -output_ = torch.rand((100, 20, 1)) +x = torch.rand((15, 20, 5)) +pos = torch.rand((15, 20, 2)) +output_ = torch.rand((15, 20, 1)) input_ = [ KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) for x_, pos_ in zip(x, pos) @@ -44,9 +44,9 @@ class GraphProblem(AbstractProblem): conditions = {"data": Condition(input=input_, target=output_)} -x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"]) -pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"]) -output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"]) +x = LabelTensor(torch.rand((15, 20, 5)), ["a", "b", "c", "d", "e"]) +pos = LabelTensor(torch.rand((15, 20, 2)), ["x", "y"]) +output_ = LabelTensor(torch.rand((15, 20, 1)), ["u"]) input_ = [ KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True) for i in range(len(x))