Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -1,10 +1,10 @@
"""Condition module."""
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface
import warnings
from .data_condition import DataCondition
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputEquationCondition
from .input_target_condition import InputTargetCondition
from ..utils import custom_warning_format
# Set the custom format for warnings
@@ -12,6 +12,21 @@ warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)
def warning_function(new, old):
"""Handle the deprecation warning.
:param new: Object to use instead of the old one.
:type new: str
:param old: Object to deprecate.
:type old: str
"""
warnings.warn(
f"'{old}' is deprecated and will be removed "
f"in future versions. Please use '{new}' instead.",
DeprecationWarning,
)
class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
@@ -40,16 +55,32 @@ class Condition:
Example::
>>> TODO
>>> from pina import Condition
>>> condition = Condition(
... input=input,
... target=target
... )
>>> condition = Condition(
... domain=location,
... equation=equation
... )
>>> condition = Condition(
... input=input,
... equation=equation
... )
>>> condition = Condition(
... input=data,
... conditional_variables=conditional_variables
... )
"""
__slots__ = list(
set(
InputOutputPointsCondition.__slots__
+ InputPointsEquationCondition.__slots__
InputTargetCondition.__slots__
+ InputEquationCondition.__slots__
+ DomainEquationCondition.__slots__
+ DataConditionInterface.__slots__
+ DataCondition.__slots__
)
)
@@ -62,25 +93,30 @@ class Condition:
)
# back-compatibility 0.1
if "location" in kwargs.keys():
keys = list(kwargs.keys())
if "location" in keys:
kwargs["domain"] = kwargs.pop("location")
warnings.warn(
f"'location' is deprecated and will be removed "
f"in future versions. Please use 'domain' instead.",
DeprecationWarning,
)
warning_function(new="domain", old="location")
if "input_points" in keys:
kwargs["input"] = kwargs.pop("input_points")
warning_function(new="input", old="input_points")
if "output_points" in keys:
kwargs["target"] = kwargs.pop("output_points")
warning_function(new="target", old="output_points")
sorted_keys = sorted(kwargs.keys())
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
return InputOutputPointsCondition(**kwargs)
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
return InputPointsEquationCondition(**kwargs)
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
if sorted_keys == sorted(InputTargetCondition.__slots__):
return InputTargetCondition(**kwargs)
if sorted_keys == sorted(InputEquationCondition.__slots__):
return InputEquationCondition(**kwargs)
if sorted_keys == sorted(DomainEquationCondition.__slots__):
return DomainEquationCondition(**kwargs)
elif sorted_keys == sorted(DataConditionInterface.__slots__):
return DataConditionInterface(**kwargs)
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
if (
sorted_keys == sorted(DataCondition.__slots__)
or sorted_keys[0] == DataCondition.__slots__[0]
):
return DataCondition(**kwargs)
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")