Dev Update (#582)

* Fix adaptive refinement (#571)


---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>

* Remove collector

* Fixes

* Fixes

* rm unnecessary comment

* fix advection (#581)

* Fix tutorial .html link (#580)

* fix problem data collection for v0.1 (#584)

* Message Passing Module (#516)

* add deep tensor network block

* add interaction network block

* add radial field network block

* add schnet block

* add equivariant network block

* fix + tests + doc files

* fix egnn + equivariance/invariance tests

Co-authored-by: Dario Coscia <dariocos99@gmail.com>

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>

* add type checker (#527)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
Dario Coscia
2025-06-13 17:34:37 +02:00
committed by GitHub
parent 6b355b45de
commit 7bf7d34d0f
40 changed files with 1963 additions and 581 deletions

View File

@@ -12,7 +12,6 @@ from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory, PinaTensorDataset
from ..collector import Collector
class DummyDataloader:
@@ -330,9 +329,7 @@ class PinaDataModule(LightningDataModule):
self.pin_memory = pin_memory
# Collect data
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
problem.collect_data()
# Check if the splits are correct
self._check_slit_sizes(train_size, test_size, val_size)
@@ -361,7 +358,9 @@ class PinaDataModule(LightningDataModule):
# raises NotImplementedError
self.val_dataloader = super().val_dataloader
self.collector_splits = self._create_splits(collector, splits_dict)
self.data_splits = self._create_splits(
problem.collected_data, splits_dict
)
self.transfer_batch_to_device = self._transfer_batch_to_device
def setup(self, stage=None):
@@ -376,15 +375,15 @@ class PinaDataModule(LightningDataModule):
"""
if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory(
self.collector_splits["train"],
self.data_splits["train"],
max_conditions_lengths=self.find_max_conditions_lengths(
"train"
),
automatic_batching=self.automatic_batching,
)
if "val" in self.collector_splits.keys():
if "val" in self.data_splits.keys():
self.val_dataset = PinaDatasetFactory(
self.collector_splits["val"],
self.data_splits["val"],
max_conditions_lengths=self.find_max_conditions_lengths(
"val"
),
@@ -392,7 +391,7 @@ class PinaDataModule(LightningDataModule):
)
elif stage == "test":
self.test_dataset = PinaDatasetFactory(
self.collector_splits["test"],
self.data_splits["test"],
max_conditions_lengths=self.find_max_conditions_lengths("test"),
automatic_batching=self.automatic_batching,
)
@@ -473,7 +472,7 @@ class PinaDataModule(LightningDataModule):
for (
condition_name,
condition_dict,
) in collector.data_collections.items():
) in collector.items():
len_data = len(condition_dict["input"])
if self.shuffle:
_apply_shuffle(condition_dict, len_data)
@@ -540,7 +539,7 @@ class PinaDataModule(LightningDataModule):
"""
max_conditions_lengths = {}
for k, v in self.collector_splits[split].items():
for k, v in self.data_splits[split].items():
if self.batch_size is None:
max_conditions_lengths[k] = len(v["input"])
elif self.repeat:

View File

@@ -239,6 +239,22 @@ class PinaTensorDataset(PinaDataset):
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
def update_data(self, new_conditions_dict):
"""
Update the dataset with new data.
This method is used to update the dataset with new data. It replaces
the current data with the new data provided in the new_conditions_dict
parameter.
:param dict new_conditions_dict: Dictionary containing the new data.
:return: None
"""
for condition, data in new_conditions_dict.items():
if condition in self.conditions_dict:
self.conditions_dict[condition].update(data)
else:
self.conditions_dict[condition] = data
class PinaGraphDataset(PinaDataset):
"""