Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -168,8 +168,9 @@ class CartesianDomain(DomainInterface):
for variable in variables:
if variable in self.fixed_.keys():
value = self.fixed_[variable]
pts_variable = torch.tensor([[value]
]).repeat(result.shape[0], 1)
pts_variable = torch.tensor([[value]]).repeat(
result.shape[0], 1
)
pts_variable = pts_variable.as_subclass(LabelTensor)
pts_variable.labels = [variable]
@@ -202,8 +203,9 @@ class CartesianDomain(DomainInterface):
for variable in variables:
if variable in self.fixed_.keys():
value = self.fixed_[variable]
pts_variable = torch.tensor([[value]
]).repeat(result.shape[0], 1)
pts_variable = torch.tensor([[value]]).repeat(
result.shape[0], 1
)
pts_variable = pts_variable.as_subclass(LabelTensor)
pts_variable.labels = [variable]

View File

@@ -36,9 +36,11 @@ class DomainInterface(metaclass=ABCMeta):
values = [values]
for value in values:
if value not in DomainInterface.available_sampling_modes:
raise TypeError(f"mode {value} not valid. Expected at least "
"one in "
f"{DomainInterface.available_sampling_modes}.")
raise TypeError(
f"mode {value} not valid. Expected at least "
"one in "
f"{DomainInterface.available_sampling_modes}."
)
@abstractmethod
def sample(self):

View File

@@ -1,4 +1,4 @@
"""Module for Exclusion class. """
"""Module for Exclusion class."""
import torch
from ..label_tensor import LabelTensor

View File

@@ -1,4 +1,4 @@
"""Module for Intersection class. """
"""Module for Intersection class."""
import torch
from ..label_tensor import LabelTensor

View File

@@ -1,4 +1,4 @@
""" Module for OperationInterface class. """
"""Module for OperationInterface class."""
from .domain_interface import DomainInterface
from ..utils import check_consistency

View File

@@ -144,7 +144,7 @@ class SimplexDomain(DomainInterface):
return all(torch.gt(lambdas, 0.0)) and all(torch.lt(lambdas, 1.0))
return all(torch.ge(lambdas, 0)) and (
any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1))
any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1))
)
def _sample_interior_randomly(self, n, variables):

View File

@@ -1,4 +1,4 @@
"""Module for Union class. """
"""Module for Union class."""
import torch
from .operation_interface import OperationInterface