Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -46,8 +46,8 @@ class InverseDiffusionReactionProblem(
equation=Equation(diffusion_reaction),
),
"data": Condition(
input_points=LabelTensor(torch.randn(10, 2), ["x", "t"]),
output_points=LabelTensor(torch.randn(10, 1), ["u"]),
input=LabelTensor(torch.randn(10, 2), ["x", "t"]),
target=LabelTensor(torch.randn(10, 1), ["u"]),
),
}

View File

@@ -50,7 +50,7 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
"nil_g4": Condition(domain="g4", equation=FixedValue(0.0)),
"laplace_D": Condition(domain="D", equation=Equation(laplace_equation)),
"data": Condition(
input_points=data_input.extract(["x", "y"]),
output_points=data_output,
input=data_input.extract(["x", "y"]),
target=data_output,
),
}

View File

@@ -8,7 +8,7 @@ class SupervisedProblem(AbstractProblem):
A problem definition for supervised learning in PINA.
This class allows an easy and straightforward definition of a Supervised problem,
based on a single condition of type `InputOutputPointsCondition`
based on a single condition of type `InputTargetCondition`
:Example:
>>> import torch
@@ -31,7 +31,5 @@ class SupervisedProblem(AbstractProblem):
"""
if isinstance(input_, Graph):
input_ = input_.data
self.conditions["data"] = Condition(
input_points=input_, output_points=output_
)
self.conditions["data"] = Condition(input=input_, target=output_)
super().__init__()