🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,4 +1,5 @@
"""Utils module"""
from torch.utils.data import Dataset, DataLoader
from functools import reduce
import types
@@ -12,13 +13,13 @@ import torch
def check_consistency(object, object_instance, subclass=False):
"""Helper function to check object inheritance consistency.
"""Helper function to check object inheritance consistency.
Given a specific ``'object'`` we check if the object is
instance of a specific ``'object_instance'``, or in case
``'subclass=True'`` we check if the object is subclass
if the ``'object_instance'``.
:param (iterable or class object) object: The object to check the inheritance
:param (iterable or class object) object: The object to check the inheritance
:param Object object_instance: The parent class from where the object
is expected to inherit
:param str object_name: The name of the object
@@ -39,9 +40,9 @@ def check_consistency(object, object_instance, subclass=False):
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
def number_parameters(model,
aggregate=True,
only_trainable=True): # TODO: check
def number_parameters(
model, aggregate=True, only_trainable=True
): # TODO: check
"""
Return the number of parameters of a given `model`.
@@ -79,8 +80,9 @@ def merge_two_tensors(tensor1, tensor2):
n2 = tensor2.shape[0]
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
)
return tensor1.append(tensor2)
@@ -95,13 +97,13 @@ def torch_lhs(n, dim):
"""
if not isinstance(n, int):
raise TypeError('number of point n must be int')
raise TypeError("number of point n must be int")
if not isinstance(dim, int):
raise TypeError('dim must be int')
raise TypeError("dim must be int")
if dim < 1:
raise ValueError('dim must be greater than one')
raise ValueError("dim must be greater than one")
samples = torch.rand(size=(n, dim))