clean logic, fix problems for tutorial1

This commit is contained in:
Nicola Demo
2025-02-06 14:28:17 +01:00
parent 7702427e8d
commit effd1e83bb
5 changed files with 105 additions and 76 deletions

View File

@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
RandomSampler
from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory
from ..collector import Collector
class DummyDataloader:
def __init__(self, dataset, device):
@@ -87,7 +87,7 @@ class PinaDataModule(LightningDataModule):
"""
def __init__(self,
collector,
problem,
train_size=.7,
test_size=.2,
val_size=.1,
@@ -99,7 +99,6 @@ class PinaDataModule(LightningDataModule):
):
"""
Initialize the object, creating dataset based on input problem
:param Collector collector: PINA problem
:param train_size: number/percentage of elements in train split
:param test_size: number/percentage of elements in test split
:param val_size: number/percentage of elements in evaluation split
@@ -135,6 +134,10 @@ class PinaDataModule(LightningDataModule):
self.predict_dataset = None
else:
self.predict_dataloader = super().predict_dataloader
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device

View File

@@ -58,6 +58,7 @@ class PinaTensorDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching:
self._getitem_func = self._getitem_int
else: