Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver

This commit is contained in:
FilippoOlivo
2024-10-16 11:24:37 +02:00
committed by Nicola Demo
parent b9753c34b2
commit c9304fb9bb
30 changed files with 770 additions and 784 deletions

View File

@@ -1,43 +1,12 @@
from torch.utils.data import Dataset
import torch
"""
Sample dataset module
"""
from .base_dataset import BaseDataset
from ..label_tensor import LabelTensor
class SamplePointDataset(Dataset):
class SamplePointDataset(BaseDataset):
"""
This class is used to create a dataset of sample points.
This class extends the BaseDataset to handle physical datasets
composed of only input points.
"""
def __init__(self, problem, device) -> None:
"""
:param dict input_pts: The input points.
"""
super().__init__()
pts_list = []
self.condition_names = []
for name, condition in problem.conditions.items():
if not hasattr(condition, "output_points"):
pts_list.append(problem.input_pts[name])
self.condition_names.append(name)
self.pts = LabelTensor.stack(pts_list)
if self.pts != []:
self.condition_indeces = torch.cat(
[
torch.tensor([i] * len(pts_list[i]))
for i in range(len(self.condition_names))
],
dim=0,
)
else: # if there are no sample points
self.condition_indeces = torch.tensor([])
self.pts = torch.tensor([])
self.pts = self.pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.pts.shape[0]
data_type = 'physics'
__slots__ = ['input_points']