Update Tutorials 0.2 (#490)
This commit is contained in:
committed by
Nicola Demo
parent
beee4cdc0b
commit
6ce0bafc2b
@@ -2,7 +2,7 @@
|
||||
|
||||
from ..abstract_problem import AbstractProblem
|
||||
from ... import Condition
|
||||
from ... import Graph
|
||||
from ... import LabelTensor
|
||||
|
||||
|
||||
class SupervisedProblem(AbstractProblem):
|
||||
@@ -22,16 +22,22 @@ class SupervisedProblem(AbstractProblem):
|
||||
|
||||
conditions = {}
|
||||
output_variables = None
|
||||
input_variables = None
|
||||
|
||||
def __init__(self, input_, output_):
|
||||
def __init__(
|
||||
self, input_, output_, input_variables=None, output_variables=None
|
||||
):
|
||||
"""
|
||||
Initialize the SupervisedProblem class.
|
||||
|
||||
:param input_: Input data of the problem.
|
||||
:type input_: torch.Tensor | LabelTensor | Graph | Data
|
||||
:param output_: Output data of the problem.
|
||||
:type output_: torch.Tensor | Graph
|
||||
:type output_: torch.Tensor | LabelTensor | Graph | Data
|
||||
"""
|
||||
if isinstance(input_, Graph):
|
||||
input_ = input_.data
|
||||
# Set input and output variables
|
||||
self.input_variables = input_variables
|
||||
self.output_variables = output_variables
|
||||
# Set the condition
|
||||
self.conditions["data"] = Condition(input=input_, target=output_)
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user