Correct codacy warnings
This commit is contained in:
committed by
Nicola Demo
parent
c9304fb9bb
commit
1bc1b3a580
@@ -2,13 +2,8 @@
|
||||
Import data classes
|
||||
"""
|
||||
__all__ = [
|
||||
'PinaDataLoader',
|
||||
'SupervisedDataset',
|
||||
'SamplePointDataset',
|
||||
'UnsupervisedDataset',
|
||||
'Batch',
|
||||
'PinaDataModule',
|
||||
'BaseDataset'
|
||||
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
|
||||
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
|
||||
]
|
||||
|
||||
from .pina_dataloader import PinaDataLoader
|
||||
|
||||
@@ -22,10 +22,12 @@ class BaseDataset(Dataset):
|
||||
dataset will be loaded.
|
||||
"""
|
||||
if cls is BaseDataset:
|
||||
raise TypeError('BaseDataset cannot be instantiated directly. Use a subclass.')
|
||||
raise TypeError(
|
||||
'BaseDataset cannot be instantiated directly. Use a subclass.')
|
||||
if not hasattr(cls, '__slots__'):
|
||||
raise TypeError('Something is wrong, __slots__ must be defined in subclasses.')
|
||||
return super().__new__(cls)
|
||||
raise TypeError(
|
||||
'Something is wrong, __slots__ must be defined in subclasses.')
|
||||
return super(BaseDataset, cls).__new__(cls)
|
||||
|
||||
def __init__(self, problem, device):
|
||||
""""
|
||||
@@ -79,7 +81,8 @@ class BaseDataset(Dataset):
|
||||
|
||||
def __getattribute__(self, item):
|
||||
attribute = super().__getattribute__(item)
|
||||
if isinstance(attribute, LabelTensor) and attribute.dtype == torch.float32:
|
||||
if isinstance(attribute,
|
||||
LabelTensor) and attribute.dtype == torch.float32:
|
||||
attribute = attribute.to(device=self.device).requires_grad_()
|
||||
return attribute
|
||||
|
||||
@@ -101,7 +104,8 @@ class BaseDataset(Dataset):
|
||||
if all(isinstance(x, int) for x in idx):
|
||||
to_return_list = []
|
||||
for i in self.__slots__:
|
||||
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
|
||||
to_return_list.append(
|
||||
getattr(self, i)[[idx]].to(self.device))
|
||||
return to_return_list
|
||||
|
||||
raise ValueError(f'Invalid index {idx}')
|
||||
|
||||
@@ -5,6 +5,7 @@ from .pina_subset import PinaSubset
|
||||
|
||||
|
||||
class Batch:
|
||||
|
||||
def __init__(self, dataset_dict, idx_dict):
|
||||
|
||||
for k, v in dataset_dict.items():
|
||||
@@ -29,5 +30,6 @@ class Batch:
|
||||
def __getattr__(self, item):
|
||||
if not item in dir(self):
|
||||
raise AttributeError(f'Batch instance has no attribute {item}')
|
||||
return PinaSubset(getattr(self, item).dataset,
|
||||
getattr(self, item).indices[self.coordinates_dict[item]])
|
||||
return PinaSubset(
|
||||
getattr(self, item).dataset,
|
||||
getattr(self, item).indices[self.coordinates_dict[item]])
|
||||
|
||||
@@ -50,7 +50,8 @@ class PinaDataLoader:
|
||||
temp_dict[k] = slice(i * v, (i + 1) * v)
|
||||
else:
|
||||
temp_dict[k] = slice(i * v, len(self.dataset_dict[k]))
|
||||
self.batches.append(Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
|
||||
self.batches.append(
|
||||
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user