Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -1,10 +1,10 @@
|
||||
"""Condition module."""
|
||||
|
||||
from .domain_equation_condition import DomainEquationCondition
|
||||
from .input_equation_condition import InputPointsEquationCondition
|
||||
from .input_output_condition import InputOutputPointsCondition
|
||||
from .data_condition import DataConditionInterface
|
||||
import warnings
|
||||
from .data_condition import DataCondition
|
||||
from .domain_equation_condition import DomainEquationCondition
|
||||
from .input_equation_condition import InputEquationCondition
|
||||
from .input_target_condition import InputTargetCondition
|
||||
from ..utils import custom_warning_format
|
||||
|
||||
# Set the custom format for warnings
|
||||
@@ -12,6 +12,21 @@ warnings.formatwarning = custom_warning_format
|
||||
warnings.filterwarnings("always", category=DeprecationWarning)
|
||||
|
||||
|
||||
def warning_function(new, old):
|
||||
"""Handle the deprecation warning.
|
||||
|
||||
:param new: Object to use instead of the old one.
|
||||
:type new: str
|
||||
:param old: Object to deprecate.
|
||||
:type old: str
|
||||
"""
|
||||
warnings.warn(
|
||||
f"'{old}' is deprecated and will be removed "
|
||||
f"in future versions. Please use '{new}' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
class Condition:
|
||||
"""
|
||||
The class ``Condition`` is used to represent the constraints (physical
|
||||
@@ -40,16 +55,32 @@ class Condition:
|
||||
|
||||
Example::
|
||||
|
||||
>>> TODO
|
||||
>>> from pina import Condition
|
||||
>>> condition = Condition(
|
||||
... input=input,
|
||||
... target=target
|
||||
... )
|
||||
>>> condition = Condition(
|
||||
... domain=location,
|
||||
... equation=equation
|
||||
... )
|
||||
>>> condition = Condition(
|
||||
... input=input,
|
||||
... equation=equation
|
||||
... )
|
||||
>>> condition = Condition(
|
||||
... input=data,
|
||||
... conditional_variables=conditional_variables
|
||||
... )
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = list(
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__
|
||||
+ InputPointsEquationCondition.__slots__
|
||||
InputTargetCondition.__slots__
|
||||
+ InputEquationCondition.__slots__
|
||||
+ DomainEquationCondition.__slots__
|
||||
+ DataConditionInterface.__slots__
|
||||
+ DataCondition.__slots__
|
||||
)
|
||||
)
|
||||
|
||||
@@ -62,25 +93,30 @@ class Condition:
|
||||
)
|
||||
|
||||
# back-compatibility 0.1
|
||||
if "location" in kwargs.keys():
|
||||
keys = list(kwargs.keys())
|
||||
if "location" in keys:
|
||||
kwargs["domain"] = kwargs.pop("location")
|
||||
warnings.warn(
|
||||
f"'location' is deprecated and will be removed "
|
||||
f"in future versions. Please use 'domain' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
warning_function(new="domain", old="location")
|
||||
|
||||
if "input_points" in keys:
|
||||
kwargs["input"] = kwargs.pop("input_points")
|
||||
warning_function(new="input", old="input_points")
|
||||
|
||||
if "output_points" in keys:
|
||||
kwargs["target"] = kwargs.pop("output_points")
|
||||
warning_function(new="target", old="output_points")
|
||||
|
||||
sorted_keys = sorted(kwargs.keys())
|
||||
|
||||
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
|
||||
return InputOutputPointsCondition(**kwargs)
|
||||
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
|
||||
return InputPointsEquationCondition(**kwargs)
|
||||
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
|
||||
if sorted_keys == sorted(InputTargetCondition.__slots__):
|
||||
return InputTargetCondition(**kwargs)
|
||||
if sorted_keys == sorted(InputEquationCondition.__slots__):
|
||||
return InputEquationCondition(**kwargs)
|
||||
if sorted_keys == sorted(DomainEquationCondition.__slots__):
|
||||
return DomainEquationCondition(**kwargs)
|
||||
elif sorted_keys == sorted(DataConditionInterface.__slots__):
|
||||
return DataConditionInterface(**kwargs)
|
||||
elif sorted_keys == DataConditionInterface.__slots__[0]:
|
||||
return DataConditionInterface(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
if (
|
||||
sorted_keys == sorted(DataCondition.__slots__)
|
||||
or sorted_keys[0] == DataCondition.__slots__[0]
|
||||
):
|
||||
return DataCondition(**kwargs)
|
||||
|
||||
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||
|
||||
Reference in New Issue
Block a user