Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,12 +1,12 @@
|
||||
__all__ = [
|
||||
'Condition',
|
||||
'ConditionInterface',
|
||||
'DomainEquationCondition',
|
||||
'InputPointsEquationCondition',
|
||||
'InputOutputPointsCondition',
|
||||
"Condition",
|
||||
"ConditionInterface",
|
||||
"DomainEquationCondition",
|
||||
"InputPointsEquationCondition",
|
||||
"InputOutputPointsCondition",
|
||||
]
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
from .domain_equation_condition import DomainEquationCondition
|
||||
from .input_equation_condition import InputPointsEquationCondition
|
||||
from .input_output_condition import InputOutputPointsCondition
|
||||
from .input_output_condition import InputOutputPointsCondition
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Condition module. """
|
||||
"""Condition module."""
|
||||
|
||||
from .domain_equation_condition import DomainEquationCondition
|
||||
from .input_equation_condition import InputPointsEquationCondition
|
||||
@@ -11,6 +11,7 @@ from ..utils import custom_warning_format
|
||||
warnings.formatwarning = custom_warning_format
|
||||
warnings.filterwarnings("always", category=DeprecationWarning)
|
||||
|
||||
|
||||
class Condition:
|
||||
"""
|
||||
The class ``Condition`` is used to represent the constraints (physical
|
||||
@@ -44,24 +45,30 @@ class Condition:
|
||||
"""
|
||||
|
||||
__slots__ = list(
|
||||
set(InputOutputPointsCondition.__slots__ +
|
||||
InputPointsEquationCondition.__slots__ +
|
||||
DomainEquationCondition.__slots__ +
|
||||
DataConditionInterface.__slots__))
|
||||
set(
|
||||
InputOutputPointsCondition.__slots__
|
||||
+ InputPointsEquationCondition.__slots__
|
||||
+ DomainEquationCondition.__slots__
|
||||
+ DataConditionInterface.__slots__
|
||||
)
|
||||
)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
||||
if len(args) != 0:
|
||||
raise ValueError("Condition takes only the following keyword "
|
||||
f"arguments: {Condition.__slots__}.")
|
||||
raise ValueError(
|
||||
"Condition takes only the following keyword "
|
||||
f"arguments: {Condition.__slots__}."
|
||||
)
|
||||
|
||||
# back-compatibility 0.1
|
||||
if 'location' in kwargs.keys():
|
||||
kwargs['domain'] = kwargs.pop('location')
|
||||
if "location" in kwargs.keys():
|
||||
kwargs["domain"] = kwargs.pop("location")
|
||||
warnings.warn(
|
||||
f"'location' is deprecated and will be removed "
|
||||
f"in future versions. Please use 'domain' instead.",
|
||||
DeprecationWarning)
|
||||
f"'location' is deprecated and will be removed "
|
||||
f"in future versions. Please use 'domain' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
sorted_keys = sorted(kwargs.keys())
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from abc import ABCMeta
|
||||
|
||||
class ConditionInterface(metaclass=ABCMeta):
|
||||
|
||||
condition_types = ['physics', 'supervised', 'unsupervised']
|
||||
condition_types = ["physics", "supervised", "unsupervised"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._condition_type = None
|
||||
@@ -28,6 +28,7 @@ class ConditionInterface(metaclass=ABCMeta):
|
||||
for value in values:
|
||||
if value not in ConditionInterface.condition_types:
|
||||
raise ValueError(
|
||||
'Unavailable type of condition, expected one of'
|
||||
f' {ConditionInterface.condition_types}.')
|
||||
"Unavailable type of condition, expected one of"
|
||||
f" {ConditionInterface.condition_types}."
|
||||
)
|
||||
self._condition_type = values
|
||||
|
||||
@@ -25,8 +25,8 @@ class DataConditionInterface(ConditionInterface):
|
||||
self.conditional_variables = conditional_variables
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'input_points') or (key == 'conditional_variables'):
|
||||
if (key == "input_points") or (key == "conditional_variables"):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
DataConditionInterface.__dict__[key].__set__(self, value)
|
||||
elif key in ('_problem', '_condition_type'):
|
||||
elif key in ("_problem", "_condition_type"):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@@ -23,11 +23,11 @@ class DomainEquationCondition(ConditionInterface):
|
||||
self.equation = equation
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'domain':
|
||||
if key == "domain":
|
||||
check_consistency(value, (DomainInterface, str))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
elif key == "equation":
|
||||
check_consistency(value, (EquationInterface))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key in ('_problem', '_condition_type'):
|
||||
elif key in ("_problem", "_condition_type"):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@@ -24,13 +24,13 @@ class InputPointsEquationCondition(ConditionInterface):
|
||||
self.equation = equation
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'input_points':
|
||||
if key == "input_points":
|
||||
check_consistency(
|
||||
value, (LabelTensor)
|
||||
) # for now only labeltensors, we need labels for the operator!
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
elif key == "equation":
|
||||
check_consistency(value, (EquationInterface))
|
||||
InputPointsEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key in ('_problem', '_condition_type'):
|
||||
elif key in ("_problem", "_condition_type"):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@@ -24,8 +24,11 @@ class InputOutputPointsCondition(ConditionInterface):
|
||||
self.output_points = output_points
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'input_points') or (key == 'output_points'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data))
|
||||
if (key == "input_points") or (key == "output_points"):
|
||||
check_consistency(
|
||||
value,
|
||||
(LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data),
|
||||
)
|
||||
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
||||
elif key in ('_problem', '_condition_type'):
|
||||
elif key in ("_problem", "_condition_type"):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
Reference in New Issue
Block a user