Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver

This commit is contained in:
FilippoOlivo
2024-10-16 11:24:37 +02:00
committed by Nicola Demo
parent b9753c34b2
commit c9304fb9bb
30 changed files with 770 additions and 784 deletions

View File

@@ -1,36 +1,33 @@
"""
Batch management module
"""
from .pina_subset import PinaSubset
class Batch:
"""
This class is used to create a dataset of sample points.
"""
def __init__(self, dataset_dict, idx_dict):
def __init__(self, type_, idx, *args, **kwargs) -> None:
for k, v in dataset_dict.items():
setattr(self, k, v)
for k, v in idx_dict.items():
setattr(self, k + '_idx', v)
def __len__(self):
"""
Returns the number of elements in the batch
:return: number of elements in the batch
:rtype: int
"""
if type_ == "sample":
length = 0
for dataset in dir(self):
attribute = getattr(self, dataset)
if isinstance(attribute, list):
length += len(getattr(self, dataset))
return length
if len(args) != 2:
raise RuntimeError
input = args[0]
conditions = args[1]
self.input = input[idx]
self.condition = conditions[idx]
elif type_ == "data":
if len(args) != 3:
raise RuntimeError
input = args[0]
output = args[1]
conditions = args[2]
self.input = input[idx]
self.output = output[idx]
self.condition = conditions[idx]
else:
raise ValueError("Invalid number of arguments.")
def __getattr__(self, item):
if not item in dir(self):
raise AttributeError(f'Batch instance has no attribute {item}')
return PinaSubset(getattr(self, item).dataset,
getattr(self, item).indices[self.coordinates_dict[item]])