37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
from pina.problem import AbstractProblem
|
|
from pina import Condition
|
|
from pina import Graph
|
|
|
|
class SupervisedProblem(AbstractProblem):
|
|
"""
|
|
A problem definition for supervised learning in PINA.
|
|
|
|
This class allows an easy and straightforward definition of a Supervised problem,
|
|
based on a single condition of type `InputOutputPointsCondition`
|
|
|
|
:Example:
|
|
>>> import torch
|
|
>>> input_data = torch.rand((100, 10))
|
|
>>> output_data = torch.rand((100, 10))
|
|
>>> problem = SupervisedProblem(input_data, output_data)
|
|
"""
|
|
conditions = dict()
|
|
output_variables = None
|
|
|
|
def __init__(self, input_, output_):
|
|
"""
|
|
Initialize the SupervisedProblem class
|
|
|
|
:param input_: Input data of the problem
|
|
:type input_: torch.Tensor | Graph
|
|
:param output_: Output data of the problem
|
|
:type output_: torch.Tensor
|
|
"""
|
|
if isinstance(input_, Graph):
|
|
input_ = input_.data
|
|
self.conditions['data'] = Condition(
|
|
input_points=input_,
|
|
output_points = output_
|
|
)
|
|
super().__init__()
|
|
|