From a0cbf1c44a08f2f2f80e5a9f181a27bc96a1c8ee Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Fri, 7 Mar 2025 11:24:09 +0100 Subject: [PATCH] Improve conditions and refactor dataset classes (#475) * Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia --- .pylintrc | 5 +- code_formatter.sh | 54 +--- pina/collector.py | 2 +- pina/condition/__init__.py | 37 ++- pina/condition/condition.py | 88 +++-- pina/condition/condition_interface.py | 82 ++++- pina/condition/data_condition.py | 76 ++++- pina/condition/domain_equation_condition.py | 12 +- pina/condition/input_equation_condition.py | 103 +++++- pina/condition/input_output_condition.py | 34 -- pina/condition/input_target_condition.py | 121 +++++++ pina/data/data_module.py | 16 +- pina/data/dataset.py | 303 +++++++----------- pina/graph.py | 45 ++- pina/label_tensor.py | 17 + pina/problem/abstract_problem.py | 5 +- .../problem/zoo/inverse_diffusion_reaction.py | 4 +- pina/problem/zoo/inverse_poisson_2d_square.py | 4 +- pina/problem/zoo/supervised_problem.py | 6 +- pina/solver/garom.py | 20 +- .../physic_informed_solver/pinn_interface.py | 16 +- .../self_adaptive_pinn.py | 2 +- pina/solver/reduced_order_model.py | 4 +- pina/solver/solver.py | 2 +- pina/solver/supervised.py | 8 +- tests/test_collector.py | 28 +- tests/test_condition.py | 147 +++++++-- tests/test_data/test_data_module.py | 58 ++-- tests/test_data/test_graph_dataset.py | 36 +-- tests/test_data/test_tensor_dataset.py | 42 +-- .../test_supervised_problem.py | 12 +- tests/test_solver/test_causal_pinn.py | 12 +- tests/test_solver/test_competitive_pinn.py | 12 +- tests/test_solver/test_garom.py | 8 +- tests/test_solver/test_gradient_pinn.py | 12 +- tests/test_solver/test_pinn.py | 12 +- tests/test_solver/test_rba_pinn.py | 12 +- .../test_reduced_order_model_solver.py | 12 +- tests/test_solver/test_self_adaptive_pinn.py | 12 +- tests/test_solver/test_supervised_solver.py | 12 +- 40 files changed, 943 insertions(+), 550 deletions(-) delete mode 100644 pina/condition/input_output_condition.py create mode 100644 pina/condition/input_target_condition.py diff --git a/.pylintrc b/.pylintrc index ba14ad8..b9702cc 100644 --- a/.pylintrc +++ b/.pylintrc @@ -214,7 +214,6 @@ logging-modules=logging [FORMAT] - # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= @@ -250,6 +249,8 @@ single-line-if-stmt=no [BASIC] +# Allow redefinition of input builtins +allowed-redefined-builtins=input # Naming hint for argument names argument-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ @@ -401,7 +402,7 @@ max-returns=6 max-statements=50 # Minimum number of public methods for a class (see R0903). -min-public-methods=2 +min-public-methods=0 [CLASSES] diff --git a/code_formatter.sh b/code_formatter.sh index 6dacf15..d638d35 100644 --- a/code_formatter.sh +++ b/code_formatter.sh @@ -2,51 +2,19 @@ ####################################### -required_command="yapf unexpand" -code_directories="pina tests" +required_command="black" +code_directories=("pina" "tests") ####################################### -usage() { - echo - echo -e "\tUsage: $0 [files]" - echo - echo -e "\tIf not files are specified, script formats all ".py" files" - echo -e "\tin code directories ($code_directories); otherwise, formats" - echo -e "\tall given files" - echo - echo -e "\tRequired command: $required_command" - echo - exit 0 -} - - -[[ $1 == "-h" ]] && usage - # Test for required program -for comm in $required_command; do - command -v $comm >/dev/null 2>&1 || { - echo "I require $comm but it's not installed. Aborting." >&2; - exit 1 - } -done +if ! command -v $required_command >/dev/null 2>&1; then + echo "I require $required_command but it's not installed. Install dev dependencies." + echo "Aborting." >&2 + exit 1 +fi -# Find all python files in code directories -python_files="" -for dir in $code_directories; do - python_files="$python_files $(find $dir -name '*.py')" -done -[[ $# != 0 ]] && python_files=$@ - - -# Here the important part: yapf format the files. -for file in $python_files; do - echo "Making beatiful $file..." - [[ ! -f $file ]] && echo "$file does not exist; $0 -h for more info" && exit - - yapf --style='{ - based_on_style: pep8, - indent_width: 4, - column_limit: 80 - }' -i $file -done +# Run black formatter +for dir in "${code_directories[@]}"; do + python -m black --line-length 80 "$dir" +done \ No newline at end of file diff --git a/pina/collector.py b/pina/collector.py index c8e8160..b784421 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -76,6 +76,6 @@ class Collector: samples = self.problem.discretised_domains[condition.domain] self.data_collections[condition_name] = { - "input_points": samples, + "input": samples, "equation": condition.equation, } diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 3893e34..36f4011 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -1,12 +1,41 @@ +""" +Module for conditions. +""" + __all__ = [ "Condition", "ConditionInterface", "DomainEquationCondition", - "InputPointsEquationCondition", - "InputOutputPointsCondition", + "InputTargetCondition", + "TensorInputTensorTargetCondition", + "TensorInputGraphTargetCondition", + "GraphInputTensorTargetCondition", + "GraphInputGraphTargetCondition", + "InputEquationCondition", + "InputTensorEquationCondition", + "InputGraphEquationCondition", + "DataCondition", + "GraphDataCondition", + "TensorDataCondition", ] from .condition_interface import ConditionInterface +from .condition import Condition from .domain_equation_condition import DomainEquationCondition -from .input_equation_condition import InputPointsEquationCondition -from .input_output_condition import InputOutputPointsCondition +from .input_target_condition import ( + InputTargetCondition, + TensorInputTensorTargetCondition, + TensorInputGraphTargetCondition, + GraphInputTensorTargetCondition, + GraphInputGraphTargetCondition, +) +from .input_equation_condition import ( + InputEquationCondition, + InputTensorEquationCondition, + InputGraphEquationCondition, +) +from .data_condition import ( + DataCondition, + GraphDataCondition, + TensorDataCondition, +) diff --git a/pina/condition/condition.py b/pina/condition/condition.py index e01db1f..53744b4 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -1,10 +1,10 @@ """Condition module.""" -from .domain_equation_condition import DomainEquationCondition -from .input_equation_condition import InputPointsEquationCondition -from .input_output_condition import InputOutputPointsCondition -from .data_condition import DataConditionInterface import warnings +from .data_condition import DataCondition +from .domain_equation_condition import DomainEquationCondition +from .input_equation_condition import InputEquationCondition +from .input_target_condition import InputTargetCondition from ..utils import custom_warning_format # Set the custom format for warnings @@ -12,6 +12,21 @@ warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=DeprecationWarning) +def warning_function(new, old): + """Handle the deprecation warning. + + :param new: Object to use instead of the old one. + :type new: str + :param old: Object to deprecate. + :type old: str + """ + warnings.warn( + f"'{old}' is deprecated and will be removed " + f"in future versions. Please use '{new}' instead.", + DeprecationWarning, + ) + + class Condition: """ The class ``Condition`` is used to represent the constraints (physical @@ -40,16 +55,32 @@ class Condition: Example:: - >>> TODO + >>> from pina import Condition + >>> condition = Condition( + ... input=input, + ... target=target + ... ) + >>> condition = Condition( + ... domain=location, + ... equation=equation + ... ) + >>> condition = Condition( + ... input=input, + ... equation=equation + ... ) + >>> condition = Condition( + ... input=data, + ... conditional_variables=conditional_variables + ... ) """ __slots__ = list( set( - InputOutputPointsCondition.__slots__ - + InputPointsEquationCondition.__slots__ + InputTargetCondition.__slots__ + + InputEquationCondition.__slots__ + DomainEquationCondition.__slots__ - + DataConditionInterface.__slots__ + + DataCondition.__slots__ ) ) @@ -62,25 +93,30 @@ class Condition: ) # back-compatibility 0.1 - if "location" in kwargs.keys(): + keys = list(kwargs.keys()) + if "location" in keys: kwargs["domain"] = kwargs.pop("location") - warnings.warn( - f"'location' is deprecated and will be removed " - f"in future versions. Please use 'domain' instead.", - DeprecationWarning, - ) + warning_function(new="domain", old="location") + + if "input_points" in keys: + kwargs["input"] = kwargs.pop("input_points") + warning_function(new="input", old="input_points") + + if "output_points" in keys: + kwargs["target"] = kwargs.pop("output_points") + warning_function(new="target", old="output_points") sorted_keys = sorted(kwargs.keys()) - - if sorted_keys == sorted(InputOutputPointsCondition.__slots__): - return InputOutputPointsCondition(**kwargs) - elif sorted_keys == sorted(InputPointsEquationCondition.__slots__): - return InputPointsEquationCondition(**kwargs) - elif sorted_keys == sorted(DomainEquationCondition.__slots__): + if sorted_keys == sorted(InputTargetCondition.__slots__): + return InputTargetCondition(**kwargs) + if sorted_keys == sorted(InputEquationCondition.__slots__): + return InputEquationCondition(**kwargs) + if sorted_keys == sorted(DomainEquationCondition.__slots__): return DomainEquationCondition(**kwargs) - elif sorted_keys == sorted(DataConditionInterface.__slots__): - return DataConditionInterface(**kwargs) - elif sorted_keys == DataConditionInterface.__slots__[0]: - return DataConditionInterface(**kwargs) - else: - raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") + if ( + sorted_keys == sorted(DataCondition.__slots__) + or sorted_keys[0] == DataCondition.__slots__[0] + ): + return DataCondition(**kwargs) + + raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.") diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index a9d62fd..4d748c3 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,34 +1,84 @@ +""" +Module that defines the ConditionInterface class. +""" + from abc import ABCMeta +from torch_geometric.data import Data +from ..label_tensor import LabelTensor +from ..graph import Graph class ConditionInterface(metaclass=ABCMeta): + """ + Abstract class which defines a common interface for all the conditions. + """ - condition_types = ["physics", "supervised", "unsupervised"] - - def __init__(self, *args, **kwargs): - self._condition_type = None + def __init__(self): self._problem = None @property def problem(self): + """ + Return the problem to which the condition is associated. + + :return: Problem to which the condition is associated + :rtype: pina.problem.AbstractProblem + """ return self._problem @problem.setter def problem(self, value): self._problem = value - @property - def condition_type(self): - return self._condition_type + @staticmethod + def _check_graph_list_consistency(data_list): - @condition_type.setter - def condition_type(self, values): - if not isinstance(values, (list, tuple)): - values = [values] - for value in values: - if value not in ConditionInterface.condition_types: + # If the data is a Graph or Data object, return (do not need to check + # anything) + if isinstance(data_list, (Graph, Data)): + return + + # check all elements in the list are of the same type + if not all(isinstance(i, (Graph, Data)) for i in data_list): + raise ValueError( + "Invalid input types. " + "Please provide either Data or Graph objects." + ) + data = data_list[0] + # Store the keys of the first element in the list + keys = sorted(list(data.keys())) + + # Store the type of each tensor inside first element Data/Graph object + data_types = {name: tensor.__class__ for name, tensor in data.items()} + + # Store the labels of each LabelTensor inside first element Data/Graph + # object + labels = { + name: tensor.labels + for name, tensor in data.items() + if isinstance(tensor, LabelTensor) + } + + # Iterate over the list of Data/Graph objects + for data in data_list[1:]: + # Check if the keys of the current element are the same as the first + # element + if sorted(list(data.keys())) != keys: raise ValueError( - "Unavailable type of condition, expected one of" - f" {ConditionInterface.condition_types}." + "All elements in the list must have the same keys." ) - self._condition_type = values + for name, tensor in data.items(): + # Check if the type of each tensor inside the current element + # is the same as the first element + if tensor.__class__ is not data_types[name]: + raise ValueError( + f"Data {name} must be a {data_types[name]}, got " + f"{tensor.__class__}" + ) + # If the tensor is a LabelTensor, check if the labels are the + # same as the first element + if isinstance(tensor, LabelTensor): + if tensor.labels != labels[name]: + raise ValueError( + "LabelTensor must have the same labels" + ) diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index ffd10f3..1560157 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -1,12 +1,15 @@ -import torch +""" +DataCondition class +""" -from . import ConditionInterface +import torch +from torch_geometric.data import Data +from .condition_interface import ConditionInterface from ..label_tensor import LabelTensor from ..graph import Graph -from ..utils import check_consistency -class DataConditionInterface(ConditionInterface): +class DataCondition(ConditionInterface): """ Condition for data. This condition must be used every time a Unsupervised Loss is needed in the Solver. The conditionalvariable @@ -14,19 +17,64 @@ class DataConditionInterface(ConditionInterface): distribution """ - __slots__ = ["input_points", "conditional_variables"] + __slots__ = ["input", "conditional_variables"] + _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) + _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) - def __init__(self, input_points, conditional_variables=None): + def __new__(cls, input, conditional_variables=None): """ - TODO : add docstring + Instanciate the correct subclass of DataCondition by checking the type + of the input data (input and conditional_variables). + + :param input: torch.Tensor or Graph/Data object containing the input + data + :type input: torch.Tensor or Graph or Data + :param conditional_variables: torch.Tensor or LabelTensor containing + the conditional variables + :type conditional_variables: torch.Tensor or LabelTensor + :return: DataCondition subclass + :rtype: TensorDataCondition or GraphDataCondition + """ + if cls != DataCondition: + return super().__new__(cls) + if isinstance(input, (torch.Tensor, LabelTensor)): + subclass = TensorDataCondition + return subclass.__new__(subclass, input, conditional_variables) + + if isinstance(input, (Graph, Data, list, tuple)): + cls._check_graph_list_consistency(input) + subclass = GraphDataCondition + return subclass.__new__(subclass, input, conditional_variables) + + raise ValueError( + "Invalid input types. " + "Please provide either Data or Graph objects." + ) + + def __init__(self, input, conditional_variables=None): + """ + Initialize the DataCondition, storing the input and conditional + variables (if any). + + :param input: torch.Tensor or Graph/Data object containing the input + data + :type input: torch.Tensor or Graph or Data + :param conditional_variables: torch.Tensor or LabelTensor containing + the conditional variables + :type conditional_variables: torch.Tensor or LabelTensor """ super().__init__() - self.input_points = input_points + self.input = input self.conditional_variables = conditional_variables - def __setattr__(self, key, value): - if (key == "input_points") or (key == "conditional_variables"): - check_consistency(value, (LabelTensor, Graph, torch.Tensor)) - DataConditionInterface.__dict__[key].__set__(self, value) - elif key in ("_problem", "_condition_type"): - super().__setattr__(key, value) + +class TensorDataCondition(DataCondition): + """ + DataCondition for torch.Tensor input data + """ + + +class GraphDataCondition(DataCondition): + """ + DataCondition for Graph/Data input data + """ diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 002a7c4..aad9d9f 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -1,4 +1,6 @@ -import torch +""" +DomainEquationCondition class definition. +""" from .condition_interface import ConditionInterface from ..utils import check_consistency @@ -16,7 +18,11 @@ class DomainEquationCondition(ConditionInterface): def __init__(self, domain, equation): """ - TODO : add docstring + Initialize the DomainEquationCondition, storing the domain and equation. + + :param DomainInterface domain: Domain object containing the domain data + :param EquationInterface equation: Equation object containing the + equation data """ super().__init__() self.domain = domain @@ -29,5 +35,5 @@ class DomainEquationCondition(ConditionInterface): elif key == "equation": check_consistency(value, (EquationInterface)) DomainEquationCondition.__dict__[key].__set__(self, value) - elif key in ("_problem", "_condition_type"): + elif key in ("_problem"): super().__setattr__(key, value) diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 061261f..9a267a3 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,5 +1,8 @@ -import torch +""" +Module to define InputEquationCondition class and its subclasses. +""" +from torch_geometric.data import Data from .condition_interface import ConditionInterface from ..label_tensor import LabelTensor from ..graph import Graph @@ -7,30 +10,100 @@ from ..utils import check_consistency from ..equation.equation_interface import EquationInterface -class InputPointsEquationCondition(ConditionInterface): +class InputEquationCondition(ConditionInterface): """ - Condition for input_points/equation data. This condition must be used every + Condition for input/equation data. This condition must be used every time a Physics Informed Loss is needed in the Solver. """ - __slots__ = ["input_points", "equation"] + __slots__ = ["input", "equation"] + _avail_input_cls = (LabelTensor, Graph, list, tuple) + _avail_equation_cls = EquationInterface - def __init__(self, input_points, equation): + def __new__(cls, input, equation): """ - TODO : add docstring + Instanciate the correct subclass of InputEquationCondition by checking + the type of the input data (only `input`). + + :param input: torch.Tensor or Graph/Data object containing the input + :type input: torch.Tensor or Graph or Data + :param EquationInterface equation: Equation object containing the + equation function + :return: InputEquationCondition subclass + :rtype: InputTensorEquationCondition or InputGraphEquationCondition + """ + + # If the class is already a subclass, return the instance + if cls != InputEquationCondition: + return super().__new__(cls) + + # Instanciate the correct subclass + if isinstance(input, (Graph, Data, list, tuple)): + subclass = InputGraphEquationCondition + cls._check_graph_list_consistency(input) + subclass._check_label_tensor(input) + return subclass.__new__(subclass, input, equation) + if isinstance(input, LabelTensor): + subclass = InputTensorEquationCondition + return subclass.__new__(subclass, input, equation) + + # If the input is not a LabelTensor or a Graph object raise an error + raise ValueError( + "The input data object must be a LabelTensor or a Graph object." + ) + + def __init__(self, input, equation): + """ + Initialize the InputEquationCondition by storing the input and equation. + + :param input: torch.Tensor or Graph/Data object containing the input + :type input: torch.Tensor or Graph or Data + :param EquationInterface equation: Equation object containing the + equation function """ super().__init__() - self.input_points = input_points + self.input = input self.equation = equation def __setattr__(self, key, value): - if key == "input_points": - check_consistency( - value, (LabelTensor) - ) # for now only labeltensors, we need labels for the operator! - InputPointsEquationCondition.__dict__[key].__set__(self, value) + if key == "input": + check_consistency(value, self._avail_input_cls) + InputEquationCondition.__dict__[key].__set__(self, value) elif key == "equation": - check_consistency(value, (EquationInterface)) - InputPointsEquationCondition.__dict__[key].__set__(self, value) - elif key in ("_problem", "_condition_type"): + check_consistency(value, self._avail_equation_cls) + InputEquationCondition.__dict__[key].__set__(self, value) + elif key in ("_problem"): super().__setattr__(key, value) + + +class InputTensorEquationCondition(InputEquationCondition): + """ + InputEquationCondition subclass for LabelTensor input data. + """ + + +class InputGraphEquationCondition(InputEquationCondition): + """ + InputEquationCondition subclass for Graph input data. + """ + + @staticmethod + def _check_label_tensor(input): + """ + Check if the input is a LabelTensor. + + :param input: input data + :type input: torch.Tensor or Graph or Data + """ + # Store the fist element of the list/tuple if input is a list/tuple + # it is anougth to check the first element because all elements must + # have the same type and structure (already checked) + data = input[0] if isinstance(input, (list, tuple)) else input + + # Check if the input data contains at least one LabelTensor + for v in data.values(): + if isinstance(v, LabelTensor): + return + raise ValueError( + "The input data object must contain at least one LabelTensor." + ) diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py deleted file mode 100644 index 47f182a..0000000 --- a/pina/condition/input_output_condition.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import torch_geometric - -from .condition_interface import ConditionInterface -from ..label_tensor import LabelTensor -from ..graph import Graph -from ..utils import check_consistency - - -class InputOutputPointsCondition(ConditionInterface): - """ - Condition for domain/equation data. This condition must be used every - time a Physics Informed or a Supervised Loss is needed in the Solver. - """ - - __slots__ = ["input_points", "output_points"] - - def __init__(self, input_points, output_points): - """ - TODO : add docstring - """ - super().__init__() - self.input_points = input_points - self.output_points = output_points - - def __setattr__(self, key, value): - if (key == "input_points") or (key == "output_points"): - check_consistency( - value, - (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data), - ) - InputOutputPointsCondition.__dict__[key].__set__(self, value) - elif key in ("_problem", "_condition_type"): - super().__setattr__(key, value) diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py new file mode 100644 index 0000000..70e09bc --- /dev/null +++ b/pina/condition/input_target_condition.py @@ -0,0 +1,121 @@ +""" +This module contains condition classes for supervised learning tasks. +""" + +import torch +from torch_geometric.data import Data +from ..label_tensor import LabelTensor +from ..graph import Graph +from .condition_interface import ConditionInterface + + +class InputTargetCondition(ConditionInterface): + """ + Condition for domain/equation data. This condition must be used every + time a Physics Informed or a Supervised Loss is needed in the Solver. + """ + + __slots__ = ["input", "target"] + _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) + _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) + + def __new__(cls, input, target): + """ + Instanciate the correct subclass of InputTargetCondition by checking the + type of the input and target data. + + :param input: torch.Tensor or Graph/Data object containing the input + :type input: torch.Tensor or Graph or Data + :param target: torch.Tensor or Graph/Data object containing the target + :type target: torch.Tensor or Graph or Data + :return: InputTargetCondition subclass + :rtype: TensorInputTensorTargetCondition or + TensorInputGraphTargetCondition or GraphInputTensorTargetCondition + or GraphInputGraphTargetCondition + """ + if cls != InputTargetCondition: + return super().__new__(cls) + + if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( + target, (torch.Tensor, LabelTensor) + ): + subclass = TensorInputTensorTargetCondition + return subclass.__new__(subclass, input, target) + if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance( + target, (Graph, Data, list, tuple) + ): + cls._check_graph_list_consistency(target) + subclass = TensorInputGraphTargetCondition + return subclass.__new__(subclass, input, target) + + if isinstance(input, (Graph, Data, list, tuple)) and isinstance( + target, (torch.Tensor, LabelTensor) + ): + cls._check_graph_list_consistency(input) + subclass = GraphInputTensorTargetCondition + return subclass.__new__(subclass, input, target) + + if isinstance(input, (Graph, Data, list, tuple)) and isinstance( + target, (Graph, Data, list, tuple) + ): + cls._check_graph_list_consistency(input) + cls._check_graph_list_consistency(target) + subclass = GraphInputGraphTargetCondition + return subclass.__new__(subclass, input, target) + + raise ValueError( + "Invalid input/target types. " + "Please provide either Data, Graph, LabelTensor or torch.Tensor " + "objects." + ) + + def __init__(self, input, target): + """ + Initialize the InputTargetCondition, storing the input and target data. + + :param input: torch.Tensor or Graph/Data object containing the input + :type input: torch.Tensor or Graph or Data + :param target: torch.Tensor or Graph/Data object containing the target + :type target: torch.Tensor or Graph or Data + """ + super().__init__() + self._check_input_target_len(input, target) + self.input = input + self.target = target + + @staticmethod + def _check_input_target_len(input, target): + if isinstance(input, (Graph, Data)) or isinstance( + target, (Graph, Data) + ): + return + if len(input) != len(target): + raise ValueError( + "The input and target lists must have the same length." + ) + + +class TensorInputTensorTargetCondition(InputTargetCondition): + """ + InputTargetCondition subclass for torch.Tensor input and target data. + """ + + +class TensorInputGraphTargetCondition(InputTargetCondition): + """ + InputTargetCondition subclass for torch.Tensor input and Graph/Data target + data. + """ + + +class GraphInputTensorTargetCondition(InputTargetCondition): + """ + InputTargetCondition subclass for Graph/Data input and torch.Tensor target + data. + """ + + +class GraphInputGraphTargetCondition(InputTargetCondition): + """ + InputTargetCondition subclass for Graph/Data input and target data. + """ diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 288e00e..8157ea4 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -285,7 +285,7 @@ class PinaDataModule(LightningDataModule): @staticmethod def _split_condition(condition_dict, splits_dict): - len_condition = len(condition_dict["input_points"]) + len_condition = len(condition_dict["input"]) lengths = [ int(len_condition * length) for length in splits_dict.values() @@ -343,7 +343,7 @@ class PinaDataModule(LightningDataModule): condition_name, condition_dict, ) in collector.data_collections.items(): - len_data = len(condition_dict["input_points"]) + len_data = len(condition_dict["input"]) if self.shuffle: _apply_shuffle(condition_dict, len_data) for key, data in self._split_condition( @@ -390,12 +390,12 @@ class PinaDataModule(LightningDataModule): max_conditions_lengths = {} for k, v in self.collector_splits[split].items(): if self.batch_size is None: - max_conditions_lengths[k] = len(v["input_points"]) + max_conditions_lengths[k] = len(v["input"]) elif self.repeat: max_conditions_lengths[k] = self.batch_size else: max_conditions_lengths[k] = min( - len(v["input_points"]), self.batch_size + len(v["input"]), self.batch_size ) return max_conditions_lengths @@ -455,15 +455,15 @@ class PinaDataModule(LightningDataModule): raise ValueError("The sum of the splits must be 1") @property - def input_points(self): + def input(self): """ # TODO """ to_return = {} if hasattr(self, "train_dataset") and self.train_dataset is not None: - to_return["train"] = self.train_dataset.input_points + to_return["train"] = self.train_dataset.input if hasattr(self, "val_dataset") and self.val_dataset is not None: - to_return["val"] = self.val_dataset.input_points + to_return["val"] = self.val_dataset.input if hasattr(self, "test_dataset") and self.test_dataset is not None: - to_return = self.test_dataset.input_points + to_return = self.test_dataset.input return to_return diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 3944ef4..3174b4b 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -2,12 +2,10 @@ This module provide basic data management functionalities """ -import functools -import torch -from torch.utils.data import Dataset from abc import abstractmethod -from torch_geometric.data import Batch, Data -from pina import LabelTensor +from torch.utils.data import Dataset +from torch_geometric.data import Data +from ..graph import Graph, LabelBatch class PinaDatasetFactory: @@ -19,25 +17,25 @@ class PinaDatasetFactory: """ def __new__(cls, conditions_dict, **kwargs): + # Check if conditions_dict is empty if len(conditions_dict) == 0: raise ValueError("No conditions provided") - if all( - [ - isinstance(v["input_points"], torch.Tensor) - for v in conditions_dict.values() - ] - ): - return PinaTensorDataset(conditions_dict, **kwargs) - elif all( - [ - isinstance(v["input_points"], list) - for v in conditions_dict.values() - ] - ): + + # Check is a Graph is present in the conditions + is_graph = cls._is_graph_dataset(conditions_dict) + if is_graph: + # If a Graph is present, return a PinaGraphDataset return PinaGraphDataset(conditions_dict, **kwargs) - raise ValueError( - "Conditions must be either torch.Tensor or list of Data " "objects." - ) + # If no Graph is present, return a PinaTensorDataset + return PinaTensorDataset(conditions_dict, **kwargs) + + @staticmethod + def _is_graph_dataset(conditions_dict): + for v in conditions_dict.values(): + for cond in v.values(): + if isinstance(cond, (Data, Graph, list)): + return True + return False class PinaDataset(Dataset): @@ -45,209 +43,140 @@ class PinaDataset(Dataset): Abstract class for the PINA dataset """ - def __init__(self, conditions_dict, max_conditions_lengths): + def __init__( + self, conditions_dict, max_conditions_lengths, automatic_batching + ): + # Store the conditions dictionary self.conditions_dict = conditions_dict + # Store the maximum number of conditions to consider self.max_conditions_lengths = max_conditions_lengths + # Store length of each condition self.conditions_length = { - k: len(v["input_points"]) for k, v in self.conditions_dict.items() + k: len(v["input"]) for k, v in self.conditions_dict.items() } + # Store the maximum length of the dataset self.length = max(self.conditions_length.values()) + # Dynamically set the getitem function based on automatic batching + if automatic_batching: + self._getitem_func = self._getitem_int + else: + self._getitem_func = self._getitem_dummy def _get_max_len(self): + """""" max_len = 0 for condition in self.conditions_dict.values(): - max_len = max(max_len, len(condition["input_points"])) + max_len = max(max_len, len(condition["input"])) return max_len def __len__(self): return self.length + def __getitem__(self, idx): + return self._getitem_func(idx) + + def _getitem_dummy(self, idx): + # If automatic batching is disabled, return the data at the given index + return idx + + def _getitem_int(self, idx): + # If automatic batching is enabled, return the data at the given index + return { + k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} + for k, v in self.conditions_dict.items() + } + + def get_all_data(self): + """ + Return all data in the dataset + + :return: All data in the dataset + :rtype: dict + """ + index = list(range(len(self))) + return self.fetch_from_idx_list(index) + + def fetch_from_idx_list(self, idx): + """ + Return data from the dataset given a list of indices + + :param idx: List of indices + :type idx: list + :return: Data from the dataset + :rtype: dict + """ + to_return_dict = {} + for condition, data in self.conditions_dict.items(): + # Get the indices for the current condition + cond_idx = idx[: self.max_conditions_lengths[condition]] + # Get the length of the current condition + condition_len = self.conditions_length[condition] + # If the length of the dataset is greater than the length of the + # current condition, repeat the indices + if self.length > condition_len: + cond_idx = [idx % condition_len for idx in cond_idx] + # Retrieve the data from the current condition + to_return_dict[condition] = self._retrive_data(data, cond_idx) + return to_return_dict + @abstractmethod - def __getitem__(self, item): + def _retrive_data(self, data, idx_list): pass class PinaTensorDataset(PinaDataset): - def __init__( - self, conditions_dict, max_conditions_lengths, automatic_batching - ): - super().__init__(conditions_dict, max_conditions_lengths) + """ + Class for the PINA dataset with torch.Tensor data + """ - if automatic_batching: - self._getitem_func = self._getitem_int - else: - self._getitem_func = self._getitem_dummy - - def _getitem_int(self, idx): - return { - k: { - k_data: v[k_data][idx % len(v["input_points"])] - for k_data in v.keys() - } - for k, v in self.conditions_dict.items() - } - - def fetch_from_idx_list(self, idx): - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - cond_idx = idx[: self.max_conditions_lengths[condition]] - condition_len = self.conditions_length[condition] - if self.length > condition_len: - cond_idx = [idx % condition_len for idx in cond_idx] - to_return_dict[condition] = { - k: v[cond_idx] for k, v in data.items() - } - return to_return_dict - - @staticmethod - def _getitem_dummy(idx): - return idx - - def get_all_data(self): - index = [i for i in range(len(self))] - return self.fetch_from_idx_list(index) - - def __getitem__(self, idx): - return self._getitem_func(idx) + # Override _retrive_data method for torch.Tensor data + def _retrive_data(self, data, idx_list): + return {k: v[idx_list] for k, v in data.items()} @property - def input_points(self): + def input(self): """ Method to return input points for training. """ - return {k: v["input_points"] for k, v in self.conditions_dict.items()} - - -class PinaBatch(Batch): - """ - Add extract function to torch_geometric Batch object - """ - - def __init__(self): - - super().__init__(self) - - def extract(self, labels): - """ - Perform extraction of labels on node features (x) - - :param labels: Labels to extract - :type labels: list[str] | tuple[str] | str - :return: Batch object with extraction performed on x - :rtype: PinaBatch - """ - self.x = self.x.extract(labels) - return self + return {k: v["input"] for k, v in self.conditions_dict.items()} class PinaGraphDataset(PinaDataset): + """ + Class for the PINA dataset with torch_geometric.data.Data data + """ - def __init__( - self, conditions_dict, max_conditions_lengths, automatic_batching - ): - super().__init__(conditions_dict, max_conditions_lengths) - self.in_labels = {} - self.out_labels = None - if automatic_batching: - self._getitem_func = self._getitem_int - else: - self._getitem_func = self._getitem_dummy - - ex_data = conditions_dict[list(conditions_dict.keys())[0]][ - "input_points" - ][0] - for name, attr in ex_data.items(): - if isinstance(attr, LabelTensor): - self.in_labels[name] = attr.stored_labels - ex_data = conditions_dict[list(conditions_dict.keys())[0]][ - "output_points" - ][0] - if isinstance(ex_data, LabelTensor): - self.out_labels = ex_data.labels - - self._create_graph_batch_from_list = ( - self._labelise_batch(self._base_create_graph_batch_from_list) - if self.in_labels - else self._base_create_graph_batch_from_list - ) - - self._create_output_batch = ( - self._labelise_tensor(self._base_create_output_batch) - if self.out_labels is not None - else self._base_create_output_batch - ) - - def fetch_from_idx_list(self, idx): - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - cond_idx = idx[: self.max_conditions_lengths[condition]] - condition_len = self.conditions_length[condition] - if self.length > condition_len: - cond_idx = [idx % condition_len for idx in cond_idx] - to_return_dict[condition] = { - k: ( - self._create_graph_batch_from_list([v[i] for i in idx]) - if isinstance(v, list) - else self._create_output_batch(v[idx]) - ) - for k, v in data.items() - } - - return to_return_dict - - def _base_create_graph_batch_from_list(self, data): - batch = PinaBatch.from_data_list(data) + def _create_graph_batch_from_list(self, data): + batch = LabelBatch.from_data_list(data) return batch - def _base_create_output_batch(self, data): + def _create_output_batch(self, data): out = data.reshape(-1, *data.shape[2:]) return out - def _getitem_dummy(self, idx): - return idx - - def _getitem_int(self, idx): - return { - k: { - k_data: v[k_data][idx % len(v["input_points"])] - for k_data in v.keys() - } - for k, v in self.conditions_dict.items() - } - - def get_all_data(self): - index = [i for i in range(len(self))] - return self.fetch_from_idx_list(index) - - def __getitem__(self, idx): - return self._getitem_func(idx) - - def _labelise_batch(self, func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - batch = func(*args, **kwargs) - for k, v in self.in_labels.items(): - tmp = batch[k] - tmp.labels = v - batch[k] = tmp - return batch - - return wrapper - - def _labelise_tensor(self, func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - out = func(*args, **kwargs) - if isinstance(out, LabelTensor): - out.labels = self.out_labels - return out - - return wrapper - def create_graph_batch(self, data): """ - # TODO + Create a Batch object from a list of Data objects. + + :param data: List of Data objects + :type data: list + :return: Batch object + :rtype: Batch or PinaBatch """ if isinstance(data[0], Data): return self._create_graph_batch_from_list(data) return self._create_output_batch(data) + + # Override _retrive_data method for graph handling + def _retrive_data(self, data, idx_list): + # Return the data from the current condition + # If the data is a list of Data objects, create a Batch object + # If the data is a list of torch.Tensor objects, create a torch.Tensor + return { + k: ( + self._create_graph_batch_from_list([v[i] for i in idx_list]) + if isinstance(v, list) + else self._create_output_batch(v[idx_list]) + ) + for k, v in data.items() + } diff --git a/pina/graph.py b/pina/graph.py index 77e426e..7d15769 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -3,7 +3,7 @@ This module provides an interface to build torch_geometric.data.Data objects. """ import torch -from torch_geometric.data import Data +from torch_geometric.data import Data, Batch from torch_geometric.utils import to_undirected from . import LabelTensor from .utils import check_consistency, is_function @@ -162,6 +162,21 @@ class Graph(Data): edge_index = to_undirected(edge_index) return edge_index + def extract(self, labels, attr="x"): + """ + Perform extraction of labels on node features (x) + + :param labels: Labels to extract + :type labels: list[str] | tuple[str] | str + :return: Batch object with extraction performed on x + :rtype: PinaBatch + """ + # Extract labels from LabelTensor object + tensor = getattr(self, attr).extract(labels) + # Set the extracted tensor as the new attribute + setattr(self, attr, tensor) + return self + class GraphBuilder: """ @@ -317,3 +332,31 @@ class KNNGraph(GraphBuilder): row = torch.arange(points.size(0)).repeat_interleave(k) col = knn_indices.flatten() return torch.stack([row, col], dim=0).as_subclass(torch.Tensor) + + +class LabelBatch(Batch): + """ + Add extract function to torch_geometric Batch object + """ + + @classmethod + def from_data_list(cls, data_list): + """ + Create a Batch object from a list of Data objects. + """ + # Store the labels of Data/Graph objects (all data have the same labels) + # If the data do not contain labels, labels is an empty dictionary, + # therefore the labels are not stored + labels = { + k: v.labels + for k, v in data_list[0].items() + if isinstance(v, LabelTensor) + } + + # Create a Batch object from the list of Data objects + batch = super().from_data_list(data_list) + + # Put the labels back in the Batch object + for k, v in labels.items(): + batch[k].labels = v + return batch diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 6448044..79313de 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -602,3 +602,20 @@ class LabelTensor(torch.Tensor): } return LabelTensor(data, labels) + + def reshape(self, *shape): + """ + Override the reshape method to update the labels of the tensor. + + :param shape: The new shape of the tensor. + :type shape: tuple + :return: A tensor-like object with updated labels. + :rtype: LabelTensor + """ + # As for now the reshape method is used only in the context of the + # dataset, the labels are not + tensor = super().reshape(*shape) + if not hasattr(self, "_labels") or shape != (-1, *self.shape[2:]): + return tensor + tensor.labels = self.labels + return tensor diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 9cdef60..ddc98af 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod from ..utils import check_consistency from ..domain import DomainInterface, CartesianDomain from ..condition.domain_equation_condition import DomainEquationCondition -from ..condition import InputPointsEquationCondition from copy import deepcopy from .. import LabelTensor from ..utils import merge_tensors @@ -55,8 +54,8 @@ class AbstractProblem(metaclass=ABCMeta): def input_pts(self): to_return = {} for cond_name, cond in self.conditions.items(): - if hasattr(cond, "input_points"): - to_return[cond_name] = cond.input_points + if hasattr(cond, "input"): + to_return[cond_name] = cond.input elif hasattr(cond, "domain"): to_return[cond_name] = self._discretised_domains[cond.domain] return to_return diff --git a/pina/problem/zoo/inverse_diffusion_reaction.py b/pina/problem/zoo/inverse_diffusion_reaction.py index 911f68e..0a05605 100644 --- a/pina/problem/zoo/inverse_diffusion_reaction.py +++ b/pina/problem/zoo/inverse_diffusion_reaction.py @@ -46,8 +46,8 @@ class InverseDiffusionReactionProblem( equation=Equation(diffusion_reaction), ), "data": Condition( - input_points=LabelTensor(torch.randn(10, 2), ["x", "t"]), - output_points=LabelTensor(torch.randn(10, 1), ["u"]), + input=LabelTensor(torch.randn(10, 2), ["x", "t"]), + target=LabelTensor(torch.randn(10, 1), ["u"]), ), } diff --git a/pina/problem/zoo/inverse_poisson_2d_square.py b/pina/problem/zoo/inverse_poisson_2d_square.py index 3a46334..2d9bbe5 100644 --- a/pina/problem/zoo/inverse_poisson_2d_square.py +++ b/pina/problem/zoo/inverse_poisson_2d_square.py @@ -50,7 +50,7 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem): "nil_g4": Condition(domain="g4", equation=FixedValue(0.0)), "laplace_D": Condition(domain="D", equation=Equation(laplace_equation)), "data": Condition( - input_points=data_input.extract(["x", "y"]), - output_points=data_output, + input=data_input.extract(["x", "y"]), + target=data_output, ), } diff --git a/pina/problem/zoo/supervised_problem.py b/pina/problem/zoo/supervised_problem.py index ef04062..b45bc91 100644 --- a/pina/problem/zoo/supervised_problem.py +++ b/pina/problem/zoo/supervised_problem.py @@ -8,7 +8,7 @@ class SupervisedProblem(AbstractProblem): A problem definition for supervised learning in PINA. This class allows an easy and straightforward definition of a Supervised problem, - based on a single condition of type `InputOutputPointsCondition` + based on a single condition of type `InputTargetCondition` :Example: >>> import torch @@ -31,7 +31,5 @@ class SupervisedProblem(AbstractProblem): """ if isinstance(input_, Graph): input_ = input_.data - self.conditions["data"] = Condition( - input_points=input_, output_points=output_ - ) + self.conditions["data"] = Condition(input=input_, target=output_) super().__init__() diff --git a/pina/solver/garom.py b/pina/solver/garom.py index 930c144..8a6f5ff 100644 --- a/pina/solver/garom.py +++ b/pina/solver/garom.py @@ -5,7 +5,7 @@ import torch from .solver import MultiSolverInterface from ..utils import check_consistency from ..loss.loss_interface import LossInterface -from ..condition import InputOutputPointsCondition +from ..condition import InputTargetCondition from ..utils import check_consistency from ..loss import LossInterface, PowerLoss from torch.nn.modules.loss import _Loss @@ -25,7 +25,7 @@ class GAROM(MultiSolverInterface): `_. """ - accepted_conditions_types = InputOutputPointsCondition + accepted_conditions_types = InputTargetCondition def __init__( self, @@ -70,8 +70,8 @@ class GAROM(MultiSolverInterface): .. warning:: The algorithm works only for data-driven model. Hence in the ``problem`` definition - the codition must only contain ``input_points`` (e.g. coefficient parameters, time - parameters), and ``output_points``. + the codition must only contain ``input`` (e.g. coefficient parameters, time + parameters), and ``target``. """ # set loss @@ -233,8 +233,8 @@ class GAROM(MultiSolverInterface): condition_loss = {} for condition_name, points in batch: parameters, snapshots = ( - points["input_points"], - points["output_points"], + points["input"], + points["target"], ) d_loss_real, d_loss_fake, d_loss = self._train_discriminator( parameters, snapshots @@ -257,8 +257,8 @@ class GAROM(MultiSolverInterface): condition_loss = {} for condition_name, points in batch: parameters, snapshots = ( - points["input_points"], - points["output_points"], + points["input"], + points["target"], ) snapshots_gen = self.generator(parameters) condition_loss[condition_name] = self._loss( @@ -272,8 +272,8 @@ class GAROM(MultiSolverInterface): condition_loss = {} for condition_name, points in batch: parameters, snapshots = ( - points["input_points"], - points["output_points"], + points["input"], + points["target"], ) snapshots_gen = self.generator(parameters) condition_loss[condition_name] = self._loss( diff --git a/pina/solver/physic_informed_solver/pinn_interface.py b/pina/solver/physic_informed_solver/pinn_interface.py index 20ce4b2..d478c63 100644 --- a/pina/solver/physic_informed_solver/pinn_interface.py +++ b/pina/solver/physic_informed_solver/pinn_interface.py @@ -9,8 +9,8 @@ from ...utils import check_consistency from ...loss.loss_interface import LossInterface from ...problem import InverseProblem from ...condition import ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition, ) @@ -28,8 +28,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): """ accepted_conditions_types = ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition, ) @@ -138,16 +138,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): for condition_name, points in batch: self.__metric = condition_name # if equations are passed - if "output_points" not in points: - input_pts = points["input_points"] + if "target" not in points: + input_pts = points["input"] condition = self.problem.conditions[condition_name] loss = loss_residuals( input_pts.requires_grad_(), condition.equation ) # if data are passed else: - input_pts = points["input_points"] - output_pts = points["output_points"] + input_pts = points["input"] + output_pts = points["target"] loss = self.loss_data( input_pts=input_pts.requires_grad_(), output_pts=output_pts ) diff --git a/pina/solver/physic_informed_solver/self_adaptive_pinn.py b/pina/solver/physic_informed_solver/self_adaptive_pinn.py index c64c499..4e919b5 100644 --- a/pina/solver/physic_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physic_informed_solver/self_adaptive_pinn.py @@ -262,7 +262,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): for ( condition_name, tensor, - ) in self.trainer.data_module.train_dataset.input_points.items(): + ) in self.trainer.data_module.train_dataset.input.items(): self.weights_dict[condition_name].sa_weights.data = torch.rand( (tensor.shape[0], 1), device=device ) diff --git a/pina/solver/reduced_order_model.py b/pina/solver/reduced_order_model.py index 54aa8a2..1cbce1a 100644 --- a/pina/solver/reduced_order_model.py +++ b/pina/solver/reduced_order_model.py @@ -75,8 +75,8 @@ class ReducedOrderModelSolver(SupervisedSolver): .. warning:: This solver works only for data-driven model. Hence in the ``problem`` - definition the codition must only contain ``input_points`` - (e.g. coefficient parameters, time parameters), and ``output_points``. + definition the codition must only contain ``input`` + (e.g. coefficient parameters, time parameters), and ``target``. .. warning:: This solver does not currently support the possibility to pass diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 2ca7c1c..3509b34 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -172,7 +172,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): # assuming batch is a custom Batch object batch_size = 0 for data in batch: - batch_size += len(data[1]["input_points"]) + batch_size += len(data[1]["input"]) return batch_size @staticmethod diff --git a/pina/solver/supervised.py b/pina/solver/supervised.py index 56771b8..2bfa858 100644 --- a/pina/solver/supervised.py +++ b/pina/solver/supervised.py @@ -5,7 +5,7 @@ from torch.nn.modules.loss import _Loss from .solver import SingleSolverInterface from ..utils import check_consistency from ..loss.loss_interface import LossInterface -from ..condition import InputOutputPointsCondition +from ..condition import InputTargetCondition class SupervisedSolver(SingleSolverInterface): @@ -37,7 +37,7 @@ class SupervisedSolver(SingleSolverInterface): multiple (discretised) input functions. """ - accepted_conditions_types = InputOutputPointsCondition + accepted_conditions_types = InputTargetCondition def __init__( self, @@ -95,8 +95,8 @@ class SupervisedSolver(SingleSolverInterface): condition_loss = {} for condition_name, points in batch: input_pts, output_pts = ( - points["input_points"], - points["output_points"], + points["input"], + points["target"], ) condition_loss[condition_name] = self.loss_data( input_pts=input_pts, output_pts=output_pts diff --git a/tests/test_collector.py b/tests/test_collector.py index 565fed3..3119f9d 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -1,7 +1,7 @@ import torch import pytest from pina import Condition, LabelTensor, Graph -from pina.condition import InputOutputPointsCondition, DomainEquationCondition +from pina.condition import InputTargetCondition, DomainEquationCondition from pina.graph import RadiusGraph from pina.problem import AbstractProblem, SpatialProblem from pina.domain import CartesianDomain @@ -16,16 +16,16 @@ def test_supervised_tensor_collector(): output_variables = None conditions = { "data1": Condition( - input_points=torch.rand((10, 2)), - output_points=torch.rand((10, 2)), + input=torch.rand((10, 2)), + target=torch.rand((10, 2)), ), "data2": Condition( - input_points=torch.rand((20, 2)), - output_points=torch.rand((20, 2)), + input=torch.rand((20, 2)), + target=torch.rand((20, 2)), ), "data3": Condition( - input_points=torch.rand((30, 2)), - output_points=torch.rand((30, 2)), + input=torch.rand((30, 2)), + target=torch.rand((30, 2)), ), } @@ -74,7 +74,7 @@ def test_pinn_collector(): domain=CartesianDomain({"x": [0, 1], "y": [0, 1]}), equation=my_laplace, ), - "data": Condition(input_points=in_, output_points=out_), + "data": Condition(input=in_, target=out_), } def poisson_sol(self, pts): @@ -95,16 +95,16 @@ def test_pinn_collector(): collector.store_sample_domains() for k, v in problem.conditions.items(): - if isinstance(v, InputOutputPointsCondition): + if isinstance(v, InputTargetCondition): assert list(collector.data_collections[k].keys()) == [ - "input_points", - "output_points", + "input", + "target", ] for k, v in problem.conditions.items(): if isinstance(v, DomainEquationCondition): assert list(collector.data_collections[k].keys()) == [ - "input_points", + "input", "equation", ] @@ -123,8 +123,8 @@ def test_supervised_graph_collector(): class SupervisedProblem(AbstractProblem): output_variables = None conditions = { - "data1": Condition(input_points=graph_list_1, output_points=out_1), - "data2": Condition(input_points=graph_list_2, output_points=out_2), + "data1": Condition(input=graph_list_1, target=out_1), + "data2": Condition(input=graph_list_2, target=out_2), } problem = SupervisedProblem() diff --git a/tests/test_condition.py b/tests/test_condition.py index f5842b9..2596e5f 100644 --- a/tests/test_condition.py +++ b/tests/test_condition.py @@ -2,42 +2,151 @@ import torch import pytest from pina import LabelTensor, Condition +from pina.condition import ( + TensorInputGraphTargetCondition, + TensorInputTensorTargetCondition, + GraphInputGraphTargetCondition, + GraphInputTensorTargetCondition, +) +from pina.condition import ( + InputTensorEquationCondition, + InputGraphEquationCondition, + DomainEquationCondition, +) +from pina.condition import ( + TensorDataCondition, + GraphDataCondition, +) from pina.domain import CartesianDomain from pina.equation.equation_factory import FixedValue +from pina.graph import RadiusGraph -example_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) -example_input_pts = LabelTensor(torch.tensor([[0, 0, 0]]), ['x', 'y', 'z']) -example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b']) +example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) + +input_tensor = torch.rand((10,3)) +target_tensor = torch.rand((10,2)) +input_lt = LabelTensor(torch.rand((10,3)), ["x", "y", "z"]) +target_lt = LabelTensor(torch.rand((10,2)), ["a", "b"]) + +x = torch.rand(10, 20, 2) +pos = torch.rand(10, 20, 2) +radius = 0.1 +input_graph = [ + RadiusGraph( + x=x_, + pos=pos_, + radius=radius, + ) + for x_, pos_ in zip(x, pos) +] +target_graph = [ + RadiusGraph( + x=x_, + pos=pos_, + radius=radius, + ) + for x_, pos_ in zip(x, pos) +] + +x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) +pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) +radius = 0.1 +input_graph_lt = [ + RadiusGraph( + x=x[i], + pos=pos[i], + radius=radius, + ) + for i in range(len(x)) +] +target_graph_lt = [ + RadiusGraph( + x=x[i], + pos=pos[i], + radius=radius, + ) + for i in range(len(x)) +] + +input_single_graph = input_graph[0] +target_single_graph = target_graph[0] -def test_init_inputoutput(): - Condition(input_points=example_input_pts, output_points=example_output_pts) +def test_init_input_target(): + cond = Condition(input=input_tensor, target=target_tensor) + assert isinstance(cond, TensorInputTensorTargetCondition) + cond = Condition(input=input_tensor, target=target_tensor) + assert isinstance(cond, TensorInputTensorTargetCondition) + cond = Condition(input=input_tensor, target=target_graph) + assert isinstance(cond, TensorInputGraphTargetCondition) + cond = Condition(input=input_graph, target=target_tensor) + assert isinstance(cond, GraphInputTensorTargetCondition) + cond = Condition(input=input_graph, target=target_graph) + assert isinstance(cond, GraphInputGraphTargetCondition) + + cond = Condition(input=input_lt, target=input_single_graph) + assert isinstance(cond, TensorInputGraphTargetCondition) + cond = Condition(input=input_single_graph, target=target_lt) + assert isinstance(cond, GraphInputTensorTargetCondition) + cond = Condition(input=input_graph, target=target_graph) + assert isinstance(cond, GraphInputGraphTargetCondition) + cond = Condition(input=input_single_graph, target=target_single_graph) + assert isinstance(cond, GraphInputGraphTargetCondition) + with pytest.raises(ValueError): - Condition(example_input_pts, example_output_pts) + Condition(input_tensor, input_tensor) with pytest.raises(ValueError): - Condition(input_points=3., output_points='example') + Condition(input=3.0, target="example") with pytest.raises(ValueError): - Condition(input_points=example_domain, output_points=example_domain) + Condition(input=example_domain, target=example_domain) + # Test wrong graph condition initialisation + input = [input_graph[0], input_graph_lt[0]] + target = [target_graph[0], target_graph_lt[0]] + with pytest.raises(ValueError): + Condition(input=input, target=target) + + input_graph_lt[0].x.labels = ["a", "b"] + with pytest.raises(ValueError): + Condition(input=input_graph_lt, target=target_graph_lt) + input_graph_lt[0].x.labels = ["u", "v"] + -test_init_inputoutput() - - -def test_init_domainfunc(): - Condition(domain=example_domain, equation=FixedValue(0.0)) +def test_init_domain_equation(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + assert isinstance(cond, DomainEquationCondition) with pytest.raises(ValueError): Condition(example_domain, FixedValue(0.0)) with pytest.raises(ValueError): - Condition(domain=3., equation='example') + Condition(domain=3.0, equation="example") with pytest.raises(ValueError): - Condition(domain=example_input_pts, equation=example_output_pts) + Condition(domain=input_tensor, equation=input_graph) -def test_init_inputfunc(): - Condition(input_points=example_input_pts, equation=FixedValue(0.0)) +def test_init_input_equation(): + cond = Condition(input=input_lt, equation=FixedValue(0.0)) + assert isinstance(cond, InputTensorEquationCondition) + cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) + assert isinstance(cond, InputGraphEquationCondition) + with pytest.raises(ValueError): + cond = Condition(input=input_tensor, equation=FixedValue(0.0)) with pytest.raises(ValueError): Condition(example_domain, FixedValue(0.0)) with pytest.raises(ValueError): - Condition(input_points=3., equation='example') + Condition(input=3.0, equation="example") with pytest.raises(ValueError): - Condition(input_points=example_domain, equation=example_output_pts) + Condition(input=example_domain, equation=input_graph) +test_init_input_equation() + +def test_init_data_condition(): + cond = Condition(input=input_lt) + assert isinstance(cond, TensorDataCondition) + cond = Condition(input=input_tensor) + assert isinstance(cond, TensorDataCondition) + cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) + assert isinstance(cond, TensorDataCondition) + cond = Condition(input=input_graph) + assert isinstance(cond, GraphDataCondition) + cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) + assert isinstance(cond, GraphDataCondition) + diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 2d7de9d..fe7b3eb 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -114,10 +114,10 @@ def test_dummy_dataloader(input_, output_): assert isinstance(data, list) assert isinstance(data[0], tuple) if isinstance(input_, list): - assert isinstance(data[0][1]["input_points"], Batch) + assert isinstance(data[0][1]["input"], Batch) else: - assert isinstance(data[0][1]["input_points"], torch.Tensor) - assert isinstance(data[0][1]["output_points"], torch.Tensor) + assert isinstance(data[0][1]["input"], torch.Tensor) + assert isinstance(data[0][1]["target"], torch.Tensor) dataloader = dm.val_dataloader() assert isinstance(dataloader, DummyDataloader) @@ -126,10 +126,10 @@ def test_dummy_dataloader(input_, output_): assert isinstance(data, list) assert isinstance(data[0], tuple) if isinstance(input_, list): - assert isinstance(data[0][1]["input_points"], Batch) + assert isinstance(data[0][1]["input"], Batch) else: - assert isinstance(data[0][1]["input_points"], torch.Tensor) - assert isinstance(data[0][1]["output_points"], torch.Tensor) + assert isinstance(data[0][1]["input"], torch.Tensor) + assert isinstance(data[0][1]["target"], torch.Tensor) @pytest.mark.parametrize( @@ -157,10 +157,10 @@ def test_dataloader(input_, output_, automatic_batching): data = next(iter(dataloader)) assert isinstance(data, dict) if isinstance(input_, list): - assert isinstance(data["data"]["input_points"], Batch) + assert isinstance(data["data"]["input"], Batch) else: - assert isinstance(data["data"]["input_points"], torch.Tensor) - assert isinstance(data["data"]["output_points"], torch.Tensor) + assert isinstance(data["data"]["input"], torch.Tensor) + assert isinstance(data["data"]["target"], torch.Tensor) dataloader = dm.val_dataloader() assert isinstance(dataloader, DataLoader) @@ -168,10 +168,10 @@ def test_dataloader(input_, output_, automatic_batching): data = next(iter(dataloader)) assert isinstance(data, dict) if isinstance(input_, list): - assert isinstance(data["data"]["input_points"], Batch) + assert isinstance(data["data"]["input"], Batch) else: - assert isinstance(data["data"]["input_points"], torch.Tensor) - assert isinstance(data["data"]["output_points"], torch.Tensor) + assert isinstance(data["data"]["input"], torch.Tensor) + assert isinstance(data["data"]["target"], torch.Tensor) from pina import LabelTensor @@ -212,15 +212,15 @@ def test_dataloader_labels(input_, output_, automatic_batching): data = next(iter(dataloader)) assert isinstance(data, dict) if isinstance(input_, list): - assert isinstance(data["data"]["input_points"], Batch) - assert isinstance(data["data"]["input_points"].x, LabelTensor) - assert data["data"]["input_points"].x.labels == ["u", "v", "w"] - assert data["data"]["input_points"].pos.labels == ["x", "y"] + assert isinstance(data["data"]["input"], Batch) + assert isinstance(data["data"]["input"].x, LabelTensor) + assert data["data"]["input"].x.labels == ["u", "v", "w"] + assert data["data"]["input"].pos.labels == ["x", "y"] else: - assert isinstance(data["data"]["input_points"], LabelTensor) - assert data["data"]["input_points"].labels == ["u", "v", "w"] - assert isinstance(data["data"]["output_points"], LabelTensor) - assert data["data"]["output_points"].labels == ["u", "v", "w"] + assert isinstance(data["data"]["input"], LabelTensor) + assert data["data"]["input"].labels == ["u", "v", "w"] + assert isinstance(data["data"]["target"], LabelTensor) + assert data["data"]["target"].labels == ["u", "v", "w"] dataloader = dm.val_dataloader() assert isinstance(dataloader, DataLoader) @@ -228,13 +228,13 @@ def test_dataloader_labels(input_, output_, automatic_batching): data = next(iter(dataloader)) assert isinstance(data, dict) if isinstance(input_, list): - assert isinstance(data["data"]["input_points"], Batch) - assert isinstance(data["data"]["input_points"].x, LabelTensor) - assert data["data"]["input_points"].x.labels == ["u", "v", "w"] - assert data["data"]["input_points"].pos.labels == ["x", "y"] + assert isinstance(data["data"]["input"], Batch) + assert isinstance(data["data"]["input"].x, LabelTensor) + assert data["data"]["input"].x.labels == ["u", "v", "w"] + assert data["data"]["input"].pos.labels == ["x", "y"] else: - assert isinstance(data["data"]["input_points"], torch.Tensor) - assert isinstance(data["data"]["input_points"], LabelTensor) - assert data["data"]["input_points"].labels == ["u", "v", "w"] - assert isinstance(data["data"]["output_points"], torch.Tensor) - assert data["data"]["output_points"].labels == ["u", "v", "w"] + assert isinstance(data["data"]["input"], torch.Tensor) + assert isinstance(data["data"]["input"], LabelTensor) + assert data["data"]["input"].labels == ["u", "v", "w"] + assert isinstance(data["data"]["target"], torch.Tensor) + assert data["data"]["target"].labels == ["u", "v", "w"] diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py index 4acb19e..e50e2c4 100644 --- a/tests/test_data/test_graph_dataset.py +++ b/tests/test_data/test_graph_dataset.py @@ -24,8 +24,8 @@ output_2_ = torch.rand((50, 20, 10)) # Problem with a single condition conditions_dict_single = { "data": { - "input_points": input_, - "output_points": output_, + "input": input_, + "target": output_, } } max_conditions_lengths_single = {"data": 100} @@ -33,12 +33,12 @@ max_conditions_lengths_single = {"data": 100} # Problem with multiple conditions conditions_dict_single_multi = { "data_1": { - "input_points": input_, - "output_points": output_, + "input": input_, + "target": output_, }, "data_2": { - "input_points": input_2_, - "output_points": output_2_, + "input": input_2_, + "target": output_2_, }, } @@ -77,56 +77,56 @@ def test_getitem(conditions_dict, max_conditions_lengths): ) data = dataset[50] assert isinstance(data, dict) - assert all([isinstance(d["input_points"], Data) for d in data.values()]) + assert all([isinstance(d["input"], Data) for d in data.values()]) assert all( - [isinstance(d["output_points"], torch.Tensor) for d in data.values()] + [isinstance(d["target"], torch.Tensor) for d in data.values()] ) assert all( [ - d["input_points"].x.shape == torch.Size((20, 10)) + d["input"].x.shape == torch.Size((20, 10)) for d in data.values() ] ) assert all( [ - d["output_points"].shape == torch.Size((20, 10)) + d["target"].shape == torch.Size((20, 10)) for d in data.values() ] ) assert all( [ - d["input_points"].edge_index.shape == torch.Size((2, 60)) + d["input"].edge_index.shape == torch.Size((2, 60)) for d in data.values() ] ) assert all( - [d["input_points"].edge_attr.shape[0] == 60 for d in data.values()] + [d["input"].edge_attr.shape[0] == 60 for d in data.values()] ) data = dataset.fetch_from_idx_list([i for i in range(20)]) assert isinstance(data, dict) - assert all([isinstance(d["input_points"], Data) for d in data.values()]) + assert all([isinstance(d["input"], Data) for d in data.values()]) assert all( - [isinstance(d["output_points"], torch.Tensor) for d in data.values()] + [isinstance(d["target"], torch.Tensor) for d in data.values()] ) assert all( [ - d["input_points"].x.shape == torch.Size((400, 10)) + d["input"].x.shape == torch.Size((400, 10)) for d in data.values() ] ) assert all( [ - d["output_points"].shape == torch.Size((400, 10)) + d["target"].shape == torch.Size((400, 10)) for d in data.values() ] ) assert all( [ - d["input_points"].edge_index.shape == torch.Size((2, 1200)) + d["input"].edge_index.shape == torch.Size((2, 1200)) for d in data.values() ] ) assert all( - [d["input_points"].edge_attr.shape[0] == 1200 for d in data.values()] + [d["input"].edge_attr.shape[0] == 1200 for d in data.values()] ) diff --git a/tests/test_data/test_tensor_dataset.py b/tests/test_data/test_tensor_dataset.py index 230cae4..a340576 100644 --- a/tests/test_data/test_tensor_dataset.py +++ b/tests/test_data/test_tensor_dataset.py @@ -10,19 +10,19 @@ output_tensor_2 = torch.rand((50, 2)) conditions_dict_single = { 'data': { - 'input_points': input_tensor, - 'output_points': output_tensor, + 'input': input_tensor, + 'target': output_tensor, } } conditions_dict_single_multi = { 'data_1': { - 'input_points': input_tensor, - 'output_points': output_tensor, + 'input': input_tensor, + 'target': output_tensor, }, 'data_2': { - 'input_points': input_tensor_2, - 'output_points': output_tensor_2, + 'input': input_tensor_2, + 'target': output_tensor_2, } } @@ -59,11 +59,11 @@ def test_getitem_single(): assert isinstance(tensors, dict) assert list(tensors.keys()) == ['data'] assert sorted(list(tensors['data'].keys())) == [ - 'input_points', 'output_points'] - assert isinstance(tensors['data']['input_points'], torch.Tensor) - assert tensors['data']['input_points'].shape == torch.Size((70, 10)) - assert isinstance(tensors['data']['output_points'], torch.Tensor) - assert tensors['data']['output_points'].shape == torch.Size((70, 2)) + 'input', 'target'] + assert isinstance(tensors['data']['input'], torch.Tensor) + assert tensors['data']['input'].shape == torch.Size((70, 10)) + assert isinstance(tensors['data']['target'], torch.Tensor) + assert tensors['data']['target'].shape == torch.Size((70, 2)) def test_getitem_multi(): @@ -74,15 +74,15 @@ def test_getitem_multi(): assert isinstance(tensors, dict) assert list(tensors.keys()) == ['data_1', 'data_2'] assert sorted(list(tensors['data_1'].keys())) == [ - 'input_points', 'output_points'] - assert isinstance(tensors['data_1']['input_points'], torch.Tensor) - assert tensors['data_1']['input_points'].shape == torch.Size((70, 10)) - assert isinstance(tensors['data_1']['output_points'], torch.Tensor) - assert tensors['data_1']['output_points'].shape == torch.Size((70, 2)) + 'input', 'target'] + assert isinstance(tensors['data_1']['input'], torch.Tensor) + assert tensors['data_1']['input'].shape == torch.Size((70, 10)) + assert isinstance(tensors['data_1']['target'], torch.Tensor) + assert tensors['data_1']['target'].shape == torch.Size((70, 2)) assert sorted(list(tensors['data_2'].keys())) == [ - 'input_points', 'output_points'] - assert isinstance(tensors['data_2']['input_points'], torch.Tensor) - assert tensors['data_2']['input_points'].shape == torch.Size((50, 10)) - assert isinstance(tensors['data_2']['output_points'], torch.Tensor) - assert tensors['data_2']['output_points'].shape == torch.Size((50, 2)) + 'input', 'target'] + assert isinstance(tensors['data_2']['input'], torch.Tensor) + assert tensors['data_2']['input'].shape == torch.Size((50, 10)) + assert isinstance(tensors['data_2']['target'], torch.Tensor) + assert tensors['data_2']['target'].shape == torch.Size((50, 2)) diff --git a/tests/test_problem_zoo/test_supervised_problem.py b/tests/test_problem_zoo/test_supervised_problem.py index 06241fa..19b3920 100644 --- a/tests/test_problem_zoo/test_supervised_problem.py +++ b/tests/test_problem_zoo/test_supervised_problem.py @@ -1,6 +1,6 @@ import torch from pina.problem import AbstractProblem -from pina.condition import InputOutputPointsCondition +from pina.condition import InputTargetCondition from pina.problem.zoo.supervised_problem import SupervisedProblem from pina.graph import RadiusGraph @@ -13,7 +13,7 @@ def test_constructor(): assert hasattr(problem, "conditions") assert isinstance(problem.conditions, dict) assert list(problem.conditions.keys()) == ["data"] - assert isinstance(problem.conditions["data"], InputOutputPointsCondition) + assert isinstance(problem.conditions["data"], InputTargetCondition) def test_constructor_graph(): @@ -23,12 +23,12 @@ def test_constructor_graph(): RadiusGraph(x=x_, pos=pos_, radius=0.2, edge_attr=True) for x_, pos_ in zip(x, pos) ] - output_ = torch.rand((100, 10)) + output_ = torch.rand((20, 100, 10)) problem = SupervisedProblem(input_=input_, output_=output_) assert isinstance(problem, AbstractProblem) assert hasattr(problem, "conditions") assert isinstance(problem.conditions, dict) assert list(problem.conditions.keys()) == ["data"] - assert isinstance(problem.conditions["data"], InputOutputPointsCondition) - assert isinstance(problem.conditions["data"].input_points, list) - assert isinstance(problem.conditions["data"].output_points, torch.Tensor) + assert isinstance(problem.conditions["data"], InputTargetCondition) + assert isinstance(problem.conditions["data"].input, list) + assert isinstance(problem.conditions["data"].target, torch.Tensor) diff --git a/tests/test_solver/test_causal_pinn.py b/tests/test_solver/test_causal_pinn.py index f051df9..c813b63 100644 --- a/tests/test_solver/test_causal_pinn.py +++ b/tests/test_solver/test_causal_pinn.py @@ -11,8 +11,8 @@ from pina.problem.zoo import ( InverseDiffusionReactionProblem ) from pina.condition import ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) from torch._dynamo.eval_frame import OptimizedModule @@ -43,8 +43,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables) output_pts = torch.rand(50, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions['data'] = Condition( - input_points=input_pts, - output_points=output_pts + input=input_pts, + target=output_pts ) @@ -56,8 +56,8 @@ def test_constructor(problem, eps): solver = CausalPINN(model=model, problem=problem, eps=eps) assert solver.accepted_conditions_types == ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) diff --git a/tests/test_solver/test_competitive_pinn.py b/tests/test_solver/test_competitive_pinn.py index 4468210..4190edc 100644 --- a/tests/test_solver/test_competitive_pinn.py +++ b/tests/test_solver/test_competitive_pinn.py @@ -10,8 +10,8 @@ from pina.problem.zoo import ( InversePoisson2DSquareProblem as InversePoisson ) from pina.condition import ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) from torch._dynamo.eval_frame import OptimizedModule @@ -33,8 +33,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables) output_pts = torch.rand(50, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions['data'] = Condition( - input_points=input_pts, - output_points=output_pts + input=input_pts, + target=output_pts ) @pytest.mark.parametrize("problem", [problem, inverse_problem]) @@ -44,8 +44,8 @@ def test_constructor(problem, discr): solver = CompPINN(problem=problem, model=model, discriminator=discr) assert solver.accepted_conditions_types == ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) diff --git a/tests/test_solver/test_garom.py b/tests/test_solver/test_garom.py index dfe9fc4..aef2286 100644 --- a/tests/test_solver/test_garom.py +++ b/tests/test_solver/test_garom.py @@ -4,7 +4,7 @@ import torch.nn as nn import pytest from pina import Condition, LabelTensor from pina.solver import GAROM -from pina.condition import InputOutputPointsCondition +from pina.condition import InputTargetCondition from pina.problem import AbstractProblem from pina.model import FeedForward from pina.trainer import Trainer @@ -16,8 +16,8 @@ class TensorProblem(AbstractProblem): output_variables = ['u'] conditions = { 'data': Condition( - output_points=torch.randn(50, 2), - input_points=torch.randn(50, 1)) + target=torch.randn(50, 2), + input=torch.randn(50, 1)) } @@ -74,7 +74,7 @@ def test_constructor(): generator=Generator(), discriminator=Discriminator()) assert GAROM.accepted_conditions_types == ( - InputOutputPointsCondition + InputTargetCondition ) diff --git a/tests/test_solver/test_gradient_pinn.py b/tests/test_solver/test_gradient_pinn.py index 0bab687..e7c5adb 100644 --- a/tests/test_solver/test_gradient_pinn.py +++ b/tests/test_solver/test_gradient_pinn.py @@ -11,8 +11,8 @@ from pina.problem.zoo import ( InversePoisson2DSquareProblem as InversePoisson ) from pina.condition import ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) from torch._dynamo.eval_frame import OptimizedModule @@ -43,8 +43,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables) output_pts = torch.rand(50, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions['data'] = Condition( - input_points=input_pts, - output_points=output_pts + input=input_pts, + target=output_pts ) @@ -55,8 +55,8 @@ def test_constructor(problem): solver = GradientPINN(model=model, problem=problem) assert solver.accepted_conditions_types == ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) diff --git a/tests/test_solver/test_pinn.py b/tests/test_solver/test_pinn.py index 6475726..88a1d06 100644 --- a/tests/test_solver/test_pinn.py +++ b/tests/test_solver/test_pinn.py @@ -6,8 +6,8 @@ from pina.model import FeedForward from pina.trainer import Trainer from pina.solver import PINN from pina.condition import ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) from pina.problem.zoo import ( @@ -33,8 +33,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables) output_pts = torch.rand(50, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions['data'] = Condition( - input_points=input_pts, - output_points=output_pts + input=input_pts, + target=output_pts ) @pytest.mark.parametrize("problem", [problem, inverse_problem]) @@ -42,8 +42,8 @@ def test_constructor(problem): solver = PINN(problem=problem, model=model) assert solver.accepted_conditions_types == ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) diff --git a/tests/test_solver/test_rba_pinn.py b/tests/test_solver/test_rba_pinn.py index 7777091..cb29084 100644 --- a/tests/test_solver/test_rba_pinn.py +++ b/tests/test_solver/test_rba_pinn.py @@ -6,8 +6,8 @@ from pina.model import FeedForward from pina.trainer import Trainer from pina.solver import RBAPINN from pina.condition import ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) from pina.problem.zoo import ( @@ -32,8 +32,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables) output_pts = torch.rand(50, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions['data'] = Condition( - input_points=input_pts, - output_points=output_pts + input=input_pts, + target=output_pts ) @@ -46,8 +46,8 @@ def test_constructor(problem, eta, gamma): solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma) assert solver.accepted_conditions_types == ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) diff --git a/tests/test_solver/test_reduced_order_model_solver.py b/tests/test_solver/test_reduced_order_model_solver.py index a537192..e00b587 100644 --- a/tests/test_solver/test_reduced_order_model_solver.py +++ b/tests/test_solver/test_reduced_order_model_solver.py @@ -3,7 +3,7 @@ import pytest from pina import Condition, LabelTensor from pina.problem import AbstractProblem -from pina.condition import InputOutputPointsCondition +from pina.condition import InputTargetCondition from pina.solver import ReducedOrderModelSolver from pina.trainer import Trainer from pina.model import FeedForward @@ -16,8 +16,8 @@ class LabelTensorProblem(AbstractProblem): output_variables = ['u'] conditions = { 'data': Condition( - input_points=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']), - output_points=LabelTensor(torch.randn(20, 1), ['u'])), + input=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']), + target=LabelTensor(torch.randn(20, 1), ['u'])), } @@ -26,8 +26,8 @@ class TensorProblem(AbstractProblem): output_variables = ['u'] conditions = { 'data': Condition( - input_points=torch.randn(20, 2), - output_points=torch.randn(20, 1)) + input=torch.randn(20, 2), + target=torch.randn(20, 1)) } @@ -68,7 +68,7 @@ def test_constructor(): ReducedOrderModelSolver(problem=LabelTensorProblem(), reduction_network=reduction_net, interpolation_network=interpolation_net) - assert ReducedOrderModelSolver.accepted_conditions_types == InputOutputPointsCondition + assert ReducedOrderModelSolver.accepted_conditions_types == InputTargetCondition with pytest.raises(SyntaxError): ReducedOrderModelSolver(problem=problem, reduction_network=AE_missing_encode( diff --git a/tests/test_solver/test_self_adaptive_pinn.py b/tests/test_solver/test_self_adaptive_pinn.py index 348c63a..b2df9fa 100644 --- a/tests/test_solver/test_self_adaptive_pinn.py +++ b/tests/test_solver/test_self_adaptive_pinn.py @@ -10,8 +10,8 @@ from pina.problem.zoo import ( InversePoisson2DSquareProblem as InversePoisson ) from pina.condition import ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) from torch._dynamo.eval_frame import OptimizedModule @@ -33,8 +33,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables) output_pts = torch.rand(50, len(problem.output_variables)) output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions['data'] = Condition( - input_points=input_pts, - output_points=output_pts + input=input_pts, + target=output_pts ) @@ -46,8 +46,8 @@ def test_constructor(problem, weight_fn): solver = SAPINN(problem=problem, model=model, weight_function=weight_fn) assert solver.accepted_conditions_types == ( - InputOutputPointsCondition, - InputPointsEquationCondition, + InputTargetCondition, + InputEquationCondition, DomainEquationCondition ) diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 2b8d608..e8aad5b 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -1,7 +1,7 @@ import torch import pytest from pina import Condition, LabelTensor -from pina.condition import InputOutputPointsCondition +from pina.condition import InputTargetCondition from pina.problem import AbstractProblem from pina.solver import SupervisedSolver from pina.model import FeedForward @@ -14,8 +14,8 @@ class LabelTensorProblem(AbstractProblem): output_variables = ['u'] conditions = { 'data': Condition( - input_points=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']), - output_points=LabelTensor(torch.randn(20, 1), ['u'])), + input=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']), + target=LabelTensor(torch.randn(20, 1), ['u'])), } @@ -24,8 +24,8 @@ class TensorProblem(AbstractProblem): output_variables = ['u'] conditions = { 'data': Condition( - input_points=torch.randn(20, 2), - output_points=torch.randn(20, 1)) + input=torch.randn(20, 2), + target=torch.randn(20, 1)) } @@ -36,7 +36,7 @@ def test_constructor(): SupervisedSolver(problem=TensorProblem(), model=model) SupervisedSolver(problem=LabelTensorProblem(), model=model) assert SupervisedSolver.accepted_conditions_types == ( - InputOutputPointsCondition + InputTargetCondition )