Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -285,7 +285,7 @@ class PinaDataModule(LightningDataModule):
@staticmethod
def _split_condition(condition_dict, splits_dict):
len_condition = len(condition_dict["input_points"])
len_condition = len(condition_dict["input"])
lengths = [
int(len_condition * length) for length in splits_dict.values()
@@ -343,7 +343,7 @@ class PinaDataModule(LightningDataModule):
condition_name,
condition_dict,
) in collector.data_collections.items():
len_data = len(condition_dict["input_points"])
len_data = len(condition_dict["input"])
if self.shuffle:
_apply_shuffle(condition_dict, len_data)
for key, data in self._split_condition(
@@ -390,12 +390,12 @@ class PinaDataModule(LightningDataModule):
max_conditions_lengths = {}
for k, v in self.collector_splits[split].items():
if self.batch_size is None:
max_conditions_lengths[k] = len(v["input_points"])
max_conditions_lengths[k] = len(v["input"])
elif self.repeat:
max_conditions_lengths[k] = self.batch_size
else:
max_conditions_lengths[k] = min(
len(v["input_points"]), self.batch_size
len(v["input"]), self.batch_size
)
return max_conditions_lengths
@@ -455,15 +455,15 @@ class PinaDataModule(LightningDataModule):
raise ValueError("The sum of the splits must be 1")
@property
def input_points(self):
def input(self):
"""
# TODO
"""
to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None:
to_return["train"] = self.train_dataset.input_points
to_return["train"] = self.train_dataset.input
if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input_points
to_return["val"] = self.val_dataset.input
if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return = self.test_dataset.input_points
to_return = self.test_dataset.input
return to_return