From 780c4921ebbe778462993461c61fa1db9585cad2 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Thu, 13 Feb 2025 20:15:56 +0100 Subject: [PATCH] Add Supervised Problem (#451) * Add SuperviedProblem class in problem zoo --- pina/problem/zoo/__init__.py | 7 +++- pina/problem/zoo/supervised_problem.py | 37 +++++++++++++++++++ .../test_supervised_problem.py | 31 ++++++++++++++++ 3 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 pina/problem/zoo/supervised_problem.py create mode 100644 tests/test_problem_zoo/test_supervised_problem.py diff --git a/pina/problem/zoo/__init__.py b/pina/problem/zoo/__init__.py index ea9aa78..3e3333d 100644 --- a/pina/problem/zoo/__init__.py +++ b/pina/problem/zoo/__init__.py @@ -1,5 +1,8 @@ __all__ = [ - 'Poisson2DSquareProblem' + 'Poisson2DSquareProblem', + 'SupervisedProblem' + ] -from .poisson_2d_square import Poisson2DSquareProblem \ No newline at end of file +from .poisson_2d_square import Poisson2DSquareProblem +from .supervised_problem import SupervisedProblem \ No newline at end of file diff --git a/pina/problem/zoo/supervised_problem.py b/pina/problem/zoo/supervised_problem.py new file mode 100644 index 0000000..6acac7a --- /dev/null +++ b/pina/problem/zoo/supervised_problem.py @@ -0,0 +1,37 @@ +from pina.problem import AbstractProblem +from pina import Condition +from pina import Graph + +class SupervisedProblem(AbstractProblem): + """ + A problem definition for supervised learning in PINA. + + This class allows an easy and straightforward definition of a Supervised problem, + based on a single condition of type `InputOutputPointsCondition` + + :Example: + >>> import torch + >>> input_data = torch.rand((100, 10)) + >>> output_data = torch.rand((100, 10)) + >>> problem = SupervisedProblem(input_data, output_data) + """ + conditions = dict() + output_variables = None + + def __init__(self, input_, output_): + """ + Initialize the SupervisedProblem class + + :param input_: Input data of the problem + :type input_: torch.Tensor | Graph + :param output_: Output data of the problem + :type output_: torch.Tensor + """ + if isinstance(input_, Graph): + input_ = input_.data + self.conditions['data'] = Condition( + input_points=input_, + output_points = output_ + ) + super().__init__() + \ No newline at end of file diff --git a/tests/test_problem_zoo/test_supervised_problem.py b/tests/test_problem_zoo/test_supervised_problem.py new file mode 100644 index 0000000..b9c7950 --- /dev/null +++ b/tests/test_problem_zoo/test_supervised_problem.py @@ -0,0 +1,31 @@ +import torch +from pina.problem import AbstractProblem +from pina.condition import InputOutputPointsCondition +from pina.problem.zoo.supervised_problem import SupervisedProblem +from pina import RadiusGraph + +def test_constructor(): + input_ = torch.rand((100,10)) + output_ = torch.rand((100,10)) + problem = SupervisedProblem(input_=input_, output_=output_) + assert isinstance(problem, AbstractProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) + assert list(problem.conditions.keys()) == ['data'] + assert isinstance(problem.conditions['data'], InputOutputPointsCondition) + +def test_constructor_graph(): + x = torch.rand((20,100,10)) + pos = torch.rand((20,100,2)) + input_ = RadiusGraph( + x=x, pos=pos, r=.2, build_edge_attr=True + ) + output_ = torch.rand((100,10)) + problem = SupervisedProblem(input_=input_, output_=output_) + assert isinstance(problem, AbstractProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) + assert list(problem.conditions.keys()) == ['data'] + assert isinstance(problem.conditions['data'], InputOutputPointsCondition) + assert isinstance(problem.conditions['data'].input_points, list) + assert isinstance(problem.conditions['data'].output_points, torch.Tensor)