move to method to the interface (#694)

This commit is contained in:
Giovanni Canali
2025-11-04 10:56:27 +01:00
committed by GitHub
parent fca3db7926
commit d4fa3ea9df
3 changed files with 36 additions and 31 deletions

View File

@@ -1,7 +1,6 @@
"""Module for the Equation."""
import inspect
import torch
from .equation_interface import EquationInterface
@@ -61,33 +60,3 @@ class Equation(EquationInterface):
f"Unexpected number of arguments in equation: {self.__len_sig}. "
"Expected either 2 (direct problem) or 3 (inverse problem)."
)
def to(self, device):
"""
Move all tensor attributes of the Equation to the specified device.
:param torch.device device: The target device to move the tensors to.
:return: The Equation instance moved to the specified device.
:rtype: Equation
"""
# Iterate over all attributes of the Equation
for key, val in self.__dict__.items():
# Move tensors in dictionaries to the specified device
if isinstance(val, dict):
self.__dict__[key] = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in val.items()
}
# Move tensors in lists to the specified device
elif isinstance(val, list):
self.__dict__[key] = [
v.to(device) if torch.is_tensor(v) else v for v in val
]
# Move tensor attributes to the specified device
elif torch.is_tensor(val):
self.__dict__[key] = val.to(device)
return self

View File

@@ -1,6 +1,7 @@
"""Module for the Equation Interface."""
from abc import ABCMeta, abstractmethod
import torch
class EquationInterface(metaclass=ABCMeta):
@@ -33,3 +34,33 @@ class EquationInterface(metaclass=ABCMeta):
:return: The computed residual of the equation.
:rtype: LabelTensor
"""
def to(self, device):
"""
Move all tensor attributes to the specified device.
:param torch.device device: The target device to move the tensors to.
:return: The instance moved to the specified device.
:rtype: EquationInterface
"""
# Iterate over all attributes of the Equation
for key, val in self.__dict__.items():
# Move tensors in dictionaries to the specified device
if isinstance(val, dict):
self.__dict__[key] = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in val.items()
}
# Move tensors in lists to the specified device
elif isinstance(val, list):
self.__dict__[key] = [
v.to(device) if torch.is_tensor(v) else v for v in val
]
# Move tensor attributes to the specified device
elif torch.is_tensor(val):
self.__dict__[key] = val.to(device)
return self

View File

@@ -101,6 +101,10 @@ class SystemEquation(EquationInterface):
:return: The aggregated residuals of the system of equations.
:rtype: LabelTensor
"""
# Move the equation to the input_ device
self.to(input_.device)
# Compute the residual for each equation
residual = torch.hstack(
[
equation.residual(input_, output_, params_)
@@ -108,6 +112,7 @@ class SystemEquation(EquationInterface):
]
)
# Skip reduction if not specified
if self.reduction is None:
return residual