Correct codacy warnings

This commit is contained in:
FilippoOlivo
2024-10-22 14:54:22 +02:00
committed by Nicola Demo
parent 1bc1b3a580
commit 3e30450e9a
10 changed files with 60 additions and 37 deletions

View File

@@ -26,7 +26,7 @@ class PinaDataModule(LightningDataModule):
eval_size=.1,
batch_size=None,
shuffle=True,
datasets = None):
datasets=None):
"""
Initialize the object, creating dataset based on input problem
:param AbstractProblem problem: PINA problem
@@ -38,9 +38,11 @@ class PinaDataModule(LightningDataModule):
:param datasets: list of datasets objects
"""
super().__init__()
dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset]
dataset_classes = [SupervisedDataset, UnsupervisedDataset,
SamplePointDataset]
if datasets is None:
self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes]
self.datasets = [DatasetClass(problem, device) for DatasetClass in
dataset_classes]
else:
self.datasets = datasets
@@ -100,8 +102,6 @@ class PinaDataModule(LightningDataModule):
for key, value in dataset.condition_names.items()
}
def train_dataloader(self):
"""
Return the training dataloader for the dataset
@@ -158,11 +158,13 @@ class PinaDataModule(LightningDataModule):
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator).tolist()
indices = torch.randperm(sum(lengths),
generator=generator).tolist()
else:
indices = torch.arange(sum(lengths)).tolist()
else:
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
indices = torch.arange(0, sum(lengths), 1,
dtype=torch.uint8).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]