Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver
This commit is contained in:
committed by
Nicola Demo
parent
b9753c34b2
commit
c9304fb9bb
@@ -1,36 +1,33 @@
|
||||
"""
|
||||
Batch management module
|
||||
"""
|
||||
from .pina_subset import PinaSubset
|
||||
|
||||
|
||||
class Batch:
|
||||
"""
|
||||
This class is used to create a dataset of sample points.
|
||||
"""
|
||||
def __init__(self, dataset_dict, idx_dict):
|
||||
|
||||
def __init__(self, type_, idx, *args, **kwargs) -> None:
|
||||
for k, v in dataset_dict.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
for k, v in idx_dict.items():
|
||||
setattr(self, k + '_idx', v)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the number of elements in the batch
|
||||
:return: number of elements in the batch
|
||||
:rtype: int
|
||||
"""
|
||||
if type_ == "sample":
|
||||
length = 0
|
||||
for dataset in dir(self):
|
||||
attribute = getattr(self, dataset)
|
||||
if isinstance(attribute, list):
|
||||
length += len(getattr(self, dataset))
|
||||
return length
|
||||
|
||||
if len(args) != 2:
|
||||
raise RuntimeError
|
||||
|
||||
input = args[0]
|
||||
conditions = args[1]
|
||||
|
||||
self.input = input[idx]
|
||||
self.condition = conditions[idx]
|
||||
|
||||
elif type_ == "data":
|
||||
|
||||
if len(args) != 3:
|
||||
raise RuntimeError
|
||||
|
||||
input = args[0]
|
||||
output = args[1]
|
||||
conditions = args[2]
|
||||
|
||||
self.input = input[idx]
|
||||
self.output = output[idx]
|
||||
self.condition = conditions[idx]
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid number of arguments.")
|
||||
def __getattr__(self, item):
|
||||
if not item in dir(self):
|
||||
raise AttributeError(f'Batch instance has no attribute {item}')
|
||||
return PinaSubset(getattr(self, item).dataset,
|
||||
getattr(self, item).indices[self.coordinates_dict[item]])
|
||||
|
||||
Reference in New Issue
Block a user