Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
@@ -12,7 +12,6 @@ from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ..label_tensor import LabelTensor
|
||||
from .dataset import PinaDatasetFactory, PinaTensorDataset
|
||||
from ..collector import Collector
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
@@ -330,9 +329,7 @@ class PinaDataModule(LightningDataModule):
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
# Collect data
|
||||
collector = Collector(problem)
|
||||
collector.store_fixed_data()
|
||||
collector.store_sample_domains()
|
||||
problem.collect_data()
|
||||
|
||||
# Check if the splits are correct
|
||||
self._check_slit_sizes(train_size, test_size, val_size)
|
||||
@@ -361,7 +358,9 @@ class PinaDataModule(LightningDataModule):
|
||||
# raises NotImplementedError
|
||||
self.val_dataloader = super().val_dataloader
|
||||
|
||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||
self.data_splits = self._create_splits(
|
||||
problem.collected_data, splits_dict
|
||||
)
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||
|
||||
def setup(self, stage=None):
|
||||
@@ -376,15 +375,15 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
if stage == "fit" or stage is None:
|
||||
self.train_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["train"],
|
||||
self.data_splits["train"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
"train"
|
||||
),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
if "val" in self.collector_splits.keys():
|
||||
if "val" in self.data_splits.keys():
|
||||
self.val_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["val"],
|
||||
self.data_splits["val"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
"val"
|
||||
),
|
||||
@@ -392,7 +391,7 @@ class PinaDataModule(LightningDataModule):
|
||||
)
|
||||
elif stage == "test":
|
||||
self.test_dataset = PinaDatasetFactory(
|
||||
self.collector_splits["test"],
|
||||
self.data_splits["test"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths("test"),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
@@ -473,7 +472,7 @@ class PinaDataModule(LightningDataModule):
|
||||
for (
|
||||
condition_name,
|
||||
condition_dict,
|
||||
) in collector.data_collections.items():
|
||||
) in collector.items():
|
||||
len_data = len(condition_dict["input"])
|
||||
if self.shuffle:
|
||||
_apply_shuffle(condition_dict, len_data)
|
||||
@@ -540,7 +539,7 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.collector_splits[split].items():
|
||||
for k, v in self.data_splits[split].items():
|
||||
if self.batch_size is None:
|
||||
max_conditions_lengths[k] = len(v["input"])
|
||||
elif self.repeat:
|
||||
|
||||
Reference in New Issue
Block a user