Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -8,10 +8,11 @@ from .label_tensor import LabelTensor
|
||||
|
||||
|
||||
def custom_warning_format(
|
||||
message, category, filename, lineno, file=None, line=None
|
||||
):
|
||||
message, category, filename, lineno, file=None, line=None
|
||||
):
|
||||
return f"{filename}: {category.__name__}: {message}\n"
|
||||
|
||||
|
||||
def check_consistency(object, object_instance, subclass=False):
|
||||
"""Helper function to check object inheritance consistency.
|
||||
Given a specific ``'object'`` we check if the object is
|
||||
@@ -39,6 +40,7 @@ def check_consistency(object, object_instance, subclass=False):
|
||||
except AssertionError:
|
||||
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
|
||||
|
||||
|
||||
def labelize_forward(forward, input_variables, output_variables):
|
||||
"""
|
||||
Wrapper decorator to allow users to enable or disable the use of
|
||||
@@ -51,6 +53,7 @@ def labelize_forward(forward, input_variables, output_variables):
|
||||
:param output_variables: The problem output variables.
|
||||
:type output_variables: list[str] | tuple[str]
|
||||
"""
|
||||
|
||||
def wrapper(x):
|
||||
x = x.extract(input_variables)
|
||||
output = forward(x)
|
||||
@@ -59,8 +62,10 @@ def labelize_forward(forward, input_variables, output_variables):
|
||||
output = output.as_subclass(LabelTensor)
|
||||
output.labels = output_variables
|
||||
return output
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def merge_tensors(tensors): # name to be changed
|
||||
if tensors:
|
||||
return reduce(merge_two_tensors, tensors[1:], tensors[0])
|
||||
@@ -72,8 +77,9 @@ def merge_two_tensors(tensor1, tensor2):
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
|
||||
)
|
||||
return tensor1.append(tensor2)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user