edited utils to take list (#115)

* enhanced difference domain
* refactored utils
* fixed typo
* added tests
---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
Kush
2023-06-19 18:47:52 +02:00
committed by Nicola Demo
parent aaf2bed732
commit 62ec69ccac
9 changed files with 73 additions and 47 deletions

View File

@@ -1,4 +1,5 @@
"""Utils module"""
from torch.utils.data import Dataset, DataLoader
from functools import reduce
import types
@@ -10,14 +11,14 @@ from .label_tensor import LabelTensor
import torch
def check_consistency(object, object_instance, object_name, subclass=False):
def check_consistency(object, object_instance, subclass=False):
"""Helper function to check object inheritance consistency.
Given a specific ``'object'`` we check if the object is
instance of a specific ``'object_instance'``, or in case
``'subclass=True'`` we check if the object is subclass
if the ``'object_instance'``.
:param Object object: The object to check the inheritance
:param (iterable or class object) object: The object to check the inheritance
:param Object object_instance: The parent class from where the object
is expected to inherit
:param str object_name: The name of the object
@@ -25,12 +26,17 @@ def check_consistency(object, object_instance, object_name, subclass=False):
:raises ValueError: If the object does not inherit from the
specified class
"""
if not subclass:
if not isinstance(object, object_instance):
raise ValueError(f"{object_name} must be {object_instance}")
else:
if not issubclass(object, object_instance):
raise ValueError(f"{object_name} must be {object_instance}")
if not isinstance(object, (list, set, tuple)):
object = [object]
for obj in object:
try:
if not subclass:
assert isinstance(obj, object_instance)
else:
assert issubclass(obj, object_instance)
except AssertionError:
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
@@ -180,13 +186,13 @@ def chebyshev_roots(n):
# def __len__(self):
# return self._len
from torch.utils.data import Dataset, DataLoader
class LabelTensorDataset(Dataset):
def __init__(self, d):
for k, v in d.items():
setattr(self, k, v)
self.labels = list(d.keys())
def __getitem__(self, index):
print(index)
result = {}
@@ -201,7 +207,7 @@ class LabelTensorDataset(Dataset):
result[label] = sample_tensor[index]
except IndexError:
result[label] = torch.tensor([])
print(result)
return result
@@ -229,13 +235,13 @@ class LabelTensorDataLoader(DataLoader):
# def __len__(self):
# return self._len
from torch.utils.data import Dataset, DataLoader
class LabelTensorDataset(Dataset):
def __init__(self, d):
for k, v in d.items():
setattr(self, k, v)
self.labels = list(d.keys())
def __getitem__(self, index):
print(index)
result = {}
@@ -250,7 +256,7 @@ class LabelTensorDataset(Dataset):
result[label] = sample_tensor[index]
except IndexError:
result[label] = torch.tensor([])
print(result)
return result
@@ -261,4 +267,4 @@ class LabelTensorDataset(Dataset):
class LabelTensorDataLoader(DataLoader):
def collate_fn(self, data):
pass
pass