Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod
|
||||
from ..utils import check_consistency
|
||||
from ..domain import DomainInterface, CartesianDomain
|
||||
from ..condition.domain_equation_condition import DomainEquationCondition
|
||||
from ..condition import InputPointsEquationCondition
|
||||
from copy import deepcopy
|
||||
from .. import LabelTensor
|
||||
from ..utils import merge_tensors
|
||||
@@ -55,8 +54,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
def input_pts(self):
|
||||
to_return = {}
|
||||
for cond_name, cond in self.conditions.items():
|
||||
if hasattr(cond, "input_points"):
|
||||
to_return[cond_name] = cond.input_points
|
||||
if hasattr(cond, "input"):
|
||||
to_return[cond_name] = cond.input
|
||||
elif hasattr(cond, "domain"):
|
||||
to_return[cond_name] = self._discretised_domains[cond.domain]
|
||||
return to_return
|
||||
|
||||
@@ -46,8 +46,8 @@ class InverseDiffusionReactionProblem(
|
||||
equation=Equation(diffusion_reaction),
|
||||
),
|
||||
"data": Condition(
|
||||
input_points=LabelTensor(torch.randn(10, 2), ["x", "t"]),
|
||||
output_points=LabelTensor(torch.randn(10, 1), ["u"]),
|
||||
input=LabelTensor(torch.randn(10, 2), ["x", "t"]),
|
||||
target=LabelTensor(torch.randn(10, 1), ["u"]),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
|
||||
"nil_g4": Condition(domain="g4", equation=FixedValue(0.0)),
|
||||
"laplace_D": Condition(domain="D", equation=Equation(laplace_equation)),
|
||||
"data": Condition(
|
||||
input_points=data_input.extract(["x", "y"]),
|
||||
output_points=data_output,
|
||||
input=data_input.extract(["x", "y"]),
|
||||
target=data_output,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ class SupervisedProblem(AbstractProblem):
|
||||
A problem definition for supervised learning in PINA.
|
||||
|
||||
This class allows an easy and straightforward definition of a Supervised problem,
|
||||
based on a single condition of type `InputOutputPointsCondition`
|
||||
based on a single condition of type `InputTargetCondition`
|
||||
|
||||
:Example:
|
||||
>>> import torch
|
||||
@@ -31,7 +31,5 @@ class SupervisedProblem(AbstractProblem):
|
||||
"""
|
||||
if isinstance(input_, Graph):
|
||||
input_ = input_.data
|
||||
self.conditions["data"] = Condition(
|
||||
input_points=input_, output_points=output_
|
||||
)
|
||||
self.conditions["data"] = Condition(input=input_, target=output_)
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user