🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Utils module"""
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from functools import reduce
|
||||
import types
|
||||
@@ -12,13 +13,13 @@ import torch
|
||||
|
||||
|
||||
def check_consistency(object, object_instance, subclass=False):
|
||||
"""Helper function to check object inheritance consistency.
|
||||
"""Helper function to check object inheritance consistency.
|
||||
Given a specific ``'object'`` we check if the object is
|
||||
instance of a specific ``'object_instance'``, or in case
|
||||
``'subclass=True'`` we check if the object is subclass
|
||||
if the ``'object_instance'``.
|
||||
|
||||
:param (iterable or class object) object: The object to check the inheritance
|
||||
:param (iterable or class object) object: The object to check the inheritance
|
||||
:param Object object_instance: The parent class from where the object
|
||||
is expected to inherit
|
||||
:param str object_name: The name of the object
|
||||
@@ -39,9 +40,9 @@ def check_consistency(object, object_instance, subclass=False):
|
||||
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
|
||||
|
||||
|
||||
def number_parameters(model,
|
||||
aggregate=True,
|
||||
only_trainable=True): # TODO: check
|
||||
def number_parameters(
|
||||
model, aggregate=True, only_trainable=True
|
||||
): # TODO: check
|
||||
"""
|
||||
Return the number of parameters of a given `model`.
|
||||
|
||||
@@ -79,8 +80,9 @@ def merge_two_tensors(tensor1, tensor2):
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
|
||||
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
|
||||
labels=tensor2.labels)
|
||||
tensor2 = LabelTensor(
|
||||
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
|
||||
)
|
||||
return tensor1.append(tensor2)
|
||||
|
||||
|
||||
@@ -95,13 +97,13 @@ def torch_lhs(n, dim):
|
||||
"""
|
||||
|
||||
if not isinstance(n, int):
|
||||
raise TypeError('number of point n must be int')
|
||||
raise TypeError("number of point n must be int")
|
||||
|
||||
if not isinstance(dim, int):
|
||||
raise TypeError('dim must be int')
|
||||
raise TypeError("dim must be int")
|
||||
|
||||
if dim < 1:
|
||||
raise ValueError('dim must be greater than one')
|
||||
raise ValueError("dim must be greater than one")
|
||||
|
||||
samples = torch.rand(size=(n, dim))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user