Update Tutorials 0.2 (#490)

This commit is contained in:
Dario Coscia
2025-03-13 20:15:48 +01:00
committed by Nicola Demo
parent beee4cdc0b
commit 6ce0bafc2b
30 changed files with 1526 additions and 5227 deletions

View File

@@ -2,7 +2,7 @@
from ..abstract_problem import AbstractProblem
from ... import Condition
from ... import Graph
from ... import LabelTensor
class SupervisedProblem(AbstractProblem):
@@ -22,16 +22,22 @@ class SupervisedProblem(AbstractProblem):
conditions = {}
output_variables = None
input_variables = None
def __init__(self, input_, output_):
def __init__(
self, input_, output_, input_variables=None, output_variables=None
):
"""
Initialize the SupervisedProblem class.
:param input_: Input data of the problem.
:type input_: torch.Tensor | LabelTensor | Graph | Data
:param output_: Output data of the problem.
:type output_: torch.Tensor | Graph
:type output_: torch.Tensor | LabelTensor | Graph | Data
"""
if isinstance(input_, Graph):
input_ = input_.data
# Set input and output variables
self.input_variables = input_variables
self.output_variables = output_variables
# Set the condition
self.conditions["data"] = Condition(input=input_, target=output_)
super().__init__()