Compare commits
1 Commits
refact-dat
...
fix-codacy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40747c56ff |
2
.github/workflows/monthly-tagger.yml
vendored
2
.github/workflows/monthly-tagger.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest, ubuntu-latest]
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
|
||||
python-version: [3.9, '3.10', '3.11', '3.12', '3.13']
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
|
||||
6
.github/workflows/tester.yml
vendored
6
.github/workflows/tester.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: "Testing Pull Request"
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- "master"
|
||||
- "dev"
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest, ubuntu-latest]
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
|
||||
python-version: [3.9, '3.10', '3.11', '3.12', '3.13']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -75,4 +75,4 @@ jobs:
|
||||
threshold: 80.123
|
||||
fail: true
|
||||
publish: true
|
||||
coverage-summary-title: "Code Coverage Summary"
|
||||
coverage-summary-title: "Code Coverage Summary"
|
||||
|
||||
4
.github/workflows/tutorial_exporter.yml
vendored
4
.github/workflows/tutorial_exporter.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: 3.9
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: 3.9
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
@@ -7,8 +7,8 @@ SPDX-License-Identifier: Apache-2.0
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="readme/pina_logo.png">
|
||||
<img src="readme/pina_logo.png"
|
||||
<a href="https://github.com/mathLab/PINA/raw/master/readme/pina_logo.png">
|
||||
<img src="https://github.com/mathLab/PINA/raw/master/readme/pina_logo.png"
|
||||
alt="PINA logo"
|
||||
style="width: 220px; aspect-ratio: 1 / 1; object-fit: contain;">
|
||||
</a>
|
||||
|
||||
@@ -26,7 +26,6 @@ Trainer, Dataset and Datamodule
|
||||
Trainer <trainer.rst>
|
||||
Dataset <data/dataset.rst>
|
||||
DataModule <data/data_module.rst>
|
||||
Dataloader <data/dataloader.rst>
|
||||
|
||||
Data Types
|
||||
------------
|
||||
|
||||
@@ -2,6 +2,14 @@ DataModule
|
||||
======================
|
||||
.. currentmodule:: pina.data.data_module
|
||||
|
||||
.. autoclass:: Collator
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: PinaDataModule
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: PinaSampler
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -1,11 +0,0 @@
|
||||
Dataloader
|
||||
======================
|
||||
.. currentmodule:: pina.data.dataloader
|
||||
|
||||
.. autoclass:: PinaSampler
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: PinaDataLoader
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -7,5 +7,13 @@ Dataset
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: PinaDatasetFactory
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: PinaGraphDataset
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: PinaTensorDataset
|
||||
:members:
|
||||
:show-inheritance:
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 411 KiB After Width: | Height: | Size: 177 KiB |
@@ -5,6 +5,7 @@ from lightning.pytorch import Callback
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency, is_function
|
||||
from ..condition import InputTargetCondition
|
||||
from ..data.dataset import PinaGraphDataset
|
||||
|
||||
|
||||
class NormalizerDataCallback(Callback):
|
||||
@@ -121,10 +122,7 @@ class NormalizerDataCallback(Callback):
|
||||
"""
|
||||
|
||||
# Ensure datsets are not graph-based
|
||||
if any(
|
||||
ds.is_graph_dataset
|
||||
for ds in trainer.datamodule.train_dataset.values()
|
||||
):
|
||||
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
|
||||
raise NotImplementedError(
|
||||
"NormalizerDataCallback is not compatible with "
|
||||
"graph-based datasets."
|
||||
@@ -166,8 +164,8 @@ class NormalizerDataCallback(Callback):
|
||||
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
|
||||
"""
|
||||
for cond in conditions:
|
||||
if cond in dataset:
|
||||
data = dataset[cond].data[self.apply_to]
|
||||
if cond in dataset.conditions_dict:
|
||||
data = dataset.conditions_dict[cond][self.apply_to]
|
||||
shift = self.shift_fn(data)
|
||||
scale = self.scale_fn(data)
|
||||
self._normalizer[cond] = {
|
||||
@@ -199,20 +197,25 @@ class NormalizerDataCallback(Callback):
|
||||
|
||||
:param PinaDataset dataset: The dataset to be normalized.
|
||||
"""
|
||||
# Initialize update dictionary
|
||||
update_dataset_dict = {}
|
||||
|
||||
# Iterate over conditions and apply normalization
|
||||
for cond, norm_params in self.normalizer.items():
|
||||
update_dataset_dict = {}
|
||||
points = dataset[cond].data[self.apply_to]
|
||||
points = dataset.conditions_dict[cond][self.apply_to]
|
||||
scale = norm_params["scale"]
|
||||
shift = norm_params["shift"]
|
||||
normalized_points = self._norm_fn(points, scale, shift)
|
||||
update_dataset_dict[self.apply_to] = (
|
||||
LabelTensor(normalized_points, points.labels)
|
||||
if isinstance(points, LabelTensor)
|
||||
else normalized_points
|
||||
)
|
||||
dataset[cond].data.update(update_dataset_dict)
|
||||
update_dataset_dict[cond] = {
|
||||
self.apply_to: (
|
||||
LabelTensor(normalized_points, points.labels)
|
||||
if isinstance(points, LabelTensor)
|
||||
else normalized_points
|
||||
)
|
||||
}
|
||||
|
||||
# Update the dataset in-place
|
||||
dataset.update_data(update_dataset_dict)
|
||||
|
||||
@property
|
||||
def normalizer(self):
|
||||
|
||||
@@ -133,12 +133,13 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
|
||||
|
||||
:param PINNInterface solver: The solver object.
|
||||
"""
|
||||
new_points = {}
|
||||
for name in self._condition_to_update:
|
||||
new_points = {}
|
||||
current_points = self.dataset[name].data["input"]
|
||||
new_points["input"] = self.sample(current_points, name, solver)
|
||||
|
||||
self.dataset[name].update_data(new_points)
|
||||
current_points = self.dataset.conditions_dict[name]["input"]
|
||||
new_points[name] = {
|
||||
"input": self.sample(current_points, name, solver)
|
||||
}
|
||||
self.dataset.update_data(new_points)
|
||||
|
||||
def _compute_population_size(self, conditions):
|
||||
"""
|
||||
@@ -149,5 +150,6 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
|
||||
:rtype: dict
|
||||
"""
|
||||
return {
|
||||
cond: len(self.dataset[cond].data["input"]) for cond in conditions
|
||||
cond: len(self.dataset.conditions_dict[cond]["input"])
|
||||
for cond in conditions
|
||||
}
|
||||
|
||||
@@ -4,3 +4,4 @@ __all__ = ["PinaDataModule", "PinaDataset"]
|
||||
|
||||
|
||||
from .data_module import PinaDataModule
|
||||
from .dataset import PinaDataset
|
||||
|
||||
@@ -7,9 +7,232 @@ different types of Datasets defined in PINA.
|
||||
import warnings
|
||||
from lightning.pytorch import LightningDataModule
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ..label_tensor import LabelTensor
|
||||
from .dataset import PinaDatasetFactory
|
||||
from .dataloader import PinaDataLoader
|
||||
from .dataset import PinaDatasetFactory, PinaTensorDataset
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Prepare a dataloader object that returns the entire dataset in a single
|
||||
batch. Depending on the number of GPUs, the dataset is managed
|
||||
as follows:
|
||||
|
||||
- **Distributed Environment** (multiple GPUs): Divides dataset across
|
||||
processes using the rank and world size. Fetches only portion of
|
||||
data corresponding to the current process.
|
||||
- **Non-Distributed Environment** (single GPU): Fetches the entire
|
||||
dataset.
|
||||
|
||||
:param PinaDataset dataset: The dataset object to be processed.
|
||||
|
||||
.. note::
|
||||
This dataloader is used when the batch size is ``None``.
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if len(dataset) < world_size:
|
||||
raise RuntimeError(
|
||||
"Dimension of the dataset smaller than world size."
|
||||
" Increase the size of the partition or use a single GPU"
|
||||
)
|
||||
idx, i = [], rank
|
||||
while i < len(dataset):
|
||||
idx.append(i)
|
||||
i += world_size
|
||||
self.dataset = dataset.fetch_from_idx_list(idx)
|
||||
else:
|
||||
self.dataset = dataset.get_all_data()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __next__(self):
|
||||
return self.dataset
|
||||
|
||||
|
||||
class Collator:
|
||||
"""
|
||||
This callable class is used to collate the data points fetched from the
|
||||
dataset. The collation is performed based on the type of dataset used and
|
||||
on the batching strategy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, max_conditions_lengths, automatic_batching, dataset=None
|
||||
):
|
||||
"""
|
||||
Initialize the object, setting the collate function based on whether
|
||||
automatic batching is enabled or not.
|
||||
|
||||
:param dict max_conditions_lengths: ``dict`` containing the maximum
|
||||
number of data points to consider in a single batch for
|
||||
each condition.
|
||||
:param bool automatic_batching: Whether automatic PyTorch batching is
|
||||
enabled or not. For more information, see the
|
||||
:class:`~pina.data.data_module.PinaDataModule` class.
|
||||
:param PinaDataset dataset: The dataset where the data is stored.
|
||||
"""
|
||||
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
# Set the collate function based on the batching strategy
|
||||
# collate_pina_dataloader is used when automatic batching is disabled
|
||||
# collate_torch_dataloader is used when automatic batching is enabled
|
||||
self.callable_function = (
|
||||
self._collate_torch_dataloader
|
||||
if automatic_batching
|
||||
else (self._collate_pina_dataloader)
|
||||
)
|
||||
self.dataset = dataset
|
||||
|
||||
# Set the function which performs the actual collation
|
||||
if isinstance(self.dataset, PinaTensorDataset):
|
||||
# If the dataset is a PinaTensorDataset, use this collate function
|
||||
self._collate = self._collate_tensor_dataset
|
||||
else:
|
||||
# If the dataset is a PinaDataset, use this collate function
|
||||
self._collate = self._collate_graph_dataset
|
||||
|
||||
def _collate_pina_dataloader(self, batch):
|
||||
"""
|
||||
Function used to create a batch when automatic batching is disabled.
|
||||
|
||||
:param list[int] batch: List of integers representing the indices of
|
||||
the data points to be fetched.
|
||||
:return: Dictionary containing the data points fetched from the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Call the fetch_from_idx_list method of the dataset
|
||||
return self.dataset.fetch_from_idx_list(batch)
|
||||
|
||||
def _collate_torch_dataloader(self, batch):
|
||||
"""
|
||||
Function used to collate the batch
|
||||
|
||||
:param list[dict] batch: List of retrieved data.
|
||||
:return: Dictionary containing the data points fetched from the dataset,
|
||||
collated.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
batch_dict = {}
|
||||
if isinstance(batch, dict):
|
||||
return batch
|
||||
conditions_names = batch[0].keys()
|
||||
# Condition names
|
||||
for condition_name in conditions_names:
|
||||
single_cond_dict = {}
|
||||
condition_args = batch[0][condition_name].keys()
|
||||
for arg in condition_args:
|
||||
data_list = [
|
||||
batch[idx][condition_name][arg]
|
||||
for idx in range(
|
||||
min(
|
||||
len(batch),
|
||||
self.max_conditions_lengths[condition_name],
|
||||
)
|
||||
)
|
||||
]
|
||||
single_cond_dict[arg] = self._collate(data_list)
|
||||
|
||||
batch_dict[condition_name] = single_cond_dict
|
||||
return batch_dict
|
||||
|
||||
@staticmethod
|
||||
def _collate_tensor_dataset(data_list):
|
||||
"""
|
||||
Function used to collate the data when the dataset is a
|
||||
:class:`~pina.data.dataset.PinaTensorDataset`.
|
||||
|
||||
:param data_list: Elements to be collated.
|
||||
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
||||
:return: Batch of data.
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
||||
:class:`~pina.label_tensor.LabelTensor`.
|
||||
"""
|
||||
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.stack(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
return torch.stack(data_list)
|
||||
raise RuntimeError("Data must be Tensors or LabelTensor ")
|
||||
|
||||
def _collate_graph_dataset(self, data_list):
|
||||
"""
|
||||
Function used to collate data when the dataset is a
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`.
|
||||
|
||||
:param data_list: Elememts to be collated.
|
||||
:type data_list: list[Data] | list[Graph]
|
||||
:return: Batch of data.
|
||||
:rtype: dict
|
||||
|
||||
:raises RuntimeError: If the data is not a
|
||||
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
|
||||
"""
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
return LabelTensor.cat(data_list)
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
return torch.cat(data_list)
|
||||
if isinstance(data_list[0], Data):
|
||||
return self.dataset.create_batch(data_list)
|
||||
raise RuntimeError(
|
||||
"Data must be Tensors or LabelTensor or pyG "
|
||||
"torch_geometric.data.Data"
|
||||
)
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Perform the collation of data fetched from the dataset. The behavoior
|
||||
of the function is set based on the batching strategy during class
|
||||
initialization.
|
||||
|
||||
:param batch: List of retrieved data or sampled indices.
|
||||
:type batch: list[int] | list[dict]
|
||||
:return: Dictionary containing colleted data fetched from the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return self.callable_function(batch)
|
||||
|
||||
|
||||
class PinaSampler:
|
||||
"""
|
||||
This class is used to create the sampler instance based on the shuffle
|
||||
parameter and the environment in which the code is running.
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset):
|
||||
"""
|
||||
Instantiate and initialize the sampler.
|
||||
|
||||
:param PinaDataset dataset: The dataset from which to sample.
|
||||
:return: The sampler instance.
|
||||
:rtype: :class:`torch.utils.data.Sampler`
|
||||
"""
|
||||
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
):
|
||||
sampler = DistributedSampler(dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
return sampler
|
||||
|
||||
|
||||
class PinaDataModule(LightningDataModule):
|
||||
@@ -27,7 +250,7 @@ class PinaDataModule(LightningDataModule):
|
||||
val_size=0.1,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
batching_mode="common_batch_size",
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
@@ -48,12 +271,11 @@ class PinaDataModule(LightningDataModule):
|
||||
Default is ``None``.
|
||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||
Default ``True``.
|
||||
:param bool common_batch_size: If ``True``, the same batch size is used
|
||||
for all conditions. If ``False``, each condition can have its own
|
||||
batch size, proportional to the size of the dataset in that
|
||||
condition. Default is ``True``.
|
||||
:param bool separate_conditions: If ``True``, dataloaders for each
|
||||
condition are iterated separately. Default is ``False``.
|
||||
:param bool repeat: If ``True``, in case of batch size larger than the
|
||||
number of elements in a specific condition, the elements are
|
||||
repeated until the batch size is reached. If ``False``, the number
|
||||
of elements in the batch is the minimum between the batch size and
|
||||
the number of elements in the condition. Default is ``False``.
|
||||
:param automatic_batching: If ``True``, automatic PyTorch batching
|
||||
is performed, which consists of extracting one element at a time
|
||||
from the dataset and collating them into a batch. This is useful
|
||||
@@ -83,7 +305,7 @@ class PinaDataModule(LightningDataModule):
|
||||
# Store fixed attributes
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.batching_mode = batching_mode
|
||||
self.repeat = repeat
|
||||
self.automatic_batching = automatic_batching
|
||||
|
||||
# If batch size is None, num_workers has no effect
|
||||
@@ -154,16 +376,23 @@ class PinaDataModule(LightningDataModule):
|
||||
if stage == "fit" or stage is None:
|
||||
self.train_dataset = PinaDatasetFactory(
|
||||
self.data_splits["train"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
"train"
|
||||
),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
if "val" in self.data_splits.keys():
|
||||
self.val_dataset = PinaDatasetFactory(
|
||||
self.data_splits["val"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
"val"
|
||||
),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
elif stage == "test":
|
||||
self.test_dataset = PinaDatasetFactory(
|
||||
self.data_splits["test"],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths("test"),
|
||||
automatic_batching=self.automatic_batching,
|
||||
)
|
||||
else:
|
||||
@@ -253,7 +482,7 @@ class PinaDataModule(LightningDataModule):
|
||||
dataset_dict[key].update({condition_name: data})
|
||||
return dataset_dict
|
||||
|
||||
def _create_dataloader(self, dataset):
|
||||
def _create_dataloader(self, split, dataset):
|
||||
""" "
|
||||
Create the dataloader for the given split.
|
||||
|
||||
@@ -273,18 +502,53 @@ class PinaDataModule(LightningDataModule):
|
||||
),
|
||||
module="lightning.pytorch.trainer.connectors.data_connector",
|
||||
)
|
||||
dl = PinaDataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=self.shuffle,
|
||||
num_workers=self.num_workers,
|
||||
batching_mode=self.batching_mode,
|
||||
device=self.trainer.strategy.root_device,
|
||||
# Use custom batching (good if batch size is large)
|
||||
if self.batch_size is not None:
|
||||
sampler = PinaSampler(dataset)
|
||||
if self.automatic_batching:
|
||||
collate = Collator(
|
||||
self.find_max_conditions_lengths(split),
|
||||
self.automatic_batching,
|
||||
dataset=dataset,
|
||||
)
|
||||
else:
|
||||
collate = Collator(
|
||||
None, self.automatic_batching, dataset=dataset
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
collate_fn=collate,
|
||||
sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
dataloader = DummyDataloader(dataset)
|
||||
dataloader.dataset = self._transfer_batch_to_device(
|
||||
dataloader.dataset, self.trainer.strategy.root_device, 0
|
||||
)
|
||||
if self.batch_size is None:
|
||||
# Override the method to transfer the batch to the device
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||
return dl
|
||||
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||
return dataloader
|
||||
|
||||
def find_max_conditions_lengths(self, split):
|
||||
"""
|
||||
Define the maximum length for each conditions.
|
||||
|
||||
:param dict split: The split of the dataset.
|
||||
:return: The maximum length per condition.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.data_splits[split].items():
|
||||
if self.batch_size is None:
|
||||
max_conditions_lengths[k] = len(v["input"])
|
||||
elif self.repeat:
|
||||
max_conditions_lengths[k] = self.batch_size
|
||||
else:
|
||||
max_conditions_lengths[k] = min(
|
||||
len(v["input"]), self.batch_size
|
||||
)
|
||||
return max_conditions_lengths
|
||||
|
||||
def val_dataloader(self):
|
||||
"""
|
||||
@@ -293,7 +557,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:return: The validation dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
return self._create_dataloader(self.val_dataset)
|
||||
return self._create_dataloader("val", self.val_dataset)
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
@@ -302,7 +566,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:return: The training dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
return self._create_dataloader(self.train_dataset)
|
||||
return self._create_dataloader("train", self.train_dataset)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
@@ -311,7 +575,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:return: The testing dataloader
|
||||
:rtype: torch.utils.data.DataLoader
|
||||
"""
|
||||
return self._create_dataloader(self.test_dataset)
|
||||
return self._create_dataloader("test", self.test_dataset)
|
||||
|
||||
@staticmethod
|
||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||
@@ -327,7 +591,7 @@ class PinaDataModule(LightningDataModule):
|
||||
:rtype: list[tuple]
|
||||
"""
|
||||
|
||||
return list(batch.items())
|
||||
return batch
|
||||
|
||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
@@ -385,15 +649,9 @@ class PinaDataModule(LightningDataModule):
|
||||
|
||||
to_return = {}
|
||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||
to_return["train"] = {
|
||||
cond: data.input for cond, data in self.train_dataset.items()
|
||||
}
|
||||
to_return["train"] = self.train_dataset.input
|
||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||
to_return["val"] = {
|
||||
cond: data.input for cond, data in self.val_dataset.items()
|
||||
}
|
||||
to_return["val"] = self.val_dataset.input
|
||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||
to_return["test"] = {
|
||||
cond: data.input for cond, data in self.test_dataset.items()
|
||||
}
|
||||
to_return["test"] = self.test_dataset.input
|
||||
return to_return
|
||||
|
||||
@@ -1,347 +0,0 @@
|
||||
"""DataLoader module for PinaDataset."""
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from .stacked_dataloader import StackedDataLoader
|
||||
|
||||
|
||||
class DummyDataloader:
|
||||
"""
|
||||
DataLoader that returns the entire dataset in a single batch.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, device=None):
|
||||
"""
|
||||
Prepare a dataloader object that returns the entire dataset in a single
|
||||
batch. Depending on the number of GPUs, the dataset is managed
|
||||
as follows:
|
||||
|
||||
- **Distributed Environment** (multiple GPUs): Divides dataset across
|
||||
processes using the rank and world size. Fetches only portion of
|
||||
data corresponding to the current process.
|
||||
- **Non-Distributed Environment** (single GPU): Fetches the entire
|
||||
dataset.
|
||||
|
||||
:param PinaDataset dataset: The dataset object to be processed.
|
||||
|
||||
.. note::
|
||||
This dataloader is used when the batch size is ``None``.
|
||||
"""
|
||||
# Handle distributed environment
|
||||
if PinaSampler.is_distributed():
|
||||
# Get rank and world size
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
# Ensure dataset is large enough
|
||||
if len(dataset) < world_size:
|
||||
raise RuntimeError(
|
||||
"Dimension of the dataset smaller than world size."
|
||||
" Increase the size of the partition or use a single GPU"
|
||||
)
|
||||
# Split dataset among processes
|
||||
idx, i = [], rank
|
||||
while i < len(dataset):
|
||||
idx.append(i)
|
||||
i += world_size
|
||||
else:
|
||||
idx = list(range(len(dataset)))
|
||||
self.dataset = dataset.getitem_from_list(idx)
|
||||
self.device = device
|
||||
self.dataset = (
|
||||
{k: v.to(self.device) for k, v in self.dataset.items()}
|
||||
if self.device
|
||||
else self.dataset
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterate over the dataloader.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the length of the dataloader, which is always 1.
|
||||
:return: The length of the dataloader.
|
||||
:rtype: int
|
||||
"""
|
||||
return 1
|
||||
|
||||
def __next__(self):
|
||||
"""
|
||||
Return the entire dataset as a single batch.
|
||||
:return: The entire dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self.dataset
|
||||
|
||||
|
||||
class PinaSampler:
|
||||
"""
|
||||
This class is used to create the sampler instance based on the shuffle
|
||||
parameter and the environment in which the code is running.
|
||||
"""
|
||||
|
||||
def __new__(cls, dataset, shuffle=True):
|
||||
"""
|
||||
Instantiate and initialize the sampler.
|
||||
|
||||
:param PinaDataset dataset: The dataset from which to sample.
|
||||
:return: The sampler instance.
|
||||
:rtype: :class:`torch.utils.data.Sampler`
|
||||
"""
|
||||
|
||||
if cls.is_distributed():
|
||||
sampler = DistributedSampler(dataset, shuffle=shuffle)
|
||||
else:
|
||||
if shuffle:
|
||||
sampler = torch.utils.data.RandomSampler(dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
return sampler
|
||||
|
||||
@staticmethod
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if the sampler is distributed.
|
||||
:return: True if the sampler is distributed, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
)
|
||||
|
||||
|
||||
def _collect_items(batch):
|
||||
"""
|
||||
Helper function to collect items from a batch of graph data samples.
|
||||
:param batch: List of graph data samples.
|
||||
"""
|
||||
to_return = {name: [] for name in batch[0].keys()}
|
||||
for sample in batch:
|
||||
for k, v in sample.items():
|
||||
to_return[k].append(v)
|
||||
return to_return
|
||||
|
||||
|
||||
def collate_fn_custom(batch, dataset):
|
||||
"""
|
||||
Override the default collate function to handle datasets without automatic
|
||||
batching.
|
||||
:param batch: List of indices from the dataset.
|
||||
:param dataset: The PinaDataset instance (must be provided).
|
||||
"""
|
||||
return dataset.getitem_from_list(batch)
|
||||
|
||||
|
||||
def collate_fn_default(batch, stack_fn):
|
||||
"""
|
||||
Default collate function that simply returns the batch as is.
|
||||
:param batch: List of data samples.
|
||||
"""
|
||||
to_return = _collect_items(batch)
|
||||
return {k: stack_fn[k](v) for k, v in to_return.items()}
|
||||
|
||||
|
||||
class PinaDataLoader:
|
||||
"""
|
||||
Custom DataLoader for PinaDataset.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
|
||||
batch_size = kwargs.get("batch_size")
|
||||
if batching_mode == "stacked" and batch_size is not None:
|
||||
return StackedDataLoader(
|
||||
args[0],
|
||||
batch_size=batch_size,
|
||||
shuffle=kwargs.get("shuffle", True),
|
||||
)
|
||||
elif batch_size is None:
|
||||
kwargs["batching_mode"] = "proportional"
|
||||
print(
|
||||
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
|
||||
)
|
||||
return super(PinaDataLoader, cls).__new__(cls)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dict,
|
||||
batch_size,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
batching_mode="common_batch_size",
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize the PinaDataLoader.
|
||||
|
||||
:param dict dataset_dict: A dictionary mapping dataset names to their
|
||||
respective PinaDataset instances.
|
||||
:param int batch_size: The batch size for the dataloader.
|
||||
:param int num_workers: Number of worker processes for data loading.
|
||||
:param bool shuffle: Whether to shuffle the data at every epoch.
|
||||
:param str batching_mode: The batching mode to use. Options are
|
||||
"common_batch_size", "separate_conditions", and "proportional".
|
||||
:param device: The device to which the data should be moved.
|
||||
"""
|
||||
|
||||
self.dataset_dict = dataset_dict
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.batching_mode = batching_mode.lower()
|
||||
self.device = device
|
||||
|
||||
# Batch size None means we want to load the entire dataset in a single
|
||||
# batch
|
||||
if batch_size is None:
|
||||
batch_size_per_dataset = {
|
||||
split: None for split in dataset_dict.keys()
|
||||
}
|
||||
else:
|
||||
# Compute batch size per dataset
|
||||
if batching_mode in ["common_batch_size", "separate_conditions"]:
|
||||
# (the sum of the batch sizes is equal to
|
||||
# n_conditions * batch_size)
|
||||
batch_size_per_dataset = {
|
||||
split: min(batch_size, len(ds))
|
||||
for split, ds in dataset_dict.items()
|
||||
}
|
||||
elif batching_mode == "proportional":
|
||||
# batch sizes is equal to the specified batch size)
|
||||
batch_size_per_dataset = self._compute_batch_size()
|
||||
|
||||
# Creaete a dataloader per dataset
|
||||
self.dataloaders = {
|
||||
split: self._create_dataloader(
|
||||
dataset, batch_size_per_dataset[split]
|
||||
)
|
||||
for split, dataset in dataset_dict.items()
|
||||
}
|
||||
|
||||
def _compute_batch_size(self):
|
||||
"""
|
||||
Compute an appropriate batch size for the given dataset.
|
||||
"""
|
||||
|
||||
# Compute number of elements per dataset
|
||||
elements_per_dataset = {
|
||||
dataset_name: len(dataset)
|
||||
for dataset_name, dataset in self.dataset_dict.items()
|
||||
}
|
||||
# Compute the total number of elements
|
||||
total_elements = sum(el for el in elements_per_dataset.values())
|
||||
# Compute the portion of each dataset
|
||||
portion_per_dataset = {
|
||||
name: el / total_elements
|
||||
for name, el in elements_per_dataset.items()
|
||||
}
|
||||
# Compute batch size per dataset. Ensure at least 1 element per
|
||||
# dataset.
|
||||
batch_size_per_dataset = {
|
||||
name: max(1, int(portion * self.batch_size))
|
||||
for name, portion in portion_per_dataset.items()
|
||||
}
|
||||
# Adjust batch sizes to match the specified total batch size
|
||||
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
|
||||
if self.batch_size > tot_el_per_batch:
|
||||
difference = self.batch_size - tot_el_per_batch
|
||||
while difference > 0:
|
||||
for k, v in batch_size_per_dataset.items():
|
||||
if difference == 0:
|
||||
break
|
||||
if v > 1:
|
||||
batch_size_per_dataset[k] += 1
|
||||
difference -= 1
|
||||
if self.batch_size < tot_el_per_batch:
|
||||
difference = tot_el_per_batch - self.batch_size
|
||||
while difference > 0:
|
||||
for k, v in batch_size_per_dataset.items():
|
||||
if difference == 0:
|
||||
break
|
||||
if v > 1:
|
||||
batch_size_per_dataset[k] -= 1
|
||||
difference -= 1
|
||||
return batch_size_per_dataset
|
||||
|
||||
def _create_dataloader(self, dataset, batch_size):
|
||||
"""
|
||||
Create the dataloader for the given dataset.
|
||||
|
||||
:param PinaDataset dataset: The dataset for which to create the
|
||||
dataloader.
|
||||
:param int batch_size: The batch size for the dataloader.
|
||||
:return: The created dataloader.
|
||||
:rtype: :class:`torch.utils.data.DataLoader`
|
||||
"""
|
||||
# If batch size is None, use DummyDataloader
|
||||
if batch_size is None or batch_size >= len(dataset):
|
||||
return DummyDataloader(dataset, device=self.device)
|
||||
|
||||
# Determine the appropriate collate function
|
||||
if not dataset.automatic_batching:
|
||||
collate_fn = partial(collate_fn_custom, dataset=dataset)
|
||||
else:
|
||||
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
|
||||
|
||||
# Create and return the dataloader
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=self.num_workers,
|
||||
sampler=PinaSampler(dataset, shuffle=self.shuffle),
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the length of the dataloader.
|
||||
|
||||
:return: The length of the dataloader.
|
||||
:rtype: int
|
||||
"""
|
||||
# If separate conditions, return sum of lengths of all dataloaders
|
||||
# else, return max length among dataloaders
|
||||
if self.batching_mode == "separate_conditions":
|
||||
return sum(len(dl) for dl in self.dataloaders.values())
|
||||
return max(len(dl) for dl in self.dataloaders.values())
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterate over the dataloader. Yields a dictionary mapping split name to batch.
|
||||
|
||||
The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
|
||||
"""
|
||||
if self.batching_mode == "separate_conditions":
|
||||
tmp = []
|
||||
for split, dl in self.dataloaders.items():
|
||||
len_split = len(dl)
|
||||
for i, batch in enumerate(dl):
|
||||
tmp.append({split: batch})
|
||||
if i + 1 >= len_split:
|
||||
break
|
||||
random.shuffle(tmp)
|
||||
for batch_dict in tmp:
|
||||
yield batch_dict
|
||||
return
|
||||
|
||||
# Common_batch_size or Proportional mode (round-robin sampling)
|
||||
iterators = {
|
||||
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
|
||||
}
|
||||
|
||||
# Iterate for the length of the longest dataloader
|
||||
for _ in range(len(self)):
|
||||
batch_dict: BatchDict = {}
|
||||
for split, it in iterators.items():
|
||||
# Since we use itertools.cycle, next(it) will always yield a batch
|
||||
# by repeating the dataset, so no need for the 'if batch is None: return' check.
|
||||
batch_dict[split] = next(it)
|
||||
yield batch_dict
|
||||
@@ -1,170 +1,326 @@
|
||||
"""Module for the PINA dataset classes."""
|
||||
|
||||
import torch
|
||||
from abc import abstractmethod, ABC
|
||||
from torch.utils.data import Dataset
|
||||
from torch_geometric.data import Data
|
||||
from ..graph import Graph, LabelBatch
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
STACK_FN_MAP = {
|
||||
"label_tensor": LabelTensor.stack,
|
||||
"tensor": torch.stack,
|
||||
"data": LabelBatch.from_data_list,
|
||||
}
|
||||
|
||||
|
||||
class PinaDatasetFactory:
|
||||
"""
|
||||
Factory class to create PINA datasets based on the provided conditions
|
||||
dictionary.
|
||||
Factory class for the PINA dataset.
|
||||
|
||||
Depending on the data type inside the conditions, it instanciate an object
|
||||
belonging to the appropriate subclass of
|
||||
:class:`~pina.data.dataset.PinaDataset`. The possible subclasses are:
|
||||
|
||||
- :class:`~pina.data.dataset.PinaTensorDataset`, for handling \
|
||||
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
|
||||
- :class:`~pina.data.dataset.PinaGraphDataset`, for handling \
|
||||
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||
"""
|
||||
|
||||
def __new__(cls, conditions_dict, **kwargs):
|
||||
"""
|
||||
Create PINA dataset instances based on the provided conditions
|
||||
dictionary.
|
||||
Instantiate the appropriate subclass of
|
||||
:class:`~pina.data.dataset.PinaDataset`.
|
||||
|
||||
:param dict conditions_dict: A dictionary where keys are condition names
|
||||
and values are dictionaries containing the associated data.
|
||||
:return: A dictionary mapping condition names to their respective
|
||||
:class:`PinaDataset` instances.
|
||||
If a graph is present in the conditions, returns a
|
||||
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||
:class:`~pina.data.dataset.PinaTensorDataset`.
|
||||
|
||||
:param dict conditions_dict: Dictionary containing all the conditions
|
||||
to be included in the dataset instance.
|
||||
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
|
||||
:rtype: PinaTensorDataset | PinaGraphDataset
|
||||
|
||||
:raises ValueError: If an empty dictionary is provided.
|
||||
"""
|
||||
|
||||
# Check if conditions_dict is empty
|
||||
if len(conditions_dict) == 0:
|
||||
raise ValueError("No conditions provided")
|
||||
|
||||
dataset_dict = {} # Dictionary to hold the created datasets
|
||||
|
||||
# Check is a Graph is present in the conditions
|
||||
for name, data in conditions_dict.items():
|
||||
# Validate that data is a dictionary
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(
|
||||
f"Condition '{name}' data must be a dictionary"
|
||||
)
|
||||
# Create PinaDataset instance for each condition
|
||||
dataset_dict[name] = PinaDataset(data, **kwargs)
|
||||
return dataset_dict
|
||||
is_graph = cls._is_graph_dataset(conditions_dict)
|
||||
if is_graph:
|
||||
# If a Graph is present, return a PinaGraphDataset
|
||||
return PinaGraphDataset(conditions_dict, **kwargs)
|
||||
# If no Graph is present, return a PinaTensorDataset
|
||||
return PinaTensorDataset(conditions_dict, **kwargs)
|
||||
|
||||
|
||||
class PinaDataset(Dataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||
:class:`~pina.label_tensor.LabelTensor` data.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dict, automatic_batching=None):
|
||||
@staticmethod
|
||||
def _is_graph_dataset(conditions_dict):
|
||||
"""
|
||||
Initialize the instance by storing the conditions dictionary.
|
||||
Check if a graph is present in the conditions (at least one time).
|
||||
|
||||
:param conditions_dict: Dictionary containing the conditions.
|
||||
:type conditions_dict: dict
|
||||
:return: True if a graph is present in the conditions, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
# Iterate over the conditions dictionary
|
||||
for v in conditions_dict.values():
|
||||
# Iterate over the values of the current condition
|
||||
for cond in v.values():
|
||||
# Check if the current value is a list of Data objects
|
||||
if isinstance(cond, (Data, Graph, list, tuple)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class PinaDataset(Dataset, ABC):
|
||||
"""
|
||||
Abstract class for the PINA dataset which extends the PyTorch
|
||||
:class:`~torch.utils.data.Dataset` class. It defines the common interface
|
||||
for :class:`~pina.data.dataset.PinaTensorDataset` and
|
||||
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||
):
|
||||
"""
|
||||
Initialize the instance by storing the conditions dictionary, the
|
||||
maximum number of items per conditions to consider, and the automatic
|
||||
batching flag.
|
||||
|
||||
:param dict conditions_dict: A dictionary mapping condition names to
|
||||
their respective data. Each key represents a condition name, and the
|
||||
corresponding value is a dictionary containing the associated data.
|
||||
:param dict max_conditions_lengths: Maximum number of data points that
|
||||
can be included in a single batch per condition.
|
||||
:param bool automatic_batching: Indicates whether PyTorch automatic
|
||||
batching is enabled in
|
||||
:class:`~pina.data.data_module.PinaDataModule`.
|
||||
"""
|
||||
|
||||
# Store the conditions dictionary
|
||||
self.data = data_dict
|
||||
self.automatic_batching = (
|
||||
automatic_batching if automatic_batching is not None else True
|
||||
)
|
||||
self._stack_fn = {}
|
||||
self.is_graph_dataset = False
|
||||
# Determine stacking functions for each data type (used in collate_fn)
|
||||
for k, v in data_dict.items():
|
||||
if isinstance(v, LabelTensor):
|
||||
self._stack_fn[k] = "label_tensor"
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stack_fn[k] = "tensor"
|
||||
elif isinstance(v, list) and all(
|
||||
isinstance(item, (Data, Graph)) for item in v
|
||||
):
|
||||
self._stack_fn[k] = "data"
|
||||
self.is_graph_dataset = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported data type for stacking: {type(v)}"
|
||||
)
|
||||
self.conditions_dict = conditions_dict
|
||||
# Store the maximum number of conditions to consider
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
# Store length of each condition
|
||||
self.conditions_length = {
|
||||
k: len(v["input"]) for k, v in self.conditions_dict.items()
|
||||
}
|
||||
# Store the maximum length of the dataset
|
||||
self.length = max(self.conditions_length.values())
|
||||
# Dynamically set the getitem function based on automatic batching
|
||||
if automatic_batching:
|
||||
self._getitem_func = self._getitem_int
|
||||
else:
|
||||
self._getitem_func = self._getitem_dummy
|
||||
|
||||
def __len__(self):
|
||||
def _get_max_len(self):
|
||||
"""
|
||||
Return the length of the dataset.
|
||||
Returns the length of the longest condition in the dataset.
|
||||
|
||||
:return: The length of the dataset.
|
||||
:return: Length of the longest condition in the dataset.
|
||||
:rtype: int
|
||||
"""
|
||||
return len(next(iter(self.data.values())))
|
||||
|
||||
max_len = 0
|
||||
for condition in self.conditions_dict.values():
|
||||
max_len = max(max_len, len(condition["input"]))
|
||||
return max_len
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._getitem_func(idx)
|
||||
|
||||
def _getitem_dummy(self, idx):
|
||||
"""
|
||||
Return the data at the given index in the dataset.
|
||||
Return the index itself. This is used when automatic batching is
|
||||
disabled to postpone the data retrieval to the dataloader.
|
||||
|
||||
:param int idx: Index.
|
||||
:return: Index.
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
# If automatic batching is disabled, return the data at the given index
|
||||
return idx
|
||||
|
||||
def _getitem_int(self, idx):
|
||||
"""
|
||||
Return the data at the given index in the dataset. This is used when
|
||||
automatic batching is enabled.
|
||||
|
||||
:param int idx: Index.
|
||||
:return: A dictionary containing the data at the given index.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
if self.automatic_batching:
|
||||
# Return the data at the given index
|
||||
return {
|
||||
field_name: data[idx] for field_name, data in self.data.items()
|
||||
}
|
||||
return idx
|
||||
# If automatic batching is enabled, return the data at the given index
|
||||
return {
|
||||
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
|
||||
for k, v in self.conditions_dict.items()
|
||||
}
|
||||
|
||||
def getitem_from_list(self, idx_list):
|
||||
def get_all_data(self):
|
||||
"""
|
||||
Return all data in the dataset.
|
||||
|
||||
:return: A dictionary containing all the data in the dataset.
|
||||
:rtype: dict
|
||||
"""
|
||||
to_return_dict = {}
|
||||
for condition, data in self.conditions_dict.items():
|
||||
len_condition = len(
|
||||
data["input"]
|
||||
) # Length of the current condition
|
||||
to_return_dict[condition] = self._retrive_data(
|
||||
data, list(range(len_condition))
|
||||
) # Retrieve the data from the current condition
|
||||
return to_return_dict
|
||||
|
||||
def fetch_from_idx_list(self, idx):
|
||||
"""
|
||||
Return data from the dataset given a list of indices.
|
||||
|
||||
:param list[int] idx_list: List of indices.
|
||||
:param list[int] idx: List of indices.
|
||||
:return: A dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
to_return = {}
|
||||
for field_name, data in self.data.items():
|
||||
if self._stack_fn[field_name] == "data":
|
||||
fn = STACK_FN_MAP[self._stack_fn[field_name]]
|
||||
to_return[field_name] = fn([data[i] for i in idx_list])
|
||||
else:
|
||||
to_return[field_name] = data[idx_list]
|
||||
return to_return
|
||||
to_return_dict = {}
|
||||
for condition, data in self.conditions_dict.items():
|
||||
# Get the indices for the current condition
|
||||
cond_idx = idx[: self.max_conditions_lengths[condition]]
|
||||
# Get the length of the current condition
|
||||
condition_len = self.conditions_length[condition]
|
||||
# If the length of the dataset is greater than the length of the
|
||||
# current condition, repeat the indices
|
||||
if self.length > condition_len:
|
||||
cond_idx = [idx % condition_len for idx in cond_idx]
|
||||
# Retrieve the data from the current condition
|
||||
to_return_dict[condition] = self._retrive_data(data, cond_idx)
|
||||
return to_return_dict
|
||||
|
||||
def update_data(self, update_dict):
|
||||
@abstractmethod
|
||||
def _retrive_data(self, data, idx_list):
|
||||
"""
|
||||
Abstract method to retrieve data from the dataset given a list of
|
||||
indices.
|
||||
"""
|
||||
Update the dataset's data in-place.
|
||||
|
||||
:param dict update_dict: A dictionary where keys are condition names
|
||||
and values are dictionaries with updated data for those conditions.
|
||||
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||
:class:`~pina.label_tensor.LabelTensor` data.
|
||||
"""
|
||||
|
||||
# Override _retrive_data method for torch.Tensor data
|
||||
def _retrive_data(self, data, idx_list):
|
||||
"""
|
||||
for field_name, updates in update_dict.items():
|
||||
if field_name not in self.data:
|
||||
raise KeyError(
|
||||
f"Condition '{field_name}' not found in dataset."
|
||||
)
|
||||
if not isinstance(updates, (LabelTensor, torch.Tensor)):
|
||||
raise ValueError(
|
||||
f"Updates for condition '{field_name}' must be of type "
|
||||
f"LabelTensor or torch.Tensor."
|
||||
)
|
||||
self.data[field_name] = updates
|
||||
Retrieve data from the dataset given a list of indices.
|
||||
|
||||
:param dict data: Dictionary containing the data
|
||||
(only :class:`torch.Tensor` or
|
||||
:class:`~pina.label_tensor.LabelTensor`).
|
||||
:param list[int] idx_list: indices to retrieve.
|
||||
:return: Dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return {k: v[idx_list] for k, v in data.items()}
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
Get the input data from the dataset.
|
||||
Return the input data for the dataset.
|
||||
|
||||
:return: The input data.
|
||||
:rtype: torch.Tensor | LabelTensor | Data | Graph
|
||||
"""
|
||||
return self.data["input"]
|
||||
|
||||
@property
|
||||
def stack_fn(self):
|
||||
"""
|
||||
Get the mapping of stacking functions for each data type in the dataset.
|
||||
|
||||
:return: A dictionary mapping condition names to their respective
|
||||
stacking function identifiers.
|
||||
:return: Dictionary containing the input points.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()}
|
||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||
|
||||
def update_data(self, new_conditions_dict):
|
||||
"""
|
||||
Update the dataset with new data.
|
||||
This method is used to update the dataset with new data. It replaces
|
||||
the current data with the new data provided in the new_conditions_dict
|
||||
parameter.
|
||||
|
||||
:param dict new_conditions_dict: Dictionary containing the new data.
|
||||
:return: None
|
||||
"""
|
||||
for condition, data in new_conditions_dict.items():
|
||||
if condition in self.conditions_dict:
|
||||
self.conditions_dict[condition].update(data)
|
||||
else:
|
||||
self.conditions_dict[condition] = data
|
||||
|
||||
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
"""
|
||||
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
|
||||
and :class:`~pina.graph.Graph` data.
|
||||
"""
|
||||
|
||||
def _create_graph_batch(self, data):
|
||||
"""
|
||||
Create a LabelBatch object from a list of
|
||||
:class:`~torch_geometric.data.Data` objects.
|
||||
|
||||
:param data: List of items to collate in a single batch.
|
||||
:type data: list[Data] | list[Graph]
|
||||
:return: LabelBatch object all the graph collated in a single batch
|
||||
disconnected graphs.
|
||||
:rtype: LabelBatch
|
||||
"""
|
||||
batch = LabelBatch.from_data_list(data)
|
||||
return batch
|
||||
|
||||
def create_batch(self, data):
|
||||
"""
|
||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||
objects.
|
||||
|
||||
:param data: List of items to collate in a single batch.
|
||||
:type data: list[Data] | list[Graph]
|
||||
:return: Batch object.
|
||||
:rtype: :class:`~torch_geometric.data.Batch`
|
||||
| :class:`~pina.graph.LabelBatch`
|
||||
"""
|
||||
|
||||
if isinstance(data[0], Data):
|
||||
return self._create_graph_batch(data)
|
||||
return self._create_tensor_batch(data)
|
||||
|
||||
# Override _retrive_data method for graph handling
|
||||
def _retrive_data(self, data, idx_list):
|
||||
"""
|
||||
Retrieve data from the dataset given a list of indices.
|
||||
|
||||
:param dict data: Dictionary containing the data.
|
||||
:param list[int] idx_list: List of indices to retrieve.
|
||||
:return: Dictionary containing the data at the given indices.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
# Return the data from the current condition
|
||||
# If the data is a list of Data objects, create a Batch object
|
||||
# If the data is a list of torch.Tensor objects, create a torch.Tensor
|
||||
return {
|
||||
k: (
|
||||
self._create_graph_batch([v[i] for i in idx_list])
|
||||
if isinstance(v, list)
|
||||
else v[idx_list]
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
"""
|
||||
Return the input data for the dataset.
|
||||
|
||||
:return: Dictionary containing the input points.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
import torch
|
||||
from math import ceil
|
||||
|
||||
|
||||
class StackedDataLoader:
|
||||
def __init__(self, datasets, batch_size=32, shuffle=True):
|
||||
for d in datasets.values():
|
||||
if d.is_graph_dataset:
|
||||
raise ValueError("Each dataset must be a dictionary")
|
||||
self.chunks = {}
|
||||
self.total_length = 0
|
||||
self.indices = []
|
||||
|
||||
self._init_chunks(datasets)
|
||||
self.indices = list(range(self.total_length))
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
if self.shuffle:
|
||||
torch.random.manual_seed(42)
|
||||
self.indices = torch.randperm(self.total_length).tolist()
|
||||
self.datasets = datasets
|
||||
|
||||
def _init_chunks(self, datasets):
|
||||
inc = 0
|
||||
total_length = 0
|
||||
for name, dataset in datasets.items():
|
||||
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
|
||||
inc += len(dataset)
|
||||
self.total_length = inc
|
||||
|
||||
def __len__(self):
|
||||
return ceil(self.total_length / self.batch_size)
|
||||
|
||||
def _build_batch_indices(self, batch_idx):
|
||||
start = batch_idx * self.batch_size
|
||||
end = min(start + self.batch_size, self.total_length)
|
||||
return self.indices[start:end]
|
||||
|
||||
def __iter__(self):
|
||||
for batch_idx in range(len(self)):
|
||||
batch_indices = self._build_batch_indices(batch_idx)
|
||||
batch_data = {}
|
||||
for name, chunk in self.chunks.items():
|
||||
local_indices = [
|
||||
idx - chunk["start"]
|
||||
for idx in batch_indices
|
||||
if chunk["start"] <= idx < chunk["end"]
|
||||
]
|
||||
if local_indices:
|
||||
batch_data[name] = self.datasets[name].getitem_from_list(
|
||||
local_indices
|
||||
)
|
||||
yield batch_data
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module for the Equation."""
|
||||
|
||||
import inspect
|
||||
|
||||
from .equation_interface import EquationInterface
|
||||
|
||||
|
||||
@@ -48,10 +49,6 @@ class Equation(EquationInterface):
|
||||
:raises RuntimeError: If the underlying equation signature length is not
|
||||
2 (direct problem) or 3 (inverse problem).
|
||||
"""
|
||||
# Move the equation to the input_ device
|
||||
self.to(input_.device)
|
||||
|
||||
# Call the underlying equation based on its signature length
|
||||
if self.__len_sig == 2:
|
||||
return self.__equation(input_, output_)
|
||||
if self.__len_sig == 3:
|
||||
|
||||
@@ -239,19 +239,19 @@ class Advection(Equation): # pylint: disable=R0903
|
||||
)
|
||||
|
||||
# Ensure consistency of c length
|
||||
if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
|
||||
if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1:
|
||||
raise ValueError(
|
||||
"If 'c' is passed as a list, its length must be equal to "
|
||||
"the number of spatial dimensions."
|
||||
)
|
||||
|
||||
# Repeat c to ensure consistent shape for advection
|
||||
c = self.c.repeat(output_.shape[0], 1)
|
||||
if c.shape[1] != (len(input_lbl) - 1):
|
||||
c = c.repeat(1, len(input_lbl) - 1)
|
||||
self.c = self.c.repeat(output_.shape[0], 1)
|
||||
if self.c.shape[1] != (len(input_lbl) - 1):
|
||||
self.c = self.c.repeat(1, len(input_lbl) - 1)
|
||||
|
||||
# Add a dimension to c for the following operations
|
||||
c = c.unsqueeze(-1)
|
||||
self.c = self.c.unsqueeze(-1)
|
||||
|
||||
# Compute the time derivative and the spatial gradient
|
||||
time_der = grad(output_, input_, components=None, d="t")
|
||||
@@ -262,7 +262,7 @@ class Advection(Equation): # pylint: disable=R0903
|
||||
tmp = tmp.transpose(-1, -2)
|
||||
|
||||
# Compute advection term
|
||||
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
|
||||
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
|
||||
|
||||
return time_der + adv
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module for the Equation Interface."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
|
||||
|
||||
class EquationInterface(metaclass=ABCMeta):
|
||||
@@ -34,33 +33,3 @@ class EquationInterface(metaclass=ABCMeta):
|
||||
:return: The computed residual of the equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
Move all tensor attributes to the specified device.
|
||||
|
||||
:param torch.device device: The target device to move the tensors to.
|
||||
:return: The instance moved to the specified device.
|
||||
:rtype: EquationInterface
|
||||
"""
|
||||
# Iterate over all attributes of the Equation
|
||||
for key, val in self.__dict__.items():
|
||||
|
||||
# Move tensors in dictionaries to the specified device
|
||||
if isinstance(val, dict):
|
||||
self.__dict__[key] = {
|
||||
k: v.to(device) if torch.is_tensor(v) else v
|
||||
for k, v in val.items()
|
||||
}
|
||||
|
||||
# Move tensors in lists to the specified device
|
||||
elif isinstance(val, list):
|
||||
self.__dict__[key] = [
|
||||
v.to(device) if torch.is_tensor(v) else v for v in val
|
||||
]
|
||||
|
||||
# Move tensor attributes to the specified device
|
||||
elif torch.is_tensor(val):
|
||||
self.__dict__[key] = val.to(device)
|
||||
|
||||
return self
|
||||
|
||||
@@ -101,10 +101,6 @@ class SystemEquation(EquationInterface):
|
||||
:return: The aggregated residuals of the system of equations.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# Move the equation to the input_ device
|
||||
self.to(input_.device)
|
||||
|
||||
# Compute the residual for each equation
|
||||
residual = torch.hstack(
|
||||
[
|
||||
equation.residual(input_, output_, params_)
|
||||
@@ -112,7 +108,6 @@ class SystemEquation(EquationInterface):
|
||||
]
|
||||
)
|
||||
|
||||
# Skip reduction if not specified
|
||||
if self.reduction is None:
|
||||
return residual
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=R0913, R0917
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
@@ -143,7 +143,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
|
||||
def forward(
|
||||
self, x, pos, edge_index, edge_attr=None, vel=None
|
||||
): # pylint: disable=R0917
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
@@ -169,7 +171,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
|
||||
)
|
||||
|
||||
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
|
||||
def message(
|
||||
self, x_i, x_j, pos_i, pos_j, edge_attr
|
||||
): # pylint: disable=R0917
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
@@ -234,7 +238,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
|
||||
return agg_message, agg_m_ij
|
||||
|
||||
def update(self, aggregated_inputs, x, pos, edge_index, vel):
|
||||
def update(
|
||||
self, aggregated_inputs, x, pos, edge_index, vel
|
||||
): # pylint: disable=R0917
|
||||
"""
|
||||
Update node features, positions, and optionally velocities.
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
||||
<https://arxiv.org/abs/2401.11037>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=R0913, R0917
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
@@ -101,7 +101,9 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
||||
flow=flow,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, vel, edge_index, edge_attr=None):
|
||||
def forward( # pylint: disable=R0917
|
||||
self, x, pos, vel, edge_index, edge_attr=None
|
||||
):
|
||||
"""
|
||||
Forward pass of the Equivariant Graph Neural Operator block.
|
||||
|
||||
@@ -182,7 +184,11 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
||||
weights = torch.complex(real[..., :modes], img[..., :modes])
|
||||
|
||||
# Convolution in Fourier space
|
||||
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
|
||||
# torch.fft.rfftn and irfftn are callable functions, but pylint
|
||||
# incorrectly flags them as E1102 (not callable).
|
||||
fourier = torch.fft.rfftn(x, dim=[0])[:modes] # pylint: disable=E1102
|
||||
out = torch.einsum(einsum_idx, fourier, weights)
|
||||
|
||||
return torch.fft.irfftn(out, s=x.shape[0], dim=0)
|
||||
return torch.fft.irfftn( # pylint: disable=E1102
|
||||
out, s=x.shape[0], dim=0
|
||||
)
|
||||
|
||||
@@ -5,7 +5,9 @@ from ..utils import check_positive_integer
|
||||
from .block.message_passing import EquivariantGraphNeuralOperatorBlock
|
||||
|
||||
|
||||
class EquivariantGraphNeuralOperator(torch.nn.Module):
|
||||
# Disable pylint warnings for too few public methods (since this is a simple
|
||||
# model class in a standard PyTorch style)
|
||||
class EquivariantGraphNeuralOperator(torch.nn.Module): # pylint: disable=R0903
|
||||
"""
|
||||
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
|
||||
|
||||
@@ -32,7 +34,9 @@ class EquivariantGraphNeuralOperator(torch.nn.Module):
|
||||
<https://arxiv.org/abs/2401.11037>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
# Disable pylint warnings for too many arguments in init (since this is a
|
||||
# model class with many configurable parameters)
|
||||
def __init__( # pylint: disable=R0913, R0917, R0914
|
||||
self,
|
||||
n_egno_layers,
|
||||
node_feature_dim,
|
||||
|
||||
@@ -48,10 +48,11 @@ class HelmholtzProblem(SpatialProblem):
|
||||
:type alpha: float | int
|
||||
"""
|
||||
super().__init__()
|
||||
check_consistency(alpha, (int, float))
|
||||
self.alpha = alpha
|
||||
|
||||
def forcing_term(input_):
|
||||
self.alpha = alpha
|
||||
check_consistency(alpha, (int, float))
|
||||
|
||||
def forcing_term(self, input_):
|
||||
"""
|
||||
Implementation of the forcing term.
|
||||
"""
|
||||
|
||||
@@ -71,7 +71,9 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
# Override the compilation, compiling only for torch < 2.8, see
|
||||
# related issue at https://github.com/mathLab/PINA/issues/621
|
||||
if torch.__version__ >= "2.8":
|
||||
if torch.__version__ < "2.8":
|
||||
self.trainer.compile = True
|
||||
else:
|
||||
self.trainer.compile = False
|
||||
warnings.warn(
|
||||
"Compilation is disabled for torch >= 2.8. "
|
||||
|
||||
@@ -174,7 +174,11 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
:return: The result of the parent class ``setup`` method.
|
||||
:rtype: Any
|
||||
"""
|
||||
if self.trainer.compile and not self._is_compiled():
|
||||
if stage == "fit" and self.trainer.compile:
|
||||
self._setup_compile()
|
||||
if stage == "test" and (
|
||||
self.trainer.compile and not self._is_compiled()
|
||||
):
|
||||
self._setup_compile()
|
||||
return super().setup(stage)
|
||||
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
"""Module for the Trainer."""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import torch
|
||||
import lightning
|
||||
from .utils import check_consistency, custom_warning_format
|
||||
from .utils import check_consistency
|
||||
from .data import PinaDataModule
|
||||
from .solver import SolverInterface, PINNInterface
|
||||
|
||||
# set the warning for compile options
|
||||
warnings.formatwarning = custom_warning_format
|
||||
warnings.filterwarnings("always", category=UserWarning)
|
||||
|
||||
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
"""
|
||||
@@ -31,7 +26,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=0.0,
|
||||
val_size=0.0,
|
||||
compile=None,
|
||||
batching_mode="common_batch_size",
|
||||
repeat=None,
|
||||
automatic_batching=None,
|
||||
num_workers=None,
|
||||
pin_memory=None,
|
||||
@@ -54,14 +49,11 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset. Default is ``0.0``.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
Default is ``False``. For Windows users, it is always disabled. Not
|
||||
supported for python version greater or equal than 3.14.
|
||||
:param bool common_batch_size: If ``True``, the same batch size is used
|
||||
for all conditions. If ``False``, each condition can have its own
|
||||
batch size, proportional to the size of the dataset in that
|
||||
condition. Default is ``True``.
|
||||
:param bool separate_conditions: If ``True``, dataloaders for each
|
||||
condition are iterated separately. Default is ``False``.
|
||||
Default is ``False``. For Windows users, it is always disabled.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training. For further details, see the
|
||||
:class:`~pina.data.data_module.PinaDataModule` class. Default is
|
||||
``False``.
|
||||
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
||||
is performed, otherwise the items are retrieved from the dataset
|
||||
all at once. For further details, see the
|
||||
@@ -84,7 +76,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size=train_size,
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
batching_mode=batching_mode,
|
||||
repeat=repeat,
|
||||
automatic_batching=automatic_batching,
|
||||
compile=compile,
|
||||
)
|
||||
@@ -112,17 +104,10 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# checking compilation and automatic batching
|
||||
# compilation disabled for Windows and for Python 3.14+
|
||||
if (
|
||||
compile is None
|
||||
or sys.platform == "win32"
|
||||
or sys.version_info >= (3, 14)
|
||||
):
|
||||
if compile is None or sys.platform == "win32":
|
||||
compile = False
|
||||
warnings.warn(
|
||||
"Compilation is disabled for Python 3.14+ and for Windows.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
repeat = repeat if repeat is not None else False
|
||||
|
||||
automatic_batching = (
|
||||
automatic_batching if automatic_batching is not None else False
|
||||
@@ -139,7 +124,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
batch_size=batch_size,
|
||||
batching_mode=batching_mode,
|
||||
repeat=repeat,
|
||||
automatic_batching=automatic_batching,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
@@ -177,7 +162,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size,
|
||||
val_size,
|
||||
batch_size,
|
||||
batching_mode,
|
||||
repeat,
|
||||
automatic_batching,
|
||||
pin_memory,
|
||||
num_workers,
|
||||
@@ -196,10 +181,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param int batch_size: The number of samples per batch to load.
|
||||
:param bool common_batch_size: Whether to use the same batch size for
|
||||
all conditions.
|
||||
:param bool seperate_conditions: Whether to iterate dataloaders for
|
||||
each condition separately.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch.
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
@@ -229,7 +212,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test_size=test_size,
|
||||
val_size=val_size,
|
||||
batch_size=batch_size,
|
||||
batching_mode=batching_mode,
|
||||
repeat=repeat,
|
||||
automatic_batching=automatic_batching,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
@@ -281,7 +264,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
train_size,
|
||||
test_size,
|
||||
val_size,
|
||||
batching_mode,
|
||||
repeat,
|
||||
automatic_batching,
|
||||
compile,
|
||||
):
|
||||
@@ -295,10 +278,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
test dataset.
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param bool common_batch_size: Whether to use the same batch size for
|
||||
all conditions.
|
||||
:param bool seperate_conditions: Whether to iterate dataloaders for
|
||||
each condition separately.
|
||||
:param bool repeat: Whether to repeat the dataset data in each
|
||||
condition during training.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
@@ -308,7 +289,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
check_consistency(train_size, float)
|
||||
check_consistency(test_size, float)
|
||||
check_consistency(val_size, float)
|
||||
check_consistency(batching_mode, str)
|
||||
if repeat is not None:
|
||||
check_consistency(repeat, bool)
|
||||
if automatic_batching is not None:
|
||||
check_consistency(automatic_batching, bool)
|
||||
if compile is not None:
|
||||
@@ -343,23 +325,3 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
if batch_size is not None:
|
||||
check_consistency(batch_size, int)
|
||||
return pin_memory, num_workers, shuffle, batch_size
|
||||
|
||||
@property
|
||||
def compile(self):
|
||||
"""
|
||||
Whether compilation is required or not.
|
||||
|
||||
:return: ``True`` if compilation is required, ``False`` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._compile
|
||||
|
||||
@compile.setter
|
||||
def compile(self, value):
|
||||
"""
|
||||
Setting the value of compile.
|
||||
|
||||
:param bool value: Whether compilation is required or not.
|
||||
"""
|
||||
check_consistency(value, bool)
|
||||
self._compile = value
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "pina-mathlab"
|
||||
version = "0.2.5"
|
||||
version = "0.2.3"
|
||||
description = "Physic Informed Neural networks for Advance modeling."
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"torch_geometric",
|
||||
"matplotlib",
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
[project.optional-dependencies]
|
||||
doc = [
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 411 KiB After Width: | Height: | Size: 51 KiB |
@@ -51,7 +51,7 @@ def test_sample(condition_to_update):
|
||||
}
|
||||
trainer.train()
|
||||
after_n_points = {
|
||||
loc: len(trainer.data_module.train_dataset[loc].input)
|
||||
loc: len(trainer.data_module.train_dataset.input[loc])
|
||||
for loc in condition_to_update
|
||||
}
|
||||
assert before_n_points == trainer.callbacks[0].initial_population_size
|
||||
|
||||
@@ -142,10 +142,14 @@ def test_setup(solver, fn, stage, apply_to):
|
||||
|
||||
for cond in ["data1", "data2"]:
|
||||
scale = scale_fn(
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][
|
||||
apply_to
|
||||
]
|
||||
)
|
||||
shift = shift_fn(
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][
|
||||
apply_to
|
||||
]
|
||||
)
|
||||
assert "scale" in normalizer[cond]
|
||||
assert "shift" in normalizer[cond]
|
||||
@@ -154,8 +158,8 @@ def test_setup(solver, fn, stage, apply_to):
|
||||
for ds_name in stage_map[stage]:
|
||||
dataset = getattr(trainer.data_module, ds_name, None)
|
||||
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
|
||||
current_points = dataset[cond].data[apply_to]
|
||||
old_points = old_dataset[cond].data[apply_to]
|
||||
current_points = dataset.conditions_dict[cond][apply_to]
|
||||
old_points = old_dataset.conditions_dict[cond][apply_to]
|
||||
expected = (old_points - shift) / scale
|
||||
assert torch.allclose(current_points, expected)
|
||||
|
||||
@@ -200,10 +204,10 @@ def test_setup_pinn(fn, stage, apply_to):
|
||||
cond = "data"
|
||||
|
||||
scale = scale_fn(
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
|
||||
)
|
||||
shift = shift_fn(
|
||||
trainer_copy.data_module.train_dataset[cond].data[apply_to]
|
||||
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
|
||||
)
|
||||
assert "scale" in normalizer[cond]
|
||||
assert "shift" in normalizer[cond]
|
||||
@@ -212,8 +216,8 @@ def test_setup_pinn(fn, stage, apply_to):
|
||||
for ds_name in stage_map[stage]:
|
||||
dataset = getattr(trainer.data_module, ds_name, None)
|
||||
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
|
||||
current_points = dataset[cond].data[apply_to]
|
||||
old_points = old_dataset[cond].data[apply_to]
|
||||
current_points = dataset.conditions_dict[cond][apply_to]
|
||||
old_points = old_dataset.conditions_dict[cond][apply_to]
|
||||
expected = (old_points - shift) / scale
|
||||
assert torch.allclose(current_points, expected)
|
||||
|
||||
@@ -238,7 +242,3 @@ def test_setup_graph_dataset():
|
||||
)
|
||||
with pytest.raises(NotImplementedError):
|
||||
trainer.train()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_setup(supervised_solver_lt, [torch.std, torch.mean], "all", "input")
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina.data import PinaDataModule
|
||||
from pina.data.dataset import PinaDataset
|
||||
from pina.data.dataset import PinaTensorDataset, PinaGraphDataset
|
||||
from pina.problem.zoo import SupervisedProblem
|
||||
from pina.graph import RadiusGraph
|
||||
|
||||
from pina.data.dataloader import DummyDataloader, PinaDataLoader
|
||||
from pina.data.data_module import DummyDataloader
|
||||
from pina import Trainer
|
||||
from pina.solver import SupervisedSolver
|
||||
from torch_geometric.data import Batch
|
||||
@@ -45,33 +44,22 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
|
||||
)
|
||||
dm.setup()
|
||||
assert hasattr(dm, "train_dataset")
|
||||
assert isinstance(dm.train_dataset, dict)
|
||||
assert all(
|
||||
isinstance(dm.train_dataset[cond], PinaDataset)
|
||||
for cond in dm.train_dataset
|
||||
)
|
||||
assert all(
|
||||
dm.train_dataset[cond].is_graph_dataset == isinstance(input_, list)
|
||||
for cond in dm.train_dataset
|
||||
)
|
||||
assert all(
|
||||
len(dm.train_dataset[cond]) == int(len(input_) * train_size)
|
||||
for cond in dm.train_dataset
|
||||
)
|
||||
if isinstance(input_, torch.Tensor):
|
||||
assert isinstance(dm.train_dataset, PinaTensorDataset)
|
||||
else:
|
||||
assert isinstance(dm.train_dataset, PinaGraphDataset)
|
||||
# assert len(dm.train_dataset) == int(len(input_) * train_size)
|
||||
if test_size > 0:
|
||||
assert hasattr(dm, "test_dataset")
|
||||
assert dm.test_dataset is None
|
||||
else:
|
||||
assert not hasattr(dm, "test_dataset")
|
||||
assert hasattr(dm, "val_dataset")
|
||||
|
||||
assert isinstance(dm.val_dataset, dict)
|
||||
assert all(
|
||||
isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
|
||||
)
|
||||
assert all(
|
||||
isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
|
||||
)
|
||||
if isinstance(input_, torch.Tensor):
|
||||
assert isinstance(dm.val_dataset, PinaTensorDataset)
|
||||
else:
|
||||
assert isinstance(dm.val_dataset, PinaGraphDataset)
|
||||
# assert len(dm.val_dataset) == int(len(input_) * val_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -99,59 +87,49 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
|
||||
assert not hasattr(dm, "val_dataset")
|
||||
|
||||
assert hasattr(dm, "test_dataset")
|
||||
assert all(
|
||||
isinstance(dm.test_dataset[cond], PinaDataset)
|
||||
for cond in dm.test_dataset
|
||||
)
|
||||
assert all(
|
||||
dm.test_dataset[cond].is_graph_dataset == isinstance(input_, list)
|
||||
for cond in dm.test_dataset
|
||||
)
|
||||
assert all(
|
||||
len(dm.test_dataset[cond]) == int(len(input_) * test_size)
|
||||
for cond in dm.test_dataset
|
||||
if isinstance(input_, torch.Tensor):
|
||||
assert isinstance(dm.test_dataset, PinaTensorDataset)
|
||||
else:
|
||||
assert isinstance(dm.test_dataset, PinaGraphDataset)
|
||||
# assert len(dm.test_dataset) == int(len(input_) * test_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_, output_",
|
||||
[(input_tensor, output_tensor), (input_graph, output_graph)],
|
||||
)
|
||||
def test_dummy_dataloader(input_, output_):
|
||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||
trainer = Trainer(
|
||||
solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
|
||||
)
|
||||
dm = trainer.data_module
|
||||
dm.setup()
|
||||
dm.trainer = trainer
|
||||
dataloader = dm.train_dataloader()
|
||||
assert isinstance(dataloader, DummyDataloader)
|
||||
assert len(dataloader) == 1
|
||||
data = next(dataloader)
|
||||
assert isinstance(data, list)
|
||||
assert isinstance(data[0], tuple)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data[0][1]["input"], Batch)
|
||||
else:
|
||||
assert isinstance(data[0][1]["input"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["target"], torch.Tensor)
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "input_, output_",
|
||||
# [(input_tensor, output_tensor), (input_graph, output_graph)],
|
||||
# )
|
||||
# def test_dummy_dataloader(input_, output_):
|
||||
# problem = SupervisedProblem(input_=input_, output_=output_)
|
||||
# solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||
# trainer = Trainer(
|
||||
# solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
|
||||
# )
|
||||
# dm = trainer.data_module
|
||||
# dm.setup()
|
||||
# dm.trainer = trainer
|
||||
# dataloader = dm.train_dataloader()
|
||||
# assert isinstance(dataloader, PinaDataLoader)
|
||||
# print(dataloader.dataloaders)
|
||||
# assert all([isinstance(ds, DummyDataloader) for ds in dataloader.dataloaders.values()])
|
||||
|
||||
# data = next(iter(dataloader))
|
||||
# assert isinstance(data, list)
|
||||
# assert isinstance(data[0], tuple)
|
||||
# if isinstance(input_, list):
|
||||
# assert isinstance(data[0][1]["input"], Batch)
|
||||
# else:
|
||||
# assert isinstance(data[0][1]["input"], torch.Tensor)
|
||||
# assert isinstance(data[0][1]["target"], torch.Tensor)
|
||||
|
||||
|
||||
# dataloader = dm.val_dataloader()
|
||||
# assert isinstance(dataloader, DummyDataloader)
|
||||
# assert len(dataloader) == 1
|
||||
# data = next(dataloader)
|
||||
# assert isinstance(data, list)
|
||||
# assert isinstance(data[0], tuple)
|
||||
# if isinstance(input_, list):
|
||||
# assert isinstance(data[0][1]["input"], Batch)
|
||||
# else:
|
||||
# assert isinstance(data[0][1]["input"], torch.Tensor)
|
||||
# assert isinstance(data[0][1]["target"], torch.Tensor)
|
||||
dataloader = dm.val_dataloader()
|
||||
assert isinstance(dataloader, DummyDataloader)
|
||||
assert len(dataloader) == 1
|
||||
data = next(dataloader)
|
||||
assert isinstance(data, list)
|
||||
assert isinstance(data[0], tuple)
|
||||
if isinstance(input_, list):
|
||||
assert isinstance(data[0][1]["input"], Batch)
|
||||
else:
|
||||
assert isinstance(data[0][1]["input"], torch.Tensor)
|
||||
assert isinstance(data[0][1]["target"], torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -159,11 +137,7 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
|
||||
[(input_tensor, output_tensor), (input_graph, output_graph)],
|
||||
)
|
||||
@pytest.mark.parametrize("automatic_batching", [True, False])
|
||||
@pytest.mark.parametrize("batch_size", [None, 10])
|
||||
@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"])
|
||||
def test_dataloader(
|
||||
input_, output_, automatic_batching, batch_size, batching_mode
|
||||
):
|
||||
def test_dataloader(input_, output_, automatic_batching):
|
||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||
trainer = Trainer(
|
||||
@@ -173,13 +147,12 @@ def test_dataloader(
|
||||
val_size=0.3,
|
||||
test_size=0.0,
|
||||
automatic_batching=automatic_batching,
|
||||
batching_mode=batching_mode,
|
||||
)
|
||||
dm = trainer.data_module
|
||||
dm.setup()
|
||||
dm.trainer = trainer
|
||||
dataloader = dm.train_dataloader()
|
||||
assert isinstance(dataloader, PinaDataLoader)
|
||||
assert isinstance(dataloader, DataLoader)
|
||||
assert len(dataloader) == 7
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
@@ -190,8 +163,8 @@ def test_dataloader(
|
||||
assert isinstance(data["data"]["target"], torch.Tensor)
|
||||
|
||||
dataloader = dm.val_dataloader()
|
||||
assert isinstance(dataloader, PinaDataLoader)
|
||||
assert len(dataloader) == 3 if batch_size is not None else 1
|
||||
assert isinstance(dataloader, DataLoader)
|
||||
assert len(dataloader) == 3
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
if isinstance(input_, list):
|
||||
@@ -229,13 +202,12 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
||||
val_size=0.3,
|
||||
test_size=0.0,
|
||||
automatic_batching=automatic_batching,
|
||||
# common_batch_size=True,
|
||||
)
|
||||
dm = trainer.data_module
|
||||
dm.setup()
|
||||
dm.trainer = trainer
|
||||
dataloader = dm.train_dataloader()
|
||||
assert isinstance(dataloader, PinaDataLoader)
|
||||
assert isinstance(dataloader, DataLoader)
|
||||
assert len(dataloader) == 7
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
@@ -251,7 +223,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
||||
assert data["data"]["target"].labels == ["u", "v", "w"]
|
||||
|
||||
dataloader = dm.val_dataloader()
|
||||
assert isinstance(dataloader, PinaDataLoader)
|
||||
assert isinstance(dataloader, DataLoader)
|
||||
assert len(dataloader) == 3
|
||||
data = next(iter(dataloader))
|
||||
assert isinstance(data, dict)
|
||||
@@ -268,6 +240,39 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
||||
assert data["data"]["target"].labels == ["u", "v", "w"]
|
||||
|
||||
|
||||
def test_get_all_data():
|
||||
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
|
||||
target = input
|
||||
|
||||
problem = SupervisedProblem(input, target)
|
||||
datamodule = PinaDataModule(
|
||||
problem,
|
||||
train_size=0.7,
|
||||
test_size=0.2,
|
||||
val_size=0.1,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
)
|
||||
datamodule.setup("fit")
|
||||
datamodule.setup("test")
|
||||
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
|
||||
assert torch.isclose(
|
||||
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
|
||||
).all()
|
||||
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
|
||||
assert torch.isclose(
|
||||
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
|
||||
).all()
|
||||
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
|
||||
assert torch.isclose(
|
||||
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
|
||||
).all()
|
||||
|
||||
|
||||
def test_input_propery_tensor():
|
||||
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
|
||||
target = input
|
||||
@@ -280,6 +285,7 @@ def test_input_propery_tensor():
|
||||
val_size=0.1,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
@@ -305,6 +311,7 @@ def test_input_propery_graph():
|
||||
val_size=0.1,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
|
||||
@@ -1,138 +1,138 @@
|
||||
# import torch
|
||||
# import pytest
|
||||
# from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset
|
||||
# from pina.graph import KNNGraph
|
||||
# from torch_geometric.data import Data
|
||||
import torch
|
||||
import pytest
|
||||
from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset
|
||||
from pina.graph import KNNGraph
|
||||
from torch_geometric.data import Data
|
||||
|
||||
# x = torch.rand((100, 20, 10))
|
||||
# pos = torch.rand((100, 20, 2))
|
||||
# input_ = [
|
||||
# KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
|
||||
# for x_, pos_ in zip(x, pos)
|
||||
# ]
|
||||
# output_ = torch.rand((100, 20, 10))
|
||||
x = torch.rand((100, 20, 10))
|
||||
pos = torch.rand((100, 20, 2))
|
||||
input_ = [
|
||||
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
|
||||
for x_, pos_ in zip(x, pos)
|
||||
]
|
||||
output_ = torch.rand((100, 20, 10))
|
||||
|
||||
# x_2 = torch.rand((50, 20, 10))
|
||||
# pos_2 = torch.rand((50, 20, 2))
|
||||
# input_2_ = [
|
||||
# KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
|
||||
# for x_, pos_ in zip(x_2, pos_2)
|
||||
# ]
|
||||
# output_2_ = torch.rand((50, 20, 10))
|
||||
x_2 = torch.rand((50, 20, 10))
|
||||
pos_2 = torch.rand((50, 20, 2))
|
||||
input_2_ = [
|
||||
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
|
||||
for x_, pos_ in zip(x_2, pos_2)
|
||||
]
|
||||
output_2_ = torch.rand((50, 20, 10))
|
||||
|
||||
|
||||
# # Problem with a single condition
|
||||
# conditions_dict_single = {
|
||||
# "data": {
|
||||
# "input": input_,
|
||||
# "target": output_,
|
||||
# }
|
||||
# }
|
||||
# max_conditions_lengths_single = {"data": 100}
|
||||
# Problem with a single condition
|
||||
conditions_dict_single = {
|
||||
"data": {
|
||||
"input": input_,
|
||||
"target": output_,
|
||||
}
|
||||
}
|
||||
max_conditions_lengths_single = {"data": 100}
|
||||
|
||||
# # Problem with multiple conditions
|
||||
# conditions_dict_multi = {
|
||||
# "data_1": {
|
||||
# "input": input_,
|
||||
# "target": output_,
|
||||
# },
|
||||
# "data_2": {
|
||||
# "input": input_2_,
|
||||
# "target": output_2_,
|
||||
# },
|
||||
# }
|
||||
# Problem with multiple conditions
|
||||
conditions_dict_multi = {
|
||||
"data_1": {
|
||||
"input": input_,
|
||||
"target": output_,
|
||||
},
|
||||
"data_2": {
|
||||
"input": input_2_,
|
||||
"target": output_2_,
|
||||
},
|
||||
}
|
||||
|
||||
# max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
|
||||
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "conditions_dict, max_conditions_lengths",
|
||||
# [
|
||||
# (conditions_dict_single, max_conditions_lengths_single),
|
||||
# (conditions_dict_multi, max_conditions_lengths_multi),
|
||||
# ],
|
||||
# )
|
||||
# def test_constructor(conditions_dict, max_conditions_lengths):
|
||||
# dataset = PinaDatasetFactory(
|
||||
# conditions_dict,
|
||||
# max_conditions_lengths=max_conditions_lengths,
|
||||
# automatic_batching=True,
|
||||
# )
|
||||
# assert isinstance(dataset, PinaGraphDataset)
|
||||
# assert len(dataset) == 100
|
||||
@pytest.mark.parametrize(
|
||||
"conditions_dict, max_conditions_lengths",
|
||||
[
|
||||
(conditions_dict_single, max_conditions_lengths_single),
|
||||
(conditions_dict_multi, max_conditions_lengths_multi),
|
||||
],
|
||||
)
|
||||
def test_constructor(conditions_dict, max_conditions_lengths):
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict,
|
||||
max_conditions_lengths=max_conditions_lengths,
|
||||
automatic_batching=True,
|
||||
)
|
||||
assert isinstance(dataset, PinaGraphDataset)
|
||||
assert len(dataset) == 100
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "conditions_dict, max_conditions_lengths",
|
||||
# [
|
||||
# (conditions_dict_single, max_conditions_lengths_single),
|
||||
# (conditions_dict_multi, max_conditions_lengths_multi),
|
||||
# ],
|
||||
# )
|
||||
# def test_getitem(conditions_dict, max_conditions_lengths):
|
||||
# dataset = PinaDatasetFactory(
|
||||
# conditions_dict,
|
||||
# max_conditions_lengths=max_conditions_lengths,
|
||||
# automatic_batching=True,
|
||||
# )
|
||||
# data = dataset[50]
|
||||
# assert isinstance(data, dict)
|
||||
# assert all([isinstance(d["input"], Data) for d in data.values()])
|
||||
# assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
|
||||
# assert all(
|
||||
# [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()]
|
||||
# )
|
||||
# assert all(
|
||||
# [d["target"].shape == torch.Size((20, 10)) for d in data.values()]
|
||||
# )
|
||||
# assert all(
|
||||
# [
|
||||
# d["input"].edge_index.shape == torch.Size((2, 60))
|
||||
# for d in data.values()
|
||||
# ]
|
||||
# )
|
||||
# assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()])
|
||||
@pytest.mark.parametrize(
|
||||
"conditions_dict, max_conditions_lengths",
|
||||
[
|
||||
(conditions_dict_single, max_conditions_lengths_single),
|
||||
(conditions_dict_multi, max_conditions_lengths_multi),
|
||||
],
|
||||
)
|
||||
def test_getitem(conditions_dict, max_conditions_lengths):
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict,
|
||||
max_conditions_lengths=max_conditions_lengths,
|
||||
automatic_batching=True,
|
||||
)
|
||||
data = dataset[50]
|
||||
assert isinstance(data, dict)
|
||||
assert all([isinstance(d["input"], Data) for d in data.values()])
|
||||
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
|
||||
assert all(
|
||||
[d["input"].x.shape == torch.Size((20, 10)) for d in data.values()]
|
||||
)
|
||||
assert all(
|
||||
[d["target"].shape == torch.Size((20, 10)) for d in data.values()]
|
||||
)
|
||||
assert all(
|
||||
[
|
||||
d["input"].edge_index.shape == torch.Size((2, 60))
|
||||
for d in data.values()
|
||||
]
|
||||
)
|
||||
assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()])
|
||||
|
||||
# data = dataset.fetch_from_idx_list([i for i in range(20)])
|
||||
# assert isinstance(data, dict)
|
||||
# assert all([isinstance(d["input"], Data) for d in data.values()])
|
||||
# assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
|
||||
# assert all(
|
||||
# [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
|
||||
# )
|
||||
# assert all(
|
||||
# [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
|
||||
# )
|
||||
# assert all(
|
||||
# [
|
||||
# d["input"].edge_index.shape == torch.Size((2, 1200))
|
||||
# for d in data.values()
|
||||
# ]
|
||||
# )
|
||||
# assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
|
||||
data = dataset.fetch_from_idx_list([i for i in range(20)])
|
||||
assert isinstance(data, dict)
|
||||
assert all([isinstance(d["input"], Data) for d in data.values()])
|
||||
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
|
||||
assert all(
|
||||
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
|
||||
)
|
||||
assert all(
|
||||
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
|
||||
)
|
||||
assert all(
|
||||
[
|
||||
d["input"].edge_index.shape == torch.Size((2, 1200))
|
||||
for d in data.values()
|
||||
]
|
||||
)
|
||||
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
|
||||
|
||||
|
||||
# def test_input_single_condition():
|
||||
# dataset = PinaDatasetFactory(
|
||||
# conditions_dict_single,
|
||||
# max_conditions_lengths=max_conditions_lengths_single,
|
||||
# automatic_batching=True,
|
||||
# )
|
||||
# input_ = dataset.input
|
||||
# assert isinstance(input_, dict)
|
||||
# assert isinstance(input_["data"], list)
|
||||
# assert all([isinstance(d, Data) for d in input_["data"]])
|
||||
def test_input_single_condition():
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict_single,
|
||||
max_conditions_lengths=max_conditions_lengths_single,
|
||||
automatic_batching=True,
|
||||
)
|
||||
input_ = dataset.input
|
||||
assert isinstance(input_, dict)
|
||||
assert isinstance(input_["data"], list)
|
||||
assert all([isinstance(d, Data) for d in input_["data"]])
|
||||
|
||||
|
||||
# def test_input_multi_condition():
|
||||
# dataset = PinaDatasetFactory(
|
||||
# conditions_dict_multi,
|
||||
# max_conditions_lengths=max_conditions_lengths_multi,
|
||||
# automatic_batching=True,
|
||||
# )
|
||||
# input_ = dataset.input
|
||||
# assert isinstance(input_, dict)
|
||||
# assert isinstance(input_["data_1"], list)
|
||||
# assert all([isinstance(d, Data) for d in input_["data_1"]])
|
||||
# assert isinstance(input_["data_2"], list)
|
||||
# assert all([isinstance(d, Data) for d in input_["data_2"]])
|
||||
def test_input_multi_condition():
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict_multi,
|
||||
max_conditions_lengths=max_conditions_lengths_multi,
|
||||
automatic_batching=True,
|
||||
)
|
||||
input_ = dataset.input
|
||||
assert isinstance(input_, dict)
|
||||
assert isinstance(input_["data_1"], list)
|
||||
assert all([isinstance(d, Data) for d in input_["data_1"]])
|
||||
assert isinstance(input_["data_2"], list)
|
||||
assert all([isinstance(d, Data) for d in input_["data_2"]])
|
||||
|
||||
@@ -1,86 +1,86 @@
|
||||
# import torch
|
||||
# import pytest
|
||||
# from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset
|
||||
import torch
|
||||
import pytest
|
||||
from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset
|
||||
|
||||
# input_tensor = torch.rand((100, 10))
|
||||
# output_tensor = torch.rand((100, 2))
|
||||
input_tensor = torch.rand((100, 10))
|
||||
output_tensor = torch.rand((100, 2))
|
||||
|
||||
# input_tensor_2 = torch.rand((50, 10))
|
||||
# output_tensor_2 = torch.rand((50, 2))
|
||||
input_tensor_2 = torch.rand((50, 10))
|
||||
output_tensor_2 = torch.rand((50, 2))
|
||||
|
||||
# conditions_dict_single = {
|
||||
# "data": {
|
||||
# "input": input_tensor,
|
||||
# "target": output_tensor,
|
||||
# }
|
||||
# }
|
||||
conditions_dict_single = {
|
||||
"data": {
|
||||
"input": input_tensor,
|
||||
"target": output_tensor,
|
||||
}
|
||||
}
|
||||
|
||||
# conditions_dict_single_multi = {
|
||||
# "data_1": {
|
||||
# "input": input_tensor,
|
||||
# "target": output_tensor,
|
||||
# },
|
||||
# "data_2": {
|
||||
# "input": input_tensor_2,
|
||||
# "target": output_tensor_2,
|
||||
# },
|
||||
# }
|
||||
conditions_dict_single_multi = {
|
||||
"data_1": {
|
||||
"input": input_tensor,
|
||||
"target": output_tensor,
|
||||
},
|
||||
"data_2": {
|
||||
"input": input_tensor_2,
|
||||
"target": output_tensor_2,
|
||||
},
|
||||
}
|
||||
|
||||
# max_conditions_lengths_single = {"data": 100}
|
||||
max_conditions_lengths_single = {"data": 100}
|
||||
|
||||
# max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
|
||||
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "conditions_dict, max_conditions_lengths",
|
||||
# [
|
||||
# (conditions_dict_single, max_conditions_lengths_single),
|
||||
# (conditions_dict_single_multi, max_conditions_lengths_multi),
|
||||
# ],
|
||||
# )
|
||||
# def test_constructor_tensor(conditions_dict, max_conditions_lengths):
|
||||
# dataset = PinaDatasetFactory(
|
||||
# conditions_dict,
|
||||
# max_conditions_lengths=max_conditions_lengths,
|
||||
# automatic_batching=True,
|
||||
# )
|
||||
# assert isinstance(dataset, PinaTensorDataset)
|
||||
@pytest.mark.parametrize(
|
||||
"conditions_dict, max_conditions_lengths",
|
||||
[
|
||||
(conditions_dict_single, max_conditions_lengths_single),
|
||||
(conditions_dict_single_multi, max_conditions_lengths_multi),
|
||||
],
|
||||
)
|
||||
def test_constructor_tensor(conditions_dict, max_conditions_lengths):
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict,
|
||||
max_conditions_lengths=max_conditions_lengths,
|
||||
automatic_batching=True,
|
||||
)
|
||||
assert isinstance(dataset, PinaTensorDataset)
|
||||
|
||||
|
||||
# def test_getitem_single():
|
||||
# dataset = PinaDatasetFactory(
|
||||
# conditions_dict_single,
|
||||
# max_conditions_lengths=max_conditions_lengths_single,
|
||||
# automatic_batching=False,
|
||||
# )
|
||||
def test_getitem_single():
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict_single,
|
||||
max_conditions_lengths=max_conditions_lengths_single,
|
||||
automatic_batching=False,
|
||||
)
|
||||
|
||||
# tensors = dataset.fetch_from_idx_list([i for i in range(70)])
|
||||
# assert isinstance(tensors, dict)
|
||||
# assert list(tensors.keys()) == ["data"]
|
||||
# assert sorted(list(tensors["data"].keys())) == ["input", "target"]
|
||||
# assert isinstance(tensors["data"]["input"], torch.Tensor)
|
||||
# assert tensors["data"]["input"].shape == torch.Size((70, 10))
|
||||
# assert isinstance(tensors["data"]["target"], torch.Tensor)
|
||||
# assert tensors["data"]["target"].shape == torch.Size((70, 2))
|
||||
tensors = dataset.fetch_from_idx_list([i for i in range(70)])
|
||||
assert isinstance(tensors, dict)
|
||||
assert list(tensors.keys()) == ["data"]
|
||||
assert sorted(list(tensors["data"].keys())) == ["input", "target"]
|
||||
assert isinstance(tensors["data"]["input"], torch.Tensor)
|
||||
assert tensors["data"]["input"].shape == torch.Size((70, 10))
|
||||
assert isinstance(tensors["data"]["target"], torch.Tensor)
|
||||
assert tensors["data"]["target"].shape == torch.Size((70, 2))
|
||||
|
||||
|
||||
# def test_getitem_multi():
|
||||
# dataset = PinaDatasetFactory(
|
||||
# conditions_dict_single_multi,
|
||||
# max_conditions_lengths=max_conditions_lengths_multi,
|
||||
# automatic_batching=False,
|
||||
# )
|
||||
# tensors = dataset.fetch_from_idx_list([i for i in range(70)])
|
||||
# assert isinstance(tensors, dict)
|
||||
# assert list(tensors.keys()) == ["data_1", "data_2"]
|
||||
# assert sorted(list(tensors["data_1"].keys())) == ["input", "target"]
|
||||
# assert isinstance(tensors["data_1"]["input"], torch.Tensor)
|
||||
# assert tensors["data_1"]["input"].shape == torch.Size((70, 10))
|
||||
# assert isinstance(tensors["data_1"]["target"], torch.Tensor)
|
||||
# assert tensors["data_1"]["target"].shape == torch.Size((70, 2))
|
||||
def test_getitem_multi():
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict_single_multi,
|
||||
max_conditions_lengths=max_conditions_lengths_multi,
|
||||
automatic_batching=False,
|
||||
)
|
||||
tensors = dataset.fetch_from_idx_list([i for i in range(70)])
|
||||
assert isinstance(tensors, dict)
|
||||
assert list(tensors.keys()) == ["data_1", "data_2"]
|
||||
assert sorted(list(tensors["data_1"].keys())) == ["input", "target"]
|
||||
assert isinstance(tensors["data_1"]["input"], torch.Tensor)
|
||||
assert tensors["data_1"]["input"].shape == torch.Size((70, 10))
|
||||
assert isinstance(tensors["data_1"]["target"], torch.Tensor)
|
||||
assert tensors["data_1"]["target"].shape == torch.Size((70, 2))
|
||||
|
||||
# assert sorted(list(tensors["data_2"].keys())) == ["input", "target"]
|
||||
# assert isinstance(tensors["data_2"]["input"], torch.Tensor)
|
||||
# assert tensors["data_2"]["input"].shape == torch.Size((50, 10))
|
||||
# assert isinstance(tensors["data_2"]["target"], torch.Tensor)
|
||||
# assert tensors["data_2"]["target"].shape == torch.Size((50, 2))
|
||||
assert sorted(list(tensors["data_2"].keys())) == ["input", "target"]
|
||||
assert isinstance(tensors["data_2"]["input"], torch.Tensor)
|
||||
assert tensors["data_2"]["input"].shape == torch.Size((50, 10))
|
||||
assert isinstance(tensors["data_2"]["target"], torch.Tensor)
|
||||
assert tensors["data_2"]["target"].shape == torch.Size((50, 2))
|
||||
|
||||
@@ -104,7 +104,7 @@ def test_advection_equation(c):
|
||||
|
||||
# Should fail if c is a list and its length != spatial dimension
|
||||
with pytest.raises(ValueError):
|
||||
equation = Advection([1, 2, 3])
|
||||
Advection([1, 2, 3])
|
||||
residual = equation.residual(pts, u)
|
||||
|
||||
|
||||
|
||||
@@ -117,10 +117,6 @@ def test_solver_train(use_lt, batch_size, compile):
|
||||
assert isinstance(solver.model, OptimizedModule)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_solver_train(use_lt=True, batch_size=20, compile=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||
@pytest.mark.parametrize("use_lt", [True, False])
|
||||
def test_solver_train_graph(batch_size, use_lt):
|
||||
|
||||
BIN
tutorials/static/pina_logo.png
vendored
BIN
tutorials/static/pina_logo.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 411 KiB After Width: | Height: | Size: 51 KiB |
Reference in New Issue
Block a user