Fix Codacy Warnings (#477)
--------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
e3790e049a
commit
4177bfbb50
@@ -1,4 +1,9 @@
|
||||
import logging
|
||||
"""
|
||||
This module contains the PinaDataModule class, which extends the
|
||||
LightningDataModule class to allow proper creation and management of
|
||||
different types of Datasets defined in PINA.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from lightning.pytorch import LightningDataModule
|
||||
import torch
|
||||
@@ -58,6 +63,10 @@ class DummyDataloader:
|
||||
|
||||
|
||||
class Collator:
|
||||
"""
|
||||
Class used to collate the batch
|
||||
"""
|
||||
|
||||
def __init__(self, max_conditions_lengths, dataset=None):
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
self.callable_function = (
|
||||
@@ -123,6 +132,10 @@ class Collator:
|
||||
|
||||
|
||||
class PinaSampler:
|
||||
"""
|
||||
Class used to create the sampler instance.
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset, shuffle):
|
||||
|
||||
if (
|
||||
@@ -150,7 +163,6 @@ class PinaDataModule(LightningDataModule):
|
||||
train_size=0.7,
|
||||
test_size=0.2,
|
||||
val_size=0.1,
|
||||
predict_size=0.0,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
repeat=False,
|
||||
@@ -169,9 +181,8 @@ class PinaDataModule(LightningDataModule):
|
||||
:type test_size: float
|
||||
:param val_size: Fraction or number of elements in the validation split.
|
||||
:type val_size: float
|
||||
:param predict_size: Fraction or number of elements in the prediction split.
|
||||
:type predict_size: float
|
||||
:param batch_size: Batch size used for training. If None, the entire dataset is used per batch.
|
||||
:param batch_size: Batch size used for training. If None, the entire
|
||||
dataset is used per batch.
|
||||
:type batch_size: int or None
|
||||
:param shuffle: Whether to shuffle the dataset before splitting.
|
||||
:type shuffle: bool
|
||||
@@ -179,13 +190,13 @@ class PinaDataModule(LightningDataModule):
|
||||
:type repeat: bool
|
||||
:param automatic_batching: Whether to enable automatic batching.
|
||||
:type automatic_batching: bool
|
||||
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
|
||||
:param num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading)
|
||||
:type num_workers: int
|
||||
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
|
||||
:param pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. (Default False)
|
||||
:type pin_memory: bool
|
||||
"""
|
||||
logging.debug("Start initialization of Pina DataModule")
|
||||
logging.info("Start initialization of Pina DataModule")
|
||||
super().__init__()
|
||||
|
||||
# Store fixed attributes
|
||||
@@ -216,7 +227,7 @@ class PinaDataModule(LightningDataModule):
|
||||
collector.store_sample_domains()
|
||||
|
||||
# Check if the splits are correct
|
||||
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
|
||||
self._check_slit_sizes(train_size, test_size, val_size)
|
||||
|
||||
# Split input data into subsets
|
||||
splits_dict = {}
|
||||
@@ -235,11 +246,7 @@ class PinaDataModule(LightningDataModule):
|
||||
self.val_dataset = None
|
||||
else:
|
||||
self.val_dataloader = super().val_dataloader
|
||||
if predict_size > 0:
|
||||
splits_dict["predict"] = predict_size
|
||||
self.predict_dataset = None
|
||||
else:
|
||||
self.predict_dataloader = super().predict_dataloader
|
||||
|
||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||
|
||||
@@ -247,7 +254,6 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
Perform the splitting of the dataset
|
||||
"""
|
||||
logging.debug("Start setup of Pina DataModule obj")
|
||||
if stage == "fit" or stage is None:
|
||||
self.train_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["train"],
|
||||
@@ -270,18 +276,8 @@ class PinaDataModule(LightningDataModule):
|
||||
max_conditions_lengths=self.find_max_conditions_lengths("test"),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
elif stage == "predict":
|
||||
self.predict_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["predict"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
"predict"
|
||||
),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"stage must be either 'fit' or 'test' or 'predict'."
|
||||
)
|
||||
raise ValueError("stage must be either 'fit' or 'test'.")
|
||||
|
||||
@staticmethod
|
||||
def _split_condition(condition_dict, splits_dict):
|
||||
@@ -336,7 +332,6 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
# ----------- End auxiliary function ------------
|
||||
|
||||
logging.debug("Dataset creation in PinaDataModule obj")
|
||||
split_names = list(splits_dict.keys())
|
||||
dataset_dict = {name: {} for name in split_names}
|
||||
for (
|
||||
@@ -355,11 +350,13 @@ class PinaDataModule(LightningDataModule):
|
||||
def _create_dataloader(self, split, dataset):
|
||||
shuffle = self.shuffle if split == "train" else False
|
||||
# Suppress the warning about num_workers.
|
||||
# In many cases, especially for PINNs, serial data loading can outperform parallel data loading.
|
||||
# In many cases, especially for PINNs,
|
||||
# serial data loading can outperform parallel data loading.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=(
|
||||
r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck."
|
||||
"The '(train|val|test)_dataloader' does not have many workers "
|
||||
"which may be a bottleneck."
|
||||
),
|
||||
module="lightning.pytorch.trainer.connectors.data_connector",
|
||||
)
|
||||
@@ -387,6 +384,14 @@ class PinaDataModule(LightningDataModule):
|
||||
return dataloader
|
||||
|
||||
def find_max_conditions_lengths(self, split):
|
||||
"""
|
||||
Define the maximum length of the conditions.
|
||||
|
||||
:param split: The splits of the dataset.
|
||||
:type split: dict
|
||||
:return: The maximum length of the conditions.
|
||||
:rtype: dict
|
||||
"""
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.collector_splits[split].items():
|
||||
if self.batch_size is None:
|
||||
@@ -417,12 +422,6 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
return self._create_dataloader("test", self.test_dataset)
|
||||
|
||||
def predict_dataloader(self):
|
||||
"""
|
||||
Create the prediction dataloader
|
||||
"""
|
||||
raise NotImplementedError("Predict dataloader not implemented")
|
||||
|
||||
@staticmethod
|
||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||
return batch
|
||||
@@ -445,13 +444,13 @@ class PinaDataModule(LightningDataModule):
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def _check_slit_sizes(train_size, test_size, val_size, predict_size):
|
||||
def _check_slit_sizes(train_size, test_size, val_size):
|
||||
"""
|
||||
Check if the splits are correct
|
||||
"""
|
||||
if train_size < 0 or test_size < 0 or val_size < 0 or predict_size < 0:
|
||||
if train_size < 0 or test_size < 0 or val_size < 0:
|
||||
raise ValueError("The splits must be positive")
|
||||
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
|
||||
if abs(train_size + test_size + val_size - 1) > 1e-6:
|
||||
raise ValueError("The sum of the splits must be 1")
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user