diff --git a/pina/equation/equation.py b/pina/equation/equation.py index 1e4622d..677b0e5 100644 --- a/pina/equation/equation.py +++ b/pina/equation/equation.py @@ -1,7 +1,7 @@ """Module for the Equation.""" import inspect - +import torch from .equation_interface import EquationInterface @@ -49,6 +49,10 @@ class Equation(EquationInterface): :raises RuntimeError: If the underlying equation signature length is not 2 (direct problem) or 3 (inverse problem). """ + # Move the equation to the input_ device + self.to(input_.device) + + # Call the underlying equation based on its signature length if self.__len_sig == 2: return self.__equation(input_, output_) if self.__len_sig == 3: @@ -57,3 +61,33 @@ class Equation(EquationInterface): f"Unexpected number of arguments in equation: {self.__len_sig}. " "Expected either 2 (direct problem) or 3 (inverse problem)." ) + + def to(self, device): + """ + Move all tensor attributes of the Equation to the specified device. + + :param torch.device device: The target device to move the tensors to. + :return: The Equation instance moved to the specified device. + :rtype: Equation + """ + # Iterate over all attributes of the Equation + for key, val in self.__dict__.items(): + + # Move tensors in dictionaries to the specified device + if isinstance(val, dict): + self.__dict__[key] = { + k: v.to(device) if torch.is_tensor(v) else v + for k, v in val.items() + } + + # Move tensors in lists to the specified device + elif isinstance(val, list): + self.__dict__[key] = [ + v.to(device) if torch.is_tensor(v) else v for v in val + ] + + # Move tensor attributes to the specified device + elif torch.is_tensor(val): + self.__dict__[key] = val.to(device) + + return self