From d4fa3ea9df780e6f94ab48a4615c7e25c3f4bbbb Mon Sep 17 00:00:00 2001 From: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:56:27 +0100 Subject: [PATCH] move to method to the interface (#694) --- pina/equation/equation.py | 31 ----------------------------- pina/equation/equation_interface.py | 31 +++++++++++++++++++++++++++++ pina/equation/system_equation.py | 5 +++++ 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/pina/equation/equation.py b/pina/equation/equation.py index 677b0e5..057c6bc 100644 --- a/pina/equation/equation.py +++ b/pina/equation/equation.py @@ -1,7 +1,6 @@ """Module for the Equation.""" import inspect -import torch from .equation_interface import EquationInterface @@ -61,33 +60,3 @@ class Equation(EquationInterface): f"Unexpected number of arguments in equation: {self.__len_sig}. " "Expected either 2 (direct problem) or 3 (inverse problem)." ) - - def to(self, device): - """ - Move all tensor attributes of the Equation to the specified device. - - :param torch.device device: The target device to move the tensors to. - :return: The Equation instance moved to the specified device. - :rtype: Equation - """ - # Iterate over all attributes of the Equation - for key, val in self.__dict__.items(): - - # Move tensors in dictionaries to the specified device - if isinstance(val, dict): - self.__dict__[key] = { - k: v.to(device) if torch.is_tensor(v) else v - for k, v in val.items() - } - - # Move tensors in lists to the specified device - elif isinstance(val, list): - self.__dict__[key] = [ - v.to(device) if torch.is_tensor(v) else v for v in val - ] - - # Move tensor attributes to the specified device - elif torch.is_tensor(val): - self.__dict__[key] = val.to(device) - - return self diff --git a/pina/equation/equation_interface.py b/pina/equation/equation_interface.py index f1cc747..82b86db 100644 --- a/pina/equation/equation_interface.py +++ b/pina/equation/equation_interface.py @@ -1,6 +1,7 @@ """Module for the Equation Interface.""" from abc import ABCMeta, abstractmethod +import torch class EquationInterface(metaclass=ABCMeta): @@ -33,3 +34,33 @@ class EquationInterface(metaclass=ABCMeta): :return: The computed residual of the equation. :rtype: LabelTensor """ + + def to(self, device): + """ + Move all tensor attributes to the specified device. + + :param torch.device device: The target device to move the tensors to. + :return: The instance moved to the specified device. + :rtype: EquationInterface + """ + # Iterate over all attributes of the Equation + for key, val in self.__dict__.items(): + + # Move tensors in dictionaries to the specified device + if isinstance(val, dict): + self.__dict__[key] = { + k: v.to(device) if torch.is_tensor(v) else v + for k, v in val.items() + } + + # Move tensors in lists to the specified device + elif isinstance(val, list): + self.__dict__[key] = [ + v.to(device) if torch.is_tensor(v) else v for v in val + ] + + # Move tensor attributes to the specified device + elif torch.is_tensor(val): + self.__dict__[key] = val.to(device) + + return self diff --git a/pina/equation/system_equation.py b/pina/equation/system_equation.py index 21cb271..3e8550d 100644 --- a/pina/equation/system_equation.py +++ b/pina/equation/system_equation.py @@ -101,6 +101,10 @@ class SystemEquation(EquationInterface): :return: The aggregated residuals of the system of equations. :rtype: LabelTensor """ + # Move the equation to the input_ device + self.to(input_.device) + + # Compute the residual for each equation residual = torch.hstack( [ equation.residual(input_, output_, params_) @@ -108,6 +112,7 @@ class SystemEquation(EquationInterface): ] ) + # Skip reduction if not specified if self.reduction is None: return residual