Minor cleanup (#45)
* cleanup * codacy * move merge_tensors to utils.py
This commit is contained in:
committed by
GitHub
parent
399c72fc0e
commit
e974801df2
@@ -1,9 +1,11 @@
|
||||
"""Utils module"""
|
||||
from functools import reduce
|
||||
|
||||
|
||||
def number_parameters(model, aggregate=True, only_trainable=True): #TODO: check
|
||||
"""
|
||||
Return the number of parameters of a given `model`.
|
||||
|
||||
|
||||
:param torch.nn.Module model: the torch module to inspect.
|
||||
:param bool aggregate: if True the return values is an integer corresponding
|
||||
to the total amount of parameters of whole model. If False, it returns a
|
||||
@@ -25,3 +27,19 @@ def number_parameters(model, aggregate=True, only_trainable=True): #TODO: check
|
||||
tmp = sum(tmp.values())
|
||||
|
||||
return tmp
|
||||
|
||||
|
||||
def merge_tensors(tensors): # name to be changed
|
||||
if tensors:
|
||||
return reduce(merge_two_tensors, tensors[1:], tensors[0])
|
||||
raise ValueError("Expected at least one tensor")
|
||||
|
||||
|
||||
def merge_two_tensors(tensor1, tensor2):
|
||||
n1 = tensor1.shape[0]
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
return tensor1.append(tensor2)
|
||||
|
||||
Reference in New Issue
Block a user