Fix Codacy Warnings (#477)

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-10 15:38:45 +01:00
committed by Nicola Demo
parent e3790e049a
commit 4177bfbb50
157 changed files with 3473 additions and 3839 deletions

View File

@@ -1,4 +1,9 @@
import logging
"""
This module contains the PinaDataModule class, which extends the
LightningDataModule class to allow proper creation and management of
different types of Datasets defined in PINA.
"""
import warnings
from lightning.pytorch import LightningDataModule
import torch
@@ -58,6 +63,10 @@ class DummyDataloader:
class Collator:
"""
Class used to collate the batch
"""
def __init__(self, max_conditions_lengths, dataset=None):
self.max_conditions_lengths = max_conditions_lengths
self.callable_function = (
@@ -123,6 +132,10 @@ class Collator:
class PinaSampler:
"""
Class used to create the sampler instance.
"""
def __new__(cls, dataset, shuffle):
if (
@@ -150,7 +163,6 @@ class PinaDataModule(LightningDataModule):
train_size=0.7,
test_size=0.2,
val_size=0.1,
predict_size=0.0,
batch_size=None,
shuffle=True,
repeat=False,
@@ -169,9 +181,8 @@ class PinaDataModule(LightningDataModule):
:type test_size: float
:param val_size: Fraction or number of elements in the validation split.
:type val_size: float
:param predict_size: Fraction or number of elements in the prediction split.
:type predict_size: float
:param batch_size: Batch size used for training. If None, the entire dataset is used per batch.
:param batch_size: Batch size used for training. If None, the entire
dataset is used per batch.
:type batch_size: int or None
:param shuffle: Whether to shuffle the dataset before splitting.
:type shuffle: bool
@@ -179,13 +190,13 @@ class PinaDataModule(LightningDataModule):
:type repeat: bool
:param automatic_batching: Whether to enable automatic batching.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
:param num_workers: Number of worker threads for data loading.
Default 0 (serial loading)
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
:param pin_memory: Whether to use pinned memory for faster data
transfer to GPU. (Default False)
:type pin_memory: bool
"""
logging.debug("Start initialization of Pina DataModule")
logging.info("Start initialization of Pina DataModule")
super().__init__()
# Store fixed attributes
@@ -216,7 +227,7 @@ class PinaDataModule(LightningDataModule):
collector.store_sample_domains()
# Check if the splits are correct
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
self._check_slit_sizes(train_size, test_size, val_size)
# Split input data into subsets
splits_dict = {}
@@ -235,11 +246,7 @@ class PinaDataModule(LightningDataModule):
self.val_dataset = None
else:
self.val_dataloader = super().val_dataloader
if predict_size > 0:
splits_dict["predict"] = predict_size
self.predict_dataset = None
else:
self.predict_dataloader = super().predict_dataloader
self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device
@@ -247,7 +254,6 @@ class PinaDataModule(LightningDataModule):
"""
Perform the splitting of the dataset
"""
logging.debug("Start setup of Pina DataModule obj")
if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory(
self.collector_splits["train"],
@@ -270,18 +276,8 @@ class PinaDataModule(LightningDataModule):
max_conditions_lengths=self.find_max_conditions_lengths("test"),
automatic_batching=self.automatic_batching,
)
elif stage == "predict":
self.predict_dataset = PinaDatasetFactory(
self.collector_splits["predict"],
max_conditions_lengths=self.find_max_conditions_lengths(
"predict"
),
automatic_batching=self.automatic_batching,
)
else:
raise ValueError(
"stage must be either 'fit' or 'test' or 'predict'."
)
raise ValueError("stage must be either 'fit' or 'test'.")
@staticmethod
def _split_condition(condition_dict, splits_dict):
@@ -336,7 +332,6 @@ class PinaDataModule(LightningDataModule):
# ----------- End auxiliary function ------------
logging.debug("Dataset creation in PinaDataModule obj")
split_names = list(splits_dict.keys())
dataset_dict = {name: {} for name in split_names}
for (
@@ -355,11 +350,13 @@ class PinaDataModule(LightningDataModule):
def _create_dataloader(self, split, dataset):
shuffle = self.shuffle if split == "train" else False
# Suppress the warning about num_workers.
# In many cases, especially for PINNs, serial data loading can outperform parallel data loading.
# In many cases, especially for PINNs,
# serial data loading can outperform parallel data loading.
warnings.filterwarnings(
"ignore",
message=(
r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck."
"The '(train|val|test)_dataloader' does not have many workers "
"which may be a bottleneck."
),
module="lightning.pytorch.trainer.connectors.data_connector",
)
@@ -387,6 +384,14 @@ class PinaDataModule(LightningDataModule):
return dataloader
def find_max_conditions_lengths(self, split):
"""
Define the maximum length of the conditions.
:param split: The splits of the dataset.
:type split: dict
:return: The maximum length of the conditions.
:rtype: dict
"""
max_conditions_lengths = {}
for k, v in self.collector_splits[split].items():
if self.batch_size is None:
@@ -417,12 +422,6 @@ class PinaDataModule(LightningDataModule):
"""
return self._create_dataloader("test", self.test_dataset)
def predict_dataloader(self):
"""
Create the prediction dataloader
"""
raise NotImplementedError("Predict dataloader not implemented")
@staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
return batch
@@ -445,13 +444,13 @@ class PinaDataModule(LightningDataModule):
return batch
@staticmethod
def _check_slit_sizes(train_size, test_size, val_size, predict_size):
def _check_slit_sizes(train_size, test_size, val_size):
"""
Check if the splits are correct
"""
if train_size < 0 or test_size < 0 or val_size < 0 or predict_size < 0:
if train_size < 0 or test_size < 0 or val_size < 0:
raise ValueError("The splits must be positive")
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
if abs(train_size + test_size + val_size - 1) > 1e-6:
raise ValueError("The sum of the splits must be 1")
@property