Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -8,10 +8,11 @@ from .label_tensor import LabelTensor
def custom_warning_format(
message, category, filename, lineno, file=None, line=None
):
message, category, filename, lineno, file=None, line=None
):
return f"{filename}: {category.__name__}: {message}\n"
def check_consistency(object, object_instance, subclass=False):
"""Helper function to check object inheritance consistency.
Given a specific ``'object'`` we check if the object is
@@ -39,6 +40,7 @@ def check_consistency(object, object_instance, subclass=False):
except AssertionError:
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
def labelize_forward(forward, input_variables, output_variables):
"""
Wrapper decorator to allow users to enable or disable the use of
@@ -51,6 +53,7 @@ def labelize_forward(forward, input_variables, output_variables):
:param output_variables: The problem output variables.
:type output_variables: list[str] | tuple[str]
"""
def wrapper(x):
x = x.extract(input_variables)
output = forward(x)
@@ -59,8 +62,10 @@ def labelize_forward(forward, input_variables, output_variables):
output = output.as_subclass(LabelTensor)
output.labels = output_variables
return output
return wrapper
def merge_tensors(tensors): # name to be changed
if tensors:
return reduce(merge_two_tensors, tensors[1:], tensors[0])
@@ -72,8 +77,9 @@ def merge_two_tensors(tensor1, tensor2):
n2 = tensor2.shape[0]
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
)
return tensor1.append(tensor2)