Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -168,8 +168,9 @@ class CartesianDomain(DomainInterface):
|
||||
for variable in variables:
|
||||
if variable in self.fixed_.keys():
|
||||
value = self.fixed_[variable]
|
||||
pts_variable = torch.tensor([[value]
|
||||
]).repeat(result.shape[0], 1)
|
||||
pts_variable = torch.tensor([[value]]).repeat(
|
||||
result.shape[0], 1
|
||||
)
|
||||
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||
pts_variable.labels = [variable]
|
||||
|
||||
@@ -202,8 +203,9 @@ class CartesianDomain(DomainInterface):
|
||||
for variable in variables:
|
||||
if variable in self.fixed_.keys():
|
||||
value = self.fixed_[variable]
|
||||
pts_variable = torch.tensor([[value]
|
||||
]).repeat(result.shape[0], 1)
|
||||
pts_variable = torch.tensor([[value]]).repeat(
|
||||
result.shape[0], 1
|
||||
)
|
||||
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||
pts_variable.labels = [variable]
|
||||
|
||||
|
||||
@@ -36,9 +36,11 @@ class DomainInterface(metaclass=ABCMeta):
|
||||
values = [values]
|
||||
for value in values:
|
||||
if value not in DomainInterface.available_sampling_modes:
|
||||
raise TypeError(f"mode {value} not valid. Expected at least "
|
||||
"one in "
|
||||
f"{DomainInterface.available_sampling_modes}.")
|
||||
raise TypeError(
|
||||
f"mode {value} not valid. Expected at least "
|
||||
"one in "
|
||||
f"{DomainInterface.available_sampling_modes}."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def sample(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Module for Exclusion class. """
|
||||
"""Module for Exclusion class."""
|
||||
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Module for Intersection class. """
|
||||
"""Module for Intersection class."""
|
||||
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Module for OperationInterface class. """
|
||||
"""Module for OperationInterface class."""
|
||||
|
||||
from .domain_interface import DomainInterface
|
||||
from ..utils import check_consistency
|
||||
|
||||
@@ -144,7 +144,7 @@ class SimplexDomain(DomainInterface):
|
||||
return all(torch.gt(lambdas, 0.0)) and all(torch.lt(lambdas, 1.0))
|
||||
|
||||
return all(torch.ge(lambdas, 0)) and (
|
||||
any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1))
|
||||
any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1))
|
||||
)
|
||||
|
||||
def _sample_interior_randomly(self, n, variables):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Module for Union class. """
|
||||
"""Module for Union class."""
|
||||
|
||||
import torch
|
||||
from .operation_interface import OperationInterface
|
||||
|
||||
Reference in New Issue
Block a user