edited utils to take list (#115)
* enhanced difference domain * refactored utils * fixed typo * added tests --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Utils module"""
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from functools import reduce
|
||||
import types
|
||||
|
||||
@@ -10,14 +11,14 @@ from .label_tensor import LabelTensor
|
||||
import torch
|
||||
|
||||
|
||||
def check_consistency(object, object_instance, object_name, subclass=False):
|
||||
def check_consistency(object, object_instance, subclass=False):
|
||||
"""Helper function to check object inheritance consistency.
|
||||
Given a specific ``'object'`` we check if the object is
|
||||
instance of a specific ``'object_instance'``, or in case
|
||||
``'subclass=True'`` we check if the object is subclass
|
||||
if the ``'object_instance'``.
|
||||
|
||||
:param Object object: The object to check the inheritance
|
||||
:param (iterable or class object) object: The object to check the inheritance
|
||||
:param Object object_instance: The parent class from where the object
|
||||
is expected to inherit
|
||||
:param str object_name: The name of the object
|
||||
@@ -25,12 +26,17 @@ def check_consistency(object, object_instance, object_name, subclass=False):
|
||||
:raises ValueError: If the object does not inherit from the
|
||||
specified class
|
||||
"""
|
||||
if not subclass:
|
||||
if not isinstance(object, object_instance):
|
||||
raise ValueError(f"{object_name} must be {object_instance}")
|
||||
else:
|
||||
if not issubclass(object, object_instance):
|
||||
raise ValueError(f"{object_name} must be {object_instance}")
|
||||
if not isinstance(object, (list, set, tuple)):
|
||||
object = [object]
|
||||
|
||||
for obj in object:
|
||||
try:
|
||||
if not subclass:
|
||||
assert isinstance(obj, object_instance)
|
||||
else:
|
||||
assert issubclass(obj, object_instance)
|
||||
except AssertionError:
|
||||
raise ValueError(f"{type(obj).__name__} must be {object_instance}.")
|
||||
|
||||
|
||||
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
||||
@@ -180,13 +186,13 @@ def chebyshev_roots(n):
|
||||
# def __len__(self):
|
||||
# return self._len
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
@@ -201,7 +207,7 @@ class LabelTensorDataset(Dataset):
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
@@ -229,13 +235,13 @@ class LabelTensorDataLoader(DataLoader):
|
||||
# def __len__(self):
|
||||
# return self._len
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class LabelTensorDataset(Dataset):
|
||||
def __init__(self, d):
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
self.labels = list(d.keys())
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
print(index)
|
||||
result = {}
|
||||
@@ -250,7 +256,7 @@ class LabelTensorDataset(Dataset):
|
||||
result[label] = sample_tensor[index]
|
||||
except IndexError:
|
||||
result[label] = torch.tensor([])
|
||||
|
||||
|
||||
print(result)
|
||||
return result
|
||||
|
||||
@@ -261,4 +267,4 @@ class LabelTensorDataset(Dataset):
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
pass
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user