clean logic, fix problems for tutorial1
This commit is contained in:
@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
||||
RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from .dataset import PinaDatasetFactory
|
||||
|
||||
from ..collector import Collector
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, dataset, device):
|
||||
@@ -87,7 +87,7 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
collector,
|
||||
problem,
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
val_size=.1,
|
||||
@@ -99,7 +99,6 @@ class PinaDataModule(LightningDataModule):
|
||||
):
|
||||
"""
|
||||
Initialize the object, creating dataset based on input problem
|
||||
:param Collector collector: PINA problem
|
||||
:param train_size: number/percentage of elements in train split
|
||||
:param test_size: number/percentage of elements in test split
|
||||
:param val_size: number/percentage of elements in evaluation split
|
||||
@@ -135,6 +134,10 @@ class PinaDataModule(LightningDataModule):
|
||||
self.predict_dataset = None
|
||||
else:
|
||||
self.predict_dataloader = super().predict_dataloader
|
||||
|
||||
collector = Collector(problem)
|
||||
collector.store_fixed_data()
|
||||
collector.store_sample_domains()
|
||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||
|
||||
|
||||
Reference in New Issue
Block a user