move equation attributes to correct device

This commit is contained in:
GiovanniCanali
2025-10-28 13:07:52 +01:00
committed by Dario Coscia
parent 9c3e55da91
commit 24d806b262

View File

@@ -1,7 +1,7 @@
"""Module for the Equation.""" """Module for the Equation."""
import inspect import inspect
import torch
from .equation_interface import EquationInterface from .equation_interface import EquationInterface
@@ -49,6 +49,10 @@ class Equation(EquationInterface):
:raises RuntimeError: If the underlying equation signature length is not :raises RuntimeError: If the underlying equation signature length is not
2 (direct problem) or 3 (inverse problem). 2 (direct problem) or 3 (inverse problem).
""" """
# Move the equation to the input_ device
self.to(input_.device)
# Call the underlying equation based on its signature length
if self.__len_sig == 2: if self.__len_sig == 2:
return self.__equation(input_, output_) return self.__equation(input_, output_)
if self.__len_sig == 3: if self.__len_sig == 3:
@@ -57,3 +61,33 @@ class Equation(EquationInterface):
f"Unexpected number of arguments in equation: {self.__len_sig}. " f"Unexpected number of arguments in equation: {self.__len_sig}. "
"Expected either 2 (direct problem) or 3 (inverse problem)." "Expected either 2 (direct problem) or 3 (inverse problem)."
) )
def to(self, device):
"""
Move all tensor attributes of the Equation to the specified device.
:param torch.device device: The target device to move the tensors to.
:return: The Equation instance moved to the specified device.
:rtype: Equation
"""
# Iterate over all attributes of the Equation
for key, val in self.__dict__.items():
# Move tensors in dictionaries to the specified device
if isinstance(val, dict):
self.__dict__[key] = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in val.items()
}
# Move tensors in lists to the specified device
elif isinstance(val, list):
self.__dict__[key] = [
v.to(device) if torch.is_tensor(v) else v for v in val
]
# Move tensor attributes to the specified device
elif torch.is_tensor(val):
self.__dict__[key] = val.to(device)
return self