Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -285,7 +285,7 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
@staticmethod
|
||||
def _split_condition(condition_dict, splits_dict):
|
||||
len_condition = len(condition_dict["input_points"])
|
||||
len_condition = len(condition_dict["input"])
|
||||
|
||||
lengths = [
|
||||
int(len_condition * length) for length in splits_dict.values()
|
||||
@@ -343,7 +343,7 @@ class PinaDataModule(LightningDataModule):
|
||||
condition_name,
|
||||
condition_dict,
|
||||
) in collector.data_collections.items():
|
||||
len_data = len(condition_dict["input_points"])
|
||||
len_data = len(condition_dict["input"])
|
||||
if self.shuffle:
|
||||
_apply_shuffle(condition_dict, len_data)
|
||||
for key, data in self._split_condition(
|
||||
@@ -390,12 +390,12 @@ class PinaDataModule(LightningDataModule):
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.collector_splits[split].items():
|
||||
if self.batch_size is None:
|
||||
max_conditions_lengths[k] = len(v["input_points"])
|
||||
max_conditions_lengths[k] = len(v["input"])
|
||||
elif self.repeat:
|
||||
max_conditions_lengths[k] = self.batch_size
|
||||
else:
|
||||
max_conditions_lengths[k] = min(
|
||||
len(v["input_points"]), self.batch_size
|
||||
len(v["input"]), self.batch_size
|
||||
)
|
||||
return max_conditions_lengths
|
||||
|
||||
@@ -455,15 +455,15 @@ class PinaDataModule(LightningDataModule):
|
||||
raise ValueError("The sum of the splits must be 1")
|
||||
|
||||
@property
|
||||
def input_points(self):
|
||||
def input(self):
|
||||
"""
|
||||
# TODO
|
||||
"""
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
to_return["train"] = self.train_dataset.input_points
|
||||
to_return["train"] = self.train_dataset.input
|
||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||
to_return["val"] = self.val_dataset.input_points
|
||||
to_return["val"] = self.val_dataset.input
|
||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||
to_return = self.test_dataset.input_points
|
||||
to_return = self.test_dataset.input
|
||||
return to_return
|
||||
|
||||
Reference in New Issue
Block a user