import torch import torch_geometric from .condition_interface import ConditionInterface from ..label_tensor import LabelTensor from ..graph import Graph from ..utils import check_consistency class InputOutputPointsCondition(ConditionInterface): """ Condition for domain/equation data. This condition must be used every time a Physics Informed or a Supervised Loss is needed in the Solver. """ __slots__ = ["input_points", "output_points"] def __init__(self, input_points, output_points): """ TODO : add docstring """ super().__init__() self.input_points = input_points self.output_points = output_points def __setattr__(self, key, value): if (key == 'input_points') or (key == 'output_points'): check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data)) InputOutputPointsCondition.__dict__[key].__set__(self, value) elif key in ('_problem', '_condition_type'): super().__setattr__(key, value)