Correct codacy warnings
This commit is contained in:
committed by
Nicola Demo
parent
1bc1b3a580
commit
3e30450e9a
@@ -26,7 +26,7 @@ class PinaDataModule(LightningDataModule):
|
||||
eval_size=.1,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
datasets = None):
|
||||
datasets=None):
|
||||
"""
|
||||
Initialize the object, creating dataset based on input problem
|
||||
:param AbstractProblem problem: PINA problem
|
||||
@@ -38,9 +38,11 @@ class PinaDataModule(LightningDataModule):
|
||||
:param datasets: list of datasets objects
|
||||
"""
|
||||
super().__init__()
|
||||
dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset]
|
||||
dataset_classes = [SupervisedDataset, UnsupervisedDataset,
|
||||
SamplePointDataset]
|
||||
if datasets is None:
|
||||
self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes]
|
||||
self.datasets = [DatasetClass(problem, device) for DatasetClass in
|
||||
dataset_classes]
|
||||
else:
|
||||
self.datasets = datasets
|
||||
|
||||
@@ -100,8 +102,6 @@ class PinaDataModule(LightningDataModule):
|
||||
for key, value in dataset.condition_names.items()
|
||||
}
|
||||
|
||||
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Return the training dataloader for the dataset
|
||||
@@ -158,11 +158,13 @@ class PinaDataModule(LightningDataModule):
|
||||
if seed is not None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
indices = torch.randperm(sum(lengths), generator=generator).tolist()
|
||||
indices = torch.randperm(sum(lengths),
|
||||
generator=generator).tolist()
|
||||
else:
|
||||
indices = torch.arange(sum(lengths)).tolist()
|
||||
else:
|
||||
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
|
||||
indices = torch.arange(0, sum(lengths), 1,
|
||||
dtype=torch.uint8).tolist()
|
||||
offsets = [
|
||||
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user