Add Supervised Problem (#451)

* Add SuperviedProblem class in problem zoo
This commit is contained in:
Filippo Olivo
2025-02-13 20:15:56 +01:00
committed by Nicola Demo
parent c6f1aafdec
commit 780c4921eb
3 changed files with 73 additions and 2 deletions

View File

@@ -1,5 +1,8 @@
__all__ = [
'Poisson2DSquareProblem'
'Poisson2DSquareProblem',
'SupervisedProblem'
]
from .poisson_2d_square import Poisson2DSquareProblem
from .supervised_problem import SupervisedProblem

View File

@@ -0,0 +1,37 @@
from pina.problem import AbstractProblem
from pina import Condition
from pina import Graph
class SupervisedProblem(AbstractProblem):
"""
A problem definition for supervised learning in PINA.
This class allows an easy and straightforward definition of a Supervised problem,
based on a single condition of type `InputOutputPointsCondition`
:Example:
>>> import torch
>>> input_data = torch.rand((100, 10))
>>> output_data = torch.rand((100, 10))
>>> problem = SupervisedProblem(input_data, output_data)
"""
conditions = dict()
output_variables = None
def __init__(self, input_, output_):
"""
Initialize the SupervisedProblem class
:param input_: Input data of the problem
:type input_: torch.Tensor | Graph
:param output_: Output data of the problem
:type output_: torch.Tensor
"""
if isinstance(input_, Graph):
input_ = input_.data
self.conditions['data'] = Condition(
input_points=input_,
output_points = output_
)
super().__init__()

View File

@@ -0,0 +1,31 @@
import torch
from pina.problem import AbstractProblem
from pina.condition import InputOutputPointsCondition
from pina.problem.zoo.supervised_problem import SupervisedProblem
from pina import RadiusGraph
def test_constructor():
input_ = torch.rand((100,10))
output_ = torch.rand((100,10))
problem = SupervisedProblem(input_=input_, output_=output_)
assert isinstance(problem, AbstractProblem)
assert hasattr(problem, "conditions")
assert isinstance(problem.conditions, dict)
assert list(problem.conditions.keys()) == ['data']
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
def test_constructor_graph():
x = torch.rand((20,100,10))
pos = torch.rand((20,100,2))
input_ = RadiusGraph(
x=x, pos=pos, r=.2, build_edge_attr=True
)
output_ = torch.rand((100,10))
problem = SupervisedProblem(input_=input_, output_=output_)
assert isinstance(problem, AbstractProblem)
assert hasattr(problem, "conditions")
assert isinstance(problem.conditions, dict)
assert list(problem.conditions.keys()) == ['data']
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
assert isinstance(problem.conditions['data'].input_points, list)
assert isinstance(problem.conditions['data'].output_points, torch.Tensor)