Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,12 +1,12 @@
__all__ = [
'Condition',
'ConditionInterface',
'DomainEquationCondition',
'InputPointsEquationCondition',
'InputOutputPointsCondition',
"Condition",
"ConditionInterface",
"DomainEquationCondition",
"InputPointsEquationCondition",
"InputOutputPointsCondition",
]
from .condition_interface import ConditionInterface
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .input_output_condition import InputOutputPointsCondition

View File

@@ -1,4 +1,4 @@
""" Condition module. """
"""Condition module."""
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
@@ -11,6 +11,7 @@ from ..utils import custom_warning_format
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=DeprecationWarning)
class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
@@ -44,24 +45,30 @@ class Condition:
"""
__slots__ = list(
set(InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__))
set(
InputOutputPointsCondition.__slots__
+ InputPointsEquationCondition.__slots__
+ DomainEquationCondition.__slots__
+ DataConditionInterface.__slots__
)
)
def __new__(cls, *args, **kwargs):
if len(args) != 0:
raise ValueError("Condition takes only the following keyword "
f"arguments: {Condition.__slots__}.")
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
)
# back-compatibility 0.1
if 'location' in kwargs.keys():
kwargs['domain'] = kwargs.pop('location')
if "location" in kwargs.keys():
kwargs["domain"] = kwargs.pop("location")
warnings.warn(
f"'location' is deprecated and will be removed "
f"in future versions. Please use 'domain' instead.",
DeprecationWarning)
f"'location' is deprecated and will be removed "
f"in future versions. Please use 'domain' instead.",
DeprecationWarning,
)
sorted_keys = sorted(kwargs.keys())

View File

@@ -3,7 +3,7 @@ from abc import ABCMeta
class ConditionInterface(metaclass=ABCMeta):
condition_types = ['physics', 'supervised', 'unsupervised']
condition_types = ["physics", "supervised", "unsupervised"]
def __init__(self, *args, **kwargs):
self._condition_type = None
@@ -28,6 +28,7 @@ class ConditionInterface(metaclass=ABCMeta):
for value in values:
if value not in ConditionInterface.condition_types:
raise ValueError(
'Unavailable type of condition, expected one of'
f' {ConditionInterface.condition_types}.')
"Unavailable type of condition, expected one of"
f" {ConditionInterface.condition_types}."
)
self._condition_type = values

View File

@@ -25,8 +25,8 @@ class DataConditionInterface(ConditionInterface):
self.conditional_variables = conditional_variables
def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
if (key == "input_points") or (key == "conditional_variables"):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
elif key in ("_problem", "_condition_type"):
super().__setattr__(key, value)

View File

@@ -23,11 +23,11 @@ class DomainEquationCondition(ConditionInterface):
self.equation = equation
def __setattr__(self, key, value):
if key == 'domain':
if key == "domain":
check_consistency(value, (DomainInterface, str))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
elif key == "equation":
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
elif key in ("_problem", "_condition_type"):
super().__setattr__(key, value)

View File

@@ -24,13 +24,13 @@ class InputPointsEquationCondition(ConditionInterface):
self.equation = equation
def __setattr__(self, key, value):
if key == 'input_points':
if key == "input_points":
check_consistency(
value, (LabelTensor)
) # for now only labeltensors, we need labels for the operator!
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
elif key == "equation":
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
elif key in ("_problem", "_condition_type"):
super().__setattr__(key, value)

View File

@@ -24,8 +24,11 @@ class InputOutputPointsCondition(ConditionInterface):
self.output_points = output_points
def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data))
if (key == "input_points") or (key == "output_points"):
check_consistency(
value,
(LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data),
)
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
elif key in ("_problem", "_condition_type"):
super().__setattr__(key, value)