Bug fix and add additional tests for Dataset and DataModule (#517)
This commit is contained in:
committed by
FilippoOlivo
parent
79a7199985
commit
80c257da4d
@@ -167,9 +167,15 @@ class PinaDataset(Dataset, ABC):
|
||||
:return: A dictionary containing all the data in the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
index = list(range(len(self)))
|
||||
return self.fetch_from_idx_list(index)
|
||||
to_return_dict = {}
|
||||
for condition, data in self.conditions_dict.items():
|
||||
len_condition = len(
|
||||
data["input"]
|
||||
) # Length of the current condition
|
||||
to_return_dict[condition] = self._retrive_data(
|
||||
data, list(range(len_condition))
|
||||
) # Retrieve the data from the current condition
|
||||
return to_return_dict
|
||||
|
||||
def fetch_from_idx_list(self, idx):
|
||||
"""
|
||||
@@ -306,3 +312,13 @@ class PinaGraphDataset(PinaDataset):
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
Return the input data for the dataset.
|
||||
|
||||
:return: Dictionary containing the input points.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||
|
||||
Reference in New Issue
Block a user