Add Supervised Problem (#451)
* Add SuperviedProblem class in problem zoo
This commit is contained in:
committed by
Nicola Demo
parent
c6f1aafdec
commit
780c4921eb
@@ -1,5 +1,8 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'Poisson2DSquareProblem'
|
'Poisson2DSquareProblem',
|
||||||
|
'SupervisedProblem'
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
from .poisson_2d_square import Poisson2DSquareProblem
|
from .poisson_2d_square import Poisson2DSquareProblem
|
||||||
|
from .supervised_problem import SupervisedProblem
|
||||||
37
pina/problem/zoo/supervised_problem.py
Normal file
37
pina/problem/zoo/supervised_problem.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from pina.problem import AbstractProblem
|
||||||
|
from pina import Condition
|
||||||
|
from pina import Graph
|
||||||
|
|
||||||
|
class SupervisedProblem(AbstractProblem):
|
||||||
|
"""
|
||||||
|
A problem definition for supervised learning in PINA.
|
||||||
|
|
||||||
|
This class allows an easy and straightforward definition of a Supervised problem,
|
||||||
|
based on a single condition of type `InputOutputPointsCondition`
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
>>> import torch
|
||||||
|
>>> input_data = torch.rand((100, 10))
|
||||||
|
>>> output_data = torch.rand((100, 10))
|
||||||
|
>>> problem = SupervisedProblem(input_data, output_data)
|
||||||
|
"""
|
||||||
|
conditions = dict()
|
||||||
|
output_variables = None
|
||||||
|
|
||||||
|
def __init__(self, input_, output_):
|
||||||
|
"""
|
||||||
|
Initialize the SupervisedProblem class
|
||||||
|
|
||||||
|
:param input_: Input data of the problem
|
||||||
|
:type input_: torch.Tensor | Graph
|
||||||
|
:param output_: Output data of the problem
|
||||||
|
:type output_: torch.Tensor
|
||||||
|
"""
|
||||||
|
if isinstance(input_, Graph):
|
||||||
|
input_ = input_.data
|
||||||
|
self.conditions['data'] = Condition(
|
||||||
|
input_points=input_,
|
||||||
|
output_points = output_
|
||||||
|
)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
31
tests/test_problem_zoo/test_supervised_problem.py
Normal file
31
tests/test_problem_zoo/test_supervised_problem.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
from pina.problem import AbstractProblem
|
||||||
|
from pina.condition import InputOutputPointsCondition
|
||||||
|
from pina.problem.zoo.supervised_problem import SupervisedProblem
|
||||||
|
from pina import RadiusGraph
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
input_ = torch.rand((100,10))
|
||||||
|
output_ = torch.rand((100,10))
|
||||||
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
|
assert isinstance(problem, AbstractProblem)
|
||||||
|
assert hasattr(problem, "conditions")
|
||||||
|
assert isinstance(problem.conditions, dict)
|
||||||
|
assert list(problem.conditions.keys()) == ['data']
|
||||||
|
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
|
||||||
|
|
||||||
|
def test_constructor_graph():
|
||||||
|
x = torch.rand((20,100,10))
|
||||||
|
pos = torch.rand((20,100,2))
|
||||||
|
input_ = RadiusGraph(
|
||||||
|
x=x, pos=pos, r=.2, build_edge_attr=True
|
||||||
|
)
|
||||||
|
output_ = torch.rand((100,10))
|
||||||
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
|
assert isinstance(problem, AbstractProblem)
|
||||||
|
assert hasattr(problem, "conditions")
|
||||||
|
assert isinstance(problem.conditions, dict)
|
||||||
|
assert list(problem.conditions.keys()) == ['data']
|
||||||
|
assert isinstance(problem.conditions['data'], InputOutputPointsCondition)
|
||||||
|
assert isinstance(problem.conditions['data'].input_points, list)
|
||||||
|
assert isinstance(problem.conditions['data'].output_points, torch.Tensor)
|
||||||
Reference in New Issue
Block a user