1 Commits

Author SHA1 Message Date
FilippoOlivo
40747c56ff fix codacy warnings 2025-10-21 14:21:09 +02:00
37 changed files with 977 additions and 1004 deletions

View File

@@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
os: [windows-latest, macos-latest, ubuntu-latest]
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
python-version: [3.9, '3.10', '3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v2
- name: Set up Python

View File

@@ -1,7 +1,7 @@
name: "Testing Pull Request"
on:
pull_request:
pull_request_target:
branches:
- "master"
- "dev"
@@ -13,7 +13,7 @@ jobs:
fail-fast: false
matrix:
os: [windows-latest, macos-latest, ubuntu-latest]
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
python-version: [3.9, '3.10', '3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v4

View File

@@ -23,7 +23,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: 3.9
- name: Install dependencies
run: |
@@ -91,7 +91,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: 3.9
- name: Install dependencies
run: |

View File

@@ -7,8 +7,8 @@ SPDX-License-Identifier: Apache-2.0
<table>
<tr>
<td>
<a href="readme/pina_logo.png">
<img src="readme/pina_logo.png"
<a href="https://github.com/mathLab/PINA/raw/master/readme/pina_logo.png">
<img src="https://github.com/mathLab/PINA/raw/master/readme/pina_logo.png"
alt="PINA logo"
style="width: 220px; aspect-ratio: 1 / 1; object-fit: contain;">
</a>

View File

@@ -26,7 +26,6 @@ Trainer, Dataset and Datamodule
Trainer <trainer.rst>
Dataset <data/dataset.rst>
DataModule <data/data_module.rst>
Dataloader <data/dataloader.rst>
Data Types
------------

View File

@@ -2,6 +2,14 @@ DataModule
======================
.. currentmodule:: pina.data.data_module
.. autoclass:: Collator
:members:
:show-inheritance:
.. autoclass:: PinaDataModule
:members:
:show-inheritance:
.. autoclass:: PinaSampler
:members:
:show-inheritance:

View File

@@ -1,11 +0,0 @@
Dataloader
======================
.. currentmodule:: pina.data.dataloader
.. autoclass:: PinaSampler
:members:
:show-inheritance:
.. autoclass:: PinaDataLoader
:members:
:show-inheritance:

View File

@@ -9,3 +9,11 @@ Dataset
.. autoclass:: PinaDatasetFactory
:members:
:show-inheritance:
.. autoclass:: PinaGraphDataset
:members:
:show-inheritance:
.. autoclass:: PinaTensorDataset
:members:
:show-inheritance:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 411 KiB

After

Width:  |  Height:  |  Size: 177 KiB

View File

@@ -5,6 +5,7 @@ from lightning.pytorch import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset
class NormalizerDataCallback(Callback):
@@ -121,10 +122,7 @@ class NormalizerDataCallback(Callback):
"""
# Ensure datsets are not graph-based
if any(
ds.is_graph_dataset
for ds in trainer.datamodule.train_dataset.values()
):
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
raise NotImplementedError(
"NormalizerDataCallback is not compatible with "
"graph-based datasets."
@@ -166,8 +164,8 @@ class NormalizerDataCallback(Callback):
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
"""
for cond in conditions:
if cond in dataset:
data = dataset[cond].data[self.apply_to]
if cond in dataset.conditions_dict:
data = dataset.conditions_dict[cond][self.apply_to]
shift = self.shift_fn(data)
scale = self.scale_fn(data)
self._normalizer[cond] = {
@@ -199,20 +197,25 @@ class NormalizerDataCallback(Callback):
:param PinaDataset dataset: The dataset to be normalized.
"""
# Initialize update dictionary
update_dataset_dict = {}
# Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items():
update_dataset_dict = {}
points = dataset[cond].data[self.apply_to]
points = dataset.conditions_dict[cond][self.apply_to]
scale = norm_params["scale"]
shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[self.apply_to] = (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
dataset[cond].data.update(update_dataset_dict)
update_dataset_dict[cond] = {
self.apply_to: (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
}
# Update the dataset in-place
dataset.update_data(update_dataset_dict)
@property
def normalizer(self):

View File

@@ -133,12 +133,13 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:param PINNInterface solver: The solver object.
"""
new_points = {}
for name in self._condition_to_update:
new_points = {}
current_points = self.dataset[name].data["input"]
new_points["input"] = self.sample(current_points, name, solver)
self.dataset[name].update_data(new_points)
current_points = self.dataset.conditions_dict[name]["input"]
new_points[name] = {
"input": self.sample(current_points, name, solver)
}
self.dataset.update_data(new_points)
def _compute_population_size(self, conditions):
"""
@@ -149,5 +150,6 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:rtype: dict
"""
return {
cond: len(self.dataset[cond].data["input"]) for cond in conditions
cond: len(self.dataset.conditions_dict[cond]["input"])
for cond in conditions
}

View File

@@ -4,3 +4,4 @@ __all__ = ["PinaDataModule", "PinaDataset"]
from .data_module import PinaDataModule
from .dataset import PinaDataset

View File

@@ -7,9 +7,232 @@ different types of Datasets defined in PINA.
import warnings
from lightning.pytorch import LightningDataModule
import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory
from .dataloader import PinaDataLoader
from .dataset import PinaDatasetFactory, PinaTensorDataset
class DummyDataloader:
def __init__(self, dataset):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
self.dataset = dataset.fetch_from_idx_list(idx)
else:
self.dataset = dataset.get_all_data()
def __iter__(self):
return self
def __len__(self):
return 1
def __next__(self):
return self.dataset
class Collator:
"""
This callable class is used to collate the data points fetched from the
dataset. The collation is performed based on the type of dataset used and
on the batching strategy.
"""
def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for
each condition.
:param bool automatic_batching: Whether automatic PyTorch batching is
enabled or not. For more information, see the
:class:`~pina.data.data_module.PinaDataModule` class.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = (
self._collate_torch_dataloader
if automatic_batching
else (self._collate_pina_dataloader)
)
self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset
else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset
def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list[int] batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch)
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
:param list[dict] batch: List of retrieved data.
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
batch_dict = {}
if isinstance(batch, dict):
return batch
conditions_names = batch[0].keys()
# Condition names
for condition_name in conditions_names:
single_cond_dict = {}
condition_args = batch[0][condition_name].keys()
for arg in condition_args:
data_list = [
batch[idx][condition_name][arg]
for idx in range(
min(
len(batch),
self.max_conditions_lengths[condition_name],
)
)
]
single_cond_dict[arg] = self._collate(data_list)
batch_dict[condition_name] = single_cond_dict
return batch_dict
@staticmethod
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param data_list: Elements to be collated.
:type data_list: list[torch.Tensor] | list[LabelTensor]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.stack(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
"""
Function used to collate data when the dataset is a
:class:`~pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated.
:type data_list: list[Data] | list[Graph]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_batch(data_list)
raise RuntimeError(
"Data must be Tensors or LabelTensor or pyG "
"torch_geometric.data.Data"
)
def __call__(self, batch):
"""
Perform the collation of data fetched from the dataset. The behavoior
of the function is set based on the batching strategy during class
initialization.
:param batch: List of retrieved data or sampled indices.
:type batch: list[int] | list[dict]
:return: Dictionary containing colleted data fetched from the dataset.
:rtype: dict
"""
return self.callable_function(batch)
class PinaSampler:
"""
This class is used to create the sampler instance based on the shuffle
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset):
"""
Instantiate and initialize the sampler.
:param PinaDataset dataset: The dataset from which to sample.
:return: The sampler instance.
:rtype: :class:`torch.utils.data.Sampler`
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
sampler = DistributedSampler(dataset)
else:
sampler = SequentialSampler(dataset)
return sampler
class PinaDataModule(LightningDataModule):
@@ -27,7 +250,7 @@ class PinaDataModule(LightningDataModule):
val_size=0.1,
batch_size=None,
shuffle=True,
batching_mode="common_batch_size",
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
@@ -48,12 +271,11 @@ class PinaDataModule(LightningDataModule):
Default is ``None``.
:param bool shuffle: Whether to shuffle the dataset before splitting.
Default ``True``.
:param bool common_batch_size: If ``True``, the same batch size is used
for all conditions. If ``False``, each condition can have its own
batch size, proportional to the size of the dataset in that
condition. Default is ``True``.
:param bool separate_conditions: If ``True``, dataloaders for each
condition are iterated separately. Default is ``False``.
:param bool repeat: If ``True``, in case of batch size larger than the
number of elements in a specific condition, the elements are
repeated until the batch size is reached. If ``False``, the number
of elements in the batch is the minimum between the batch size and
the number of elements in the condition. Default is ``False``.
:param automatic_batching: If ``True``, automatic PyTorch batching
is performed, which consists of extracting one element at a time
from the dataset and collating them into a batch. This is useful
@@ -83,7 +305,7 @@ class PinaDataModule(LightningDataModule):
# Store fixed attributes
self.batch_size = batch_size
self.shuffle = shuffle
self.batching_mode = batching_mode
self.repeat = repeat
self.automatic_batching = automatic_batching
# If batch size is None, num_workers has no effect
@@ -154,16 +376,23 @@ class PinaDataModule(LightningDataModule):
if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory(
self.data_splits["train"],
max_conditions_lengths=self.find_max_conditions_lengths(
"train"
),
automatic_batching=self.automatic_batching,
)
if "val" in self.data_splits.keys():
self.val_dataset = PinaDatasetFactory(
self.data_splits["val"],
max_conditions_lengths=self.find_max_conditions_lengths(
"val"
),
automatic_batching=self.automatic_batching,
)
elif stage == "test":
self.test_dataset = PinaDatasetFactory(
self.data_splits["test"],
max_conditions_lengths=self.find_max_conditions_lengths("test"),
automatic_batching=self.automatic_batching,
)
else:
@@ -253,7 +482,7 @@ class PinaDataModule(LightningDataModule):
dataset_dict[key].update({condition_name: data})
return dataset_dict
def _create_dataloader(self, dataset):
def _create_dataloader(self, split, dataset):
""" "
Create the dataloader for the given split.
@@ -273,18 +502,53 @@ class PinaDataModule(LightningDataModule):
),
module="lightning.pytorch.trainer.connectors.data_connector",
)
dl = PinaDataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
batching_mode=self.batching_mode,
device=self.trainer.strategy.root_device,
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
sampler = PinaSampler(dataset)
if self.automatic_batching:
collate = Collator(
self.find_max_conditions_lengths(split),
self.automatic_batching,
dataset=dataset,
)
else:
collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader(
dataset,
self.batch_size,
collate_fn=collate,
sampler=sampler,
num_workers=self.num_workers,
)
dataloader = DummyDataloader(dataset)
dataloader.dataset = self._transfer_batch_to_device(
dataloader.dataset, self.trainer.strategy.root_device, 0
)
if self.batch_size is None:
# Override the method to transfer the batch to the device
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dl
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader
def find_max_conditions_lengths(self, split):
"""
Define the maximum length for each conditions.
:param dict split: The split of the dataset.
:return: The maximum length per condition.
:rtype: dict
"""
max_conditions_lengths = {}
for k, v in self.data_splits[split].items():
if self.batch_size is None:
max_conditions_lengths[k] = len(v["input"])
elif self.repeat:
max_conditions_lengths[k] = self.batch_size
else:
max_conditions_lengths[k] = min(
len(v["input"]), self.batch_size
)
return max_conditions_lengths
def val_dataloader(self):
"""
@@ -293,7 +557,7 @@ class PinaDataModule(LightningDataModule):
:return: The validation dataloader
:rtype: torch.utils.data.DataLoader
"""
return self._create_dataloader(self.val_dataset)
return self._create_dataloader("val", self.val_dataset)
def train_dataloader(self):
"""
@@ -302,7 +566,7 @@ class PinaDataModule(LightningDataModule):
:return: The training dataloader
:rtype: torch.utils.data.DataLoader
"""
return self._create_dataloader(self.train_dataset)
return self._create_dataloader("train", self.train_dataset)
def test_dataloader(self):
"""
@@ -311,7 +575,7 @@ class PinaDataModule(LightningDataModule):
:return: The testing dataloader
:rtype: torch.utils.data.DataLoader
"""
return self._create_dataloader(self.test_dataset)
return self._create_dataloader("test", self.test_dataset)
@staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
@@ -327,7 +591,7 @@ class PinaDataModule(LightningDataModule):
:rtype: list[tuple]
"""
return list(batch.items())
return batch
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
@@ -385,15 +649,9 @@ class PinaDataModule(LightningDataModule):
to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None:
to_return["train"] = {
cond: data.input for cond, data in self.train_dataset.items()
}
to_return["train"] = self.train_dataset.input
if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = {
cond: data.input for cond, data in self.val_dataset.items()
}
to_return["val"] = self.val_dataset.input
if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return["test"] = {
cond: data.input for cond, data in self.test_dataset.items()
}
to_return["test"] = self.test_dataset.input
return to_return

View File

@@ -1,347 +0,0 @@
"""DataLoader module for PinaDataset."""
import itertools
import random
from functools import partial
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
from .stacked_dataloader import StackedDataLoader
class DummyDataloader:
"""
DataLoader that returns the entire dataset in a single batch.
"""
def __init__(self, dataset, device=None):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
# Handle distributed environment
if PinaSampler.is_distributed():
# Get rank and world size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Ensure dataset is large enough
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
# Split dataset among processes
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
else:
idx = list(range(len(dataset)))
self.dataset = dataset.getitem_from_list(idx)
self.device = device
self.dataset = (
{k: v.to(self.device) for k, v in self.dataset.items()}
if self.device
else self.dataset
)
def __iter__(self):
"""
Iterate over the dataloader.
"""
return self
def __len__(self):
"""
Return the length of the dataloader, which is always 1.
:return: The length of the dataloader.
:rtype: int
"""
return 1
def __next__(self):
"""
Return the entire dataset as a single batch.
:return: The entire dataset.
:rtype: dict
"""
return self.dataset
class PinaSampler:
"""
This class is used to create the sampler instance based on the shuffle
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset, shuffle=True):
"""
Instantiate and initialize the sampler.
:param PinaDataset dataset: The dataset from which to sample.
:return: The sampler instance.
:rtype: :class:`torch.utils.data.Sampler`
"""
if cls.is_distributed():
sampler = DistributedSampler(dataset, shuffle=shuffle)
else:
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
return sampler
@staticmethod
def is_distributed():
"""
Check if the sampler is distributed.
:return: True if the sampler is distributed, False otherwise.
:rtype: bool
"""
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
)
def _collect_items(batch):
"""
Helper function to collect items from a batch of graph data samples.
:param batch: List of graph data samples.
"""
to_return = {name: [] for name in batch[0].keys()}
for sample in batch:
for k, v in sample.items():
to_return[k].append(v)
return to_return
def collate_fn_custom(batch, dataset):
"""
Override the default collate function to handle datasets without automatic
batching.
:param batch: List of indices from the dataset.
:param dataset: The PinaDataset instance (must be provided).
"""
return dataset.getitem_from_list(batch)
def collate_fn_default(batch, stack_fn):
"""
Default collate function that simply returns the batch as is.
:param batch: List of data samples.
"""
to_return = _collect_items(batch)
return {k: stack_fn[k](v) for k, v in to_return.items()}
class PinaDataLoader:
"""
Custom DataLoader for PinaDataset.
"""
def __new__(cls, *args, **kwargs):
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
batch_size = kwargs.get("batch_size")
if batching_mode == "stacked" and batch_size is not None:
return StackedDataLoader(
args[0],
batch_size=batch_size,
shuffle=kwargs.get("shuffle", True),
)
elif batch_size is None:
kwargs["batching_mode"] = "proportional"
print(
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
)
return super(PinaDataLoader, cls).__new__(cls)
def __init__(
self,
dataset_dict,
batch_size,
num_workers=0,
shuffle=False,
batching_mode="common_batch_size",
device=None,
):
"""
Initialize the PinaDataLoader.
:param dict dataset_dict: A dictionary mapping dataset names to their
respective PinaDataset instances.
:param int batch_size: The batch size for the dataloader.
:param int num_workers: Number of worker processes for data loading.
:param bool shuffle: Whether to shuffle the data at every epoch.
:param str batching_mode: The batching mode to use. Options are
"common_batch_size", "separate_conditions", and "proportional".
:param device: The device to which the data should be moved.
"""
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.batching_mode = batching_mode.lower()
self.device = device
# Batch size None means we want to load the entire dataset in a single
# batch
if batch_size is None:
batch_size_per_dataset = {
split: None for split in dataset_dict.keys()
}
else:
# Compute batch size per dataset
if batching_mode in ["common_batch_size", "separate_conditions"]:
# (the sum of the batch sizes is equal to
# n_conditions * batch_size)
batch_size_per_dataset = {
split: min(batch_size, len(ds))
for split, ds in dataset_dict.items()
}
elif batching_mode == "proportional":
# batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size()
# Creaete a dataloader per dataset
self.dataloaders = {
split: self._create_dataloader(
dataset, batch_size_per_dataset[split]
)
for split, dataset in dataset_dict.items()
}
def _compute_batch_size(self):
"""
Compute an appropriate batch size for the given dataset.
"""
# Compute number of elements per dataset
elements_per_dataset = {
dataset_name: len(dataset)
for dataset_name, dataset in self.dataset_dict.items()
}
# Compute the total number of elements
total_elements = sum(el for el in elements_per_dataset.values())
# Compute the portion of each dataset
portion_per_dataset = {
name: el / total_elements
for name, el in elements_per_dataset.items()
}
# Compute batch size per dataset. Ensure at least 1 element per
# dataset.
batch_size_per_dataset = {
name: max(1, int(portion * self.batch_size))
for name, portion in portion_per_dataset.items()
}
# Adjust batch sizes to match the specified total batch size
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
if self.batch_size > tot_el_per_batch:
difference = self.batch_size - tot_el_per_batch
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] += 1
difference -= 1
if self.batch_size < tot_el_per_batch:
difference = tot_el_per_batch - self.batch_size
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] -= 1
difference -= 1
return batch_size_per_dataset
def _create_dataloader(self, dataset, batch_size):
"""
Create the dataloader for the given dataset.
:param PinaDataset dataset: The dataset for which to create the
dataloader.
:param int batch_size: The batch size for the dataloader.
:return: The created dataloader.
:rtype: :class:`torch.utils.data.DataLoader`
"""
# If batch size is None, use DummyDataloader
if batch_size is None or batch_size >= len(dataset):
return DummyDataloader(dataset, device=self.device)
# Determine the appropriate collate function
if not dataset.automatic_batching:
collate_fn = partial(collate_fn_custom, dataset=dataset)
else:
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
# Create and return the dataloader
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=self.num_workers,
sampler=PinaSampler(dataset, shuffle=self.shuffle),
)
def __len__(self):
"""
Return the length of the dataloader.
:return: The length of the dataloader.
:rtype: int
"""
# If separate conditions, return sum of lengths of all dataloaders
# else, return max length among dataloaders
if self.batching_mode == "separate_conditions":
return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values())
def __iter__(self):
"""
Iterate over the dataloader. Yields a dictionary mapping split name to batch.
The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
"""
if self.batching_mode == "separate_conditions":
tmp = []
for split, dl in self.dataloaders.items():
len_split = len(dl)
for i, batch in enumerate(dl):
tmp.append({split: batch})
if i + 1 >= len_split:
break
random.shuffle(tmp)
for batch_dict in tmp:
yield batch_dict
return
# Common_batch_size or Proportional mode (round-robin sampling)
iterators = {
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
}
# Iterate for the length of the longest dataloader
for _ in range(len(self)):
batch_dict: BatchDict = {}
for split, it in iterators.items():
# Since we use itertools.cycle, next(it) will always yield a batch
# by repeating the dataset, so no need for the 'if batch is None: return' check.
batch_dict[split] = next(it)
yield batch_dict

View File

@@ -1,170 +1,326 @@
"""Module for the PINA dataset classes."""
import torch
from abc import abstractmethod, ABC
from torch.utils.data import Dataset
from torch_geometric.data import Data
from ..graph import Graph, LabelBatch
from ..label_tensor import LabelTensor
STACK_FN_MAP = {
"label_tensor": LabelTensor.stack,
"tensor": torch.stack,
"data": LabelBatch.from_data_list,
}
class PinaDatasetFactory:
"""
Factory class to create PINA datasets based on the provided conditions
dictionary.
Factory class for the PINA dataset.
Depending on the data type inside the conditions, it instanciate an object
belonging to the appropriate subclass of
:class:`~pina.data.dataset.PinaDataset`. The possible subclasses are:
- :class:`~pina.data.dataset.PinaTensorDataset`, for handling \
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
- :class:`~pina.data.dataset.PinaGraphDataset`, for handling \
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
"""
def __new__(cls, conditions_dict, **kwargs):
"""
Create PINA dataset instances based on the provided conditions
dictionary.
Instantiate the appropriate subclass of
:class:`~pina.data.dataset.PinaDataset`.
:param dict conditions_dict: A dictionary where keys are condition names
and values are dictionaries containing the associated data.
:return: A dictionary mapping condition names to their respective
:class:`PinaDataset` instances.
If a graph is present in the conditions, returns a
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param dict conditions_dict: Dictionary containing all the conditions
to be included in the dataset instance.
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
:rtype: PinaTensorDataset | PinaGraphDataset
:raises ValueError: If an empty dictionary is provided.
"""
# Check if conditions_dict is empty
if len(conditions_dict) == 0:
raise ValueError("No conditions provided")
dataset_dict = {} # Dictionary to hold the created datasets
# Check is a Graph is present in the conditions
for name, data in conditions_dict.items():
# Validate that data is a dictionary
if not isinstance(data, dict):
raise ValueError(
f"Condition '{name}' data must be a dictionary"
)
# Create PinaDataset instance for each condition
dataset_dict[name] = PinaDataset(data, **kwargs)
return dataset_dict
is_graph = cls._is_graph_dataset(conditions_dict)
if is_graph:
# If a Graph is present, return a PinaGraphDataset
return PinaGraphDataset(conditions_dict, **kwargs)
# If no Graph is present, return a PinaTensorDataset
return PinaTensorDataset(conditions_dict, **kwargs)
class PinaDataset(Dataset):
"""
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`~pina.label_tensor.LabelTensor` data.
"""
def __init__(self, data_dict, automatic_batching=None):
@staticmethod
def _is_graph_dataset(conditions_dict):
"""
Initialize the instance by storing the conditions dictionary.
Check if a graph is present in the conditions (at least one time).
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise.
:rtype: bool
"""
# Iterate over the conditions dictionary
for v in conditions_dict.values():
# Iterate over the values of the current condition
for cond in v.values():
# Check if the current value is a list of Data objects
if isinstance(cond, (Data, Graph, list, tuple)):
return True
return False
class PinaDataset(Dataset, ABC):
"""
Abstract class for the PINA dataset which extends the PyTorch
:class:`~torch.utils.data.Dataset` class. It defines the common interface
for :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`~pina.data.dataset.PinaGraphDataset` classes.
"""
def __init__(
self, conditions_dict, max_conditions_lengths, automatic_batching
):
"""
Initialize the instance by storing the conditions dictionary, the
maximum number of items per conditions to consider, and the automatic
batching flag.
:param dict conditions_dict: A dictionary mapping condition names to
their respective data. Each key represents a condition name, and the
corresponding value is a dictionary containing the associated data.
:param dict max_conditions_lengths: Maximum number of data points that
can be included in a single batch per condition.
:param bool automatic_batching: Indicates whether PyTorch automatic
batching is enabled in
:class:`~pina.data.data_module.PinaDataModule`.
"""
# Store the conditions dictionary
self.data = data_dict
self.automatic_batching = (
automatic_batching if automatic_batching is not None else True
)
self._stack_fn = {}
self.is_graph_dataset = False
# Determine stacking functions for each data type (used in collate_fn)
for k, v in data_dict.items():
if isinstance(v, LabelTensor):
self._stack_fn[k] = "label_tensor"
elif isinstance(v, torch.Tensor):
self._stack_fn[k] = "tensor"
elif isinstance(v, list) and all(
isinstance(item, (Data, Graph)) for item in v
):
self._stack_fn[k] = "data"
self.is_graph_dataset = True
else:
raise ValueError(
f"Unsupported data type for stacking: {type(v)}"
)
self.conditions_dict = conditions_dict
# Store the maximum number of conditions to consider
self.max_conditions_lengths = max_conditions_lengths
# Store length of each condition
self.conditions_length = {
k: len(v["input"]) for k, v in self.conditions_dict.items()
}
# Store the maximum length of the dataset
self.length = max(self.conditions_length.values())
# Dynamically set the getitem function based on automatic batching
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_dummy
def __len__(self):
def _get_max_len(self):
"""
Return the length of the dataset.
Returns the length of the longest condition in the dataset.
:return: The length of the dataset.
:return: Length of the longest condition in the dataset.
:rtype: int
"""
return len(next(iter(self.data.values())))
max_len = 0
for condition in self.conditions_dict.values():
max_len = max(max_len, len(condition["input"]))
return max_len
def __len__(self):
return self.length
def __getitem__(self, idx):
return self._getitem_func(idx)
def _getitem_dummy(self, idx):
"""
Return the data at the given index in the dataset.
Return the index itself. This is used when automatic batching is
disabled to postpone the data retrieval to the dataloader.
:param int idx: Index.
:return: Index.
:rtype: int
"""
# If automatic batching is disabled, return the data at the given index
return idx
def _getitem_int(self, idx):
"""
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index.
:return: A dictionary containing the data at the given index.
:rtype: dict
"""
if self.automatic_batching:
# Return the data at the given index
return {
field_name: data[idx] for field_name, data in self.data.items()
}
return idx
# If automatic batching is enabled, return the data at the given index
return {
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()}
for k, v in self.conditions_dict.items()
}
def getitem_from_list(self, idx_list):
def get_all_data(self):
"""
Return all data in the dataset.
:return: A dictionary containing all the data in the dataset.
:rtype: dict
"""
to_return_dict = {}
for condition, data in self.conditions_dict.items():
len_condition = len(
data["input"]
) # Length of the current condition
to_return_dict[condition] = self._retrive_data(
data, list(range(len_condition))
) # Retrieve the data from the current condition
return to_return_dict
def fetch_from_idx_list(self, idx):
"""
Return data from the dataset given a list of indices.
:param list[int] idx_list: List of indices.
:param list[int] idx: List of indices.
:return: A dictionary containing the data at the given indices.
:rtype: dict
"""
to_return = {}
for field_name, data in self.data.items():
if self._stack_fn[field_name] == "data":
fn = STACK_FN_MAP[self._stack_fn[field_name]]
to_return[field_name] = fn([data[i] for i in idx_list])
else:
to_return[field_name] = data[idx_list]
return to_return
to_return_dict = {}
for condition, data in self.conditions_dict.items():
# Get the indices for the current condition
cond_idx = idx[: self.max_conditions_lengths[condition]]
# Get the length of the current condition
condition_len = self.conditions_length[condition]
# If the length of the dataset is greater than the length of the
# current condition, repeat the indices
if self.length > condition_len:
cond_idx = [idx % condition_len for idx in cond_idx]
# Retrieve the data from the current condition
to_return_dict[condition] = self._retrive_data(data, cond_idx)
return to_return_dict
def update_data(self, update_dict):
@abstractmethod
def _retrive_data(self, data, idx_list):
"""
Abstract method to retrieve data from the dataset given a list of
indices.
"""
Update the dataset's data in-place.
:param dict update_dict: A dictionary where keys are condition names
and values are dictionaries with updated data for those conditions.
class PinaTensorDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`~pina.label_tensor.LabelTensor` data.
"""
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
for field_name, updates in update_dict.items():
if field_name not in self.data:
raise KeyError(
f"Condition '{field_name}' not found in dataset."
)
if not isinstance(updates, (LabelTensor, torch.Tensor)):
raise ValueError(
f"Updates for condition '{field_name}' must be of type "
f"LabelTensor or torch.Tensor."
)
self.data[field_name] = updates
Retrieve data from the dataset given a list of indices.
:param dict data: Dictionary containing the data
(only :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor`).
:param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
return {k: v[idx_list] for k, v in data.items()}
@property
def input(self):
"""
Get the input data from the dataset.
Return the input data for the dataset.
:return: The input data.
:rtype: torch.Tensor | LabelTensor | Data | Graph
"""
return self.data["input"]
@property
def stack_fn(self):
"""
Get the mapping of stacking functions for each data type in the dataset.
:return: A dictionary mapping condition names to their respective
stacking function identifiers.
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()}
return {k: v["input"] for k, v in self.conditions_dict.items()}
def update_data(self, new_conditions_dict):
"""
Update the dataset with new data.
This method is used to update the dataset with new data. It replaces
the current data with the new data provided in the new_conditions_dict
parameter.
:param dict new_conditions_dict: Dictionary containing the new data.
:return: None
"""
for condition, data in new_conditions_dict.items():
if condition in self.conditions_dict:
self.conditions_dict[condition].update(data)
else:
self.conditions_dict[condition] = data
class PinaGraphDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
and :class:`~pina.graph.Graph` data.
"""
def _create_graph_batch(self, data):
"""
Create a LabelBatch object from a list of
:class:`~torch_geometric.data.Data` objects.
:param data: List of items to collate in a single batch.
:type data: list[Data] | list[Graph]
:return: LabelBatch object all the graph collated in a single batch
disconnected graphs.
:rtype: LabelBatch
"""
batch = LabelBatch.from_data_list(data)
return batch
def create_batch(self, data):
"""
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
:param data: List of items to collate in a single batch.
:type data: list[Data] | list[Graph]
:return: Batch object.
:rtype: :class:`~torch_geometric.data.Batch`
| :class:`~pina.graph.LabelBatch`
"""
if isinstance(data[0], Data):
return self._create_graph_batch(data)
return self._create_tensor_batch(data)
# Override _retrive_data method for graph handling
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices.
:param dict data: Dictionary containing the data.
:param list[int] idx_list: List of indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor
return {
k: (
self._create_graph_batch([v[i] for i in idx_list])
if isinstance(v, list)
else v[idx_list]
)
for k, v in data.items()
}
@property
def input(self):
"""
Return the input data for the dataset.
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}

View File

@@ -1,53 +0,0 @@
import torch
from math import ceil
class StackedDataLoader:
def __init__(self, datasets, batch_size=32, shuffle=True):
for d in datasets.values():
if d.is_graph_dataset:
raise ValueError("Each dataset must be a dictionary")
self.chunks = {}
self.total_length = 0
self.indices = []
self._init_chunks(datasets)
self.indices = list(range(self.total_length))
self.batch_size = batch_size
self.shuffle = shuffle
if self.shuffle:
torch.random.manual_seed(42)
self.indices = torch.randperm(self.total_length).tolist()
self.datasets = datasets
def _init_chunks(self, datasets):
inc = 0
total_length = 0
for name, dataset in datasets.items():
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
inc += len(dataset)
self.total_length = inc
def __len__(self):
return ceil(self.total_length / self.batch_size)
def _build_batch_indices(self, batch_idx):
start = batch_idx * self.batch_size
end = min(start + self.batch_size, self.total_length)
return self.indices[start:end]
def __iter__(self):
for batch_idx in range(len(self)):
batch_indices = self._build_batch_indices(batch_idx)
batch_data = {}
for name, chunk in self.chunks.items():
local_indices = [
idx - chunk["start"]
for idx in batch_indices
if chunk["start"] <= idx < chunk["end"]
]
if local_indices:
batch_data[name] = self.datasets[name].getitem_from_list(
local_indices
)
yield batch_data

View File

@@ -1,6 +1,7 @@
"""Module for the Equation."""
import inspect
from .equation_interface import EquationInterface
@@ -48,10 +49,6 @@ class Equation(EquationInterface):
:raises RuntimeError: If the underlying equation signature length is not
2 (direct problem) or 3 (inverse problem).
"""
# Move the equation to the input_ device
self.to(input_.device)
# Call the underlying equation based on its signature length
if self.__len_sig == 2:
return self.__equation(input_, output_)
if self.__len_sig == 3:

View File

@@ -239,19 +239,19 @@ class Advection(Equation): # pylint: disable=R0903
)
# Ensure consistency of c length
if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1:
if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1:
raise ValueError(
"If 'c' is passed as a list, its length must be equal to "
"the number of spatial dimensions."
)
# Repeat c to ensure consistent shape for advection
c = self.c.repeat(output_.shape[0], 1)
if c.shape[1] != (len(input_lbl) - 1):
c = c.repeat(1, len(input_lbl) - 1)
self.c = self.c.repeat(output_.shape[0], 1)
if self.c.shape[1] != (len(input_lbl) - 1):
self.c = self.c.repeat(1, len(input_lbl) - 1)
# Add a dimension to c for the following operations
c = c.unsqueeze(-1)
self.c = self.c.unsqueeze(-1)
# Compute the time derivative and the spatial gradient
time_der = grad(output_, input_, components=None, d="t")
@@ -262,7 +262,7 @@ class Advection(Equation): # pylint: disable=R0903
tmp = tmp.transpose(-1, -2)
# Compute advection term
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
return time_der + adv

View File

@@ -1,7 +1,6 @@
"""Module for the Equation Interface."""
from abc import ABCMeta, abstractmethod
import torch
class EquationInterface(metaclass=ABCMeta):
@@ -34,33 +33,3 @@ class EquationInterface(metaclass=ABCMeta):
:return: The computed residual of the equation.
:rtype: LabelTensor
"""
def to(self, device):
"""
Move all tensor attributes to the specified device.
:param torch.device device: The target device to move the tensors to.
:return: The instance moved to the specified device.
:rtype: EquationInterface
"""
# Iterate over all attributes of the Equation
for key, val in self.__dict__.items():
# Move tensors in dictionaries to the specified device
if isinstance(val, dict):
self.__dict__[key] = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in val.items()
}
# Move tensors in lists to the specified device
elif isinstance(val, list):
self.__dict__[key] = [
v.to(device) if torch.is_tensor(v) else v for v in val
]
# Move tensor attributes to the specified device
elif torch.is_tensor(val):
self.__dict__[key] = val.to(device)
return self

View File

@@ -101,10 +101,6 @@ class SystemEquation(EquationInterface):
:return: The aggregated residuals of the system of equations.
:rtype: LabelTensor
"""
# Move the equation to the input_ device
self.to(input_.device)
# Compute the residual for each equation
residual = torch.hstack(
[
equation.residual(input_, output_, params_)
@@ -112,7 +108,6 @@ class SystemEquation(EquationInterface):
]
)
# Skip reduction if not specified
if self.reduction is None:
return residual

View File

@@ -41,7 +41,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
"""
def __init__(
def __init__( # pylint: disable=R0913, R0917
self,
node_feature_dim,
edge_feature_dim,
@@ -143,7 +143,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
func=activation,
)
def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
def forward(
self, x, pos, edge_index, edge_attr=None, vel=None
): # pylint: disable=R0917
"""
Forward pass of the block, triggering the message-passing routine.
@@ -169,7 +171,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
)
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
def message(
self, x_i, x_j, pos_i, pos_j, edge_attr
): # pylint: disable=R0917
"""
Compute the message to be passed between nodes and edges.
@@ -234,7 +238,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
return agg_message, agg_m_ij
def update(self, aggregated_inputs, x, pos, edge_index, vel):
def update(
self, aggregated_inputs, x, pos, edge_index, vel
): # pylint: disable=R0917
"""
Update node features, positions, and optionally velocities.

View File

@@ -23,7 +23,7 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
<https://arxiv.org/abs/2401.11037>`_
"""
def __init__(
def __init__( # pylint: disable=R0913, R0917
self,
node_feature_dim,
edge_feature_dim,
@@ -101,7 +101,9 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
flow=flow,
)
def forward(self, x, pos, vel, edge_index, edge_attr=None):
def forward( # pylint: disable=R0917
self, x, pos, vel, edge_index, edge_attr=None
):
"""
Forward pass of the Equivariant Graph Neural Operator block.
@@ -182,7 +184,11 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
weights = torch.complex(real[..., :modes], img[..., :modes])
# Convolution in Fourier space
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
# torch.fft.rfftn and irfftn are callable functions, but pylint
# incorrectly flags them as E1102 (not callable).
fourier = torch.fft.rfftn(x, dim=[0])[:modes] # pylint: disable=E1102
out = torch.einsum(einsum_idx, fourier, weights)
return torch.fft.irfftn(out, s=x.shape[0], dim=0)
return torch.fft.irfftn( # pylint: disable=E1102
out, s=x.shape[0], dim=0
)

View File

@@ -5,7 +5,9 @@ from ..utils import check_positive_integer
from .block.message_passing import EquivariantGraphNeuralOperatorBlock
class EquivariantGraphNeuralOperator(torch.nn.Module):
# Disable pylint warnings for too few public methods (since this is a simple
# model class in a standard PyTorch style)
class EquivariantGraphNeuralOperator(torch.nn.Module): # pylint: disable=R0903
"""
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
@@ -32,7 +34,9 @@ class EquivariantGraphNeuralOperator(torch.nn.Module):
<https://arxiv.org/abs/2401.11037>`_
"""
def __init__(
# Disable pylint warnings for too many arguments in init (since this is a
# model class with many configurable parameters)
def __init__( # pylint: disable=R0913, R0917, R0914
self,
n_egno_layers,
node_feature_dim,

View File

@@ -48,10 +48,11 @@ class HelmholtzProblem(SpatialProblem):
:type alpha: float | int
"""
super().__init__()
check_consistency(alpha, (int, float))
self.alpha = alpha
def forcing_term(input_):
self.alpha = alpha
check_consistency(alpha, (int, float))
def forcing_term(self, input_):
"""
Implementation of the forcing term.
"""

View File

@@ -71,7 +71,9 @@ class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
"""
# Override the compilation, compiling only for torch < 2.8, see
# related issue at https://github.com/mathLab/PINA/issues/621
if torch.__version__ >= "2.8":
if torch.__version__ < "2.8":
self.trainer.compile = True
else:
self.trainer.compile = False
warnings.warn(
"Compilation is disabled for torch >= 2.8. "

View File

@@ -174,7 +174,11 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
:return: The result of the parent class ``setup`` method.
:rtype: Any
"""
if self.trainer.compile and not self._is_compiled():
if stage == "fit" and self.trainer.compile:
self._setup_compile()
if stage == "test" and (
self.trainer.compile and not self._is_compiled()
):
self._setup_compile()
return super().setup(stage)

View File

@@ -1,17 +1,12 @@
"""Module for the Trainer."""
import sys
import warnings
import torch
import lightning
from .utils import check_consistency, custom_warning_format
from .utils import check_consistency
from .data import PinaDataModule
from .solver import SolverInterface, PINNInterface
# set the warning for compile options
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)
class Trainer(lightning.pytorch.Trainer):
"""
@@ -31,7 +26,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0,
val_size=0.0,
compile=None,
batching_mode="common_batch_size",
repeat=None,
automatic_batching=None,
num_workers=None,
pin_memory=None,
@@ -54,14 +49,11 @@ class Trainer(lightning.pytorch.Trainer):
:param float val_size: The percentage of elements to include in the
validation dataset. Default is ``0.0``.
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14.
:param bool common_batch_size: If ``True``, the same batch size is used
for all conditions. If ``False``, each condition can have its own
batch size, proportional to the size of the dataset in that
condition. Default is ``True``.
:param bool separate_conditions: If ``True``, dataloaders for each
condition are iterated separately. Default is ``False``.
Default is ``False``. For Windows users, it is always disabled.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.data_module.PinaDataModule` class. Default is
``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset
all at once. For further details, see the
@@ -84,7 +76,7 @@ class Trainer(lightning.pytorch.Trainer):
train_size=train_size,
test_size=test_size,
val_size=val_size,
batching_mode=batching_mode,
repeat=repeat,
automatic_batching=automatic_batching,
compile=compile,
)
@@ -112,17 +104,10 @@ class Trainer(lightning.pytorch.Trainer):
super().__init__(**kwargs)
# checking compilation and automatic batching
# compilation disabled for Windows and for Python 3.14+
if (
compile is None
or sys.platform == "win32"
or sys.version_info >= (3, 14)
):
if compile is None or sys.platform == "win32":
compile = False
warnings.warn(
"Compilation is disabled for Python 3.14+ and for Windows.",
UserWarning,
)
repeat = repeat if repeat is not None else False
automatic_batching = (
automatic_batching if automatic_batching is not None else False
@@ -139,7 +124,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
batching_mode=batching_mode,
repeat=repeat,
automatic_batching=automatic_batching,
pin_memory=pin_memory,
num_workers=num_workers,
@@ -177,7 +162,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size,
val_size,
batch_size,
batching_mode,
repeat,
automatic_batching,
pin_memory,
num_workers,
@@ -196,10 +181,8 @@ class Trainer(lightning.pytorch.Trainer):
:param float val_size: The percentage of elements to include in the
validation dataset.
:param int batch_size: The number of samples per batch to load.
:param bool common_batch_size: Whether to use the same batch size for
all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data
@@ -229,7 +212,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
batching_mode=batching_mode,
repeat=repeat,
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
@@ -281,7 +264,7 @@ class Trainer(lightning.pytorch.Trainer):
train_size,
test_size,
val_size,
batching_mode,
repeat,
automatic_batching,
compile,
):
@@ -295,10 +278,8 @@ class Trainer(lightning.pytorch.Trainer):
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param bool common_batch_size: Whether to use the same batch size for
all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool compile: If ``True``, the model is compiled before training.
@@ -308,7 +289,8 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
check_consistency(batching_mode, str)
if repeat is not None:
check_consistency(repeat, bool)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None:
@@ -343,23 +325,3 @@ class Trainer(lightning.pytorch.Trainer):
if batch_size is not None:
check_consistency(batch_size, int)
return pin_memory, num_workers, shuffle, batch_size
@property
def compile(self):
"""
Whether compilation is required or not.
:return: ``True`` if compilation is required, ``False`` otherwise.
:rtype: bool
"""
return self._compile
@compile.setter
def compile(self, value):
"""
Setting the value of compile.
:param bool value: Whether compilation is required or not.
"""
check_consistency(value, bool)
self._compile = value

View File

@@ -1,6 +1,6 @@
[project]
name = "pina-mathlab"
version = "0.2.5"
version = "0.2.3"
description = "Physic Informed Neural networks for Advance modeling."
readme = "README.md"
authors = [
@@ -19,7 +19,7 @@ dependencies = [
"torch_geometric",
"matplotlib",
]
requires-python = ">=3.10"
requires-python = ">=3.9"
[project.optional-dependencies]
doc = [

Binary file not shown.

Before

Width:  |  Height:  |  Size: 411 KiB

After

Width:  |  Height:  |  Size: 51 KiB

View File

@@ -51,7 +51,7 @@ def test_sample(condition_to_update):
}
trainer.train()
after_n_points = {
loc: len(trainer.data_module.train_dataset[loc].input)
loc: len(trainer.data_module.train_dataset.input[loc])
for loc in condition_to_update
}
assert before_n_points == trainer.callbacks[0].initial_population_size

View File

@@ -142,10 +142,14 @@ def test_setup(solver, fn, stage, apply_to):
for cond in ["data1", "data2"]:
scale = scale_fn(
trainer_copy.data_module.train_dataset[cond].data[apply_to]
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset[cond].data[apply_to]
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
@@ -154,8 +158,8 @@ def test_setup(solver, fn, stage, apply_to):
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset[cond].data[apply_to]
old_points = old_dataset[cond].data[apply_to]
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
@@ -200,10 +204,10 @@ def test_setup_pinn(fn, stage, apply_to):
cond = "data"
scale = scale_fn(
trainer_copy.data_module.train_dataset[cond].data[apply_to]
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset[cond].data[apply_to]
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
@@ -212,8 +216,8 @@ def test_setup_pinn(fn, stage, apply_to):
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset[cond].data[apply_to]
old_points = old_dataset[cond].data[apply_to]
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
@@ -238,7 +242,3 @@ def test_setup_graph_dataset():
)
with pytest.raises(NotImplementedError):
trainer.train()
# if __name__ == "__main__":
# test_setup(supervised_solver_lt, [torch.std, torch.mean], "all", "input")

View File

@@ -1,11 +1,10 @@
import torch
import pytest
from pina.data import PinaDataModule
from pina.data.dataset import PinaDataset
from pina.data.dataset import PinaTensorDataset, PinaGraphDataset
from pina.problem.zoo import SupervisedProblem
from pina.graph import RadiusGraph
from pina.data.dataloader import DummyDataloader, PinaDataLoader
from pina.data.data_module import DummyDataloader
from pina import Trainer
from pina.solver import SupervisedSolver
from torch_geometric.data import Batch
@@ -45,33 +44,22 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
)
dm.setup()
assert hasattr(dm, "train_dataset")
assert isinstance(dm.train_dataset, dict)
assert all(
isinstance(dm.train_dataset[cond], PinaDataset)
for cond in dm.train_dataset
)
assert all(
dm.train_dataset[cond].is_graph_dataset == isinstance(input_, list)
for cond in dm.train_dataset
)
assert all(
len(dm.train_dataset[cond]) == int(len(input_) * train_size)
for cond in dm.train_dataset
)
if isinstance(input_, torch.Tensor):
assert isinstance(dm.train_dataset, PinaTensorDataset)
else:
assert isinstance(dm.train_dataset, PinaGraphDataset)
# assert len(dm.train_dataset) == int(len(input_) * train_size)
if test_size > 0:
assert hasattr(dm, "test_dataset")
assert dm.test_dataset is None
else:
assert not hasattr(dm, "test_dataset")
assert hasattr(dm, "val_dataset")
assert isinstance(dm.val_dataset, dict)
assert all(
isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
)
assert all(
isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
)
if isinstance(input_, torch.Tensor):
assert isinstance(dm.val_dataset, PinaTensorDataset)
else:
assert isinstance(dm.val_dataset, PinaGraphDataset)
# assert len(dm.val_dataset) == int(len(input_) * val_size)
@pytest.mark.parametrize(
@@ -99,59 +87,49 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
assert not hasattr(dm, "val_dataset")
assert hasattr(dm, "test_dataset")
assert all(
isinstance(dm.test_dataset[cond], PinaDataset)
for cond in dm.test_dataset
)
assert all(
dm.test_dataset[cond].is_graph_dataset == isinstance(input_, list)
for cond in dm.test_dataset
)
assert all(
len(dm.test_dataset[cond]) == int(len(input_) * test_size)
for cond in dm.test_dataset
if isinstance(input_, torch.Tensor):
assert isinstance(dm.test_dataset, PinaTensorDataset)
else:
assert isinstance(dm.test_dataset, PinaGraphDataset)
# assert len(dm.test_dataset) == int(len(input_) * test_size)
@pytest.mark.parametrize(
"input_, output_",
[(input_tensor, output_tensor), (input_graph, output_graph)],
)
def test_dummy_dataloader(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(
solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
)
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, DummyDataloader)
assert len(dataloader) == 1
data = next(dataloader)
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, list):
assert isinstance(data[0][1]["input"], Batch)
else:
assert isinstance(data[0][1]["input"], torch.Tensor)
assert isinstance(data[0][1]["target"], torch.Tensor)
# @pytest.mark.parametrize(
# "input_, output_",
# [(input_tensor, output_tensor), (input_graph, output_graph)],
# )
# def test_dummy_dataloader(input_, output_):
# problem = SupervisedProblem(input_=input_, output_=output_)
# solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
# trainer = Trainer(
# solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
# )
# dm = trainer.data_module
# dm.setup()
# dm.trainer = trainer
# dataloader = dm.train_dataloader()
# assert isinstance(dataloader, PinaDataLoader)
# print(dataloader.dataloaders)
# assert all([isinstance(ds, DummyDataloader) for ds in dataloader.dataloaders.values()])
# data = next(iter(dataloader))
# assert isinstance(data, list)
# assert isinstance(data[0], tuple)
# if isinstance(input_, list):
# assert isinstance(data[0][1]["input"], Batch)
# else:
# assert isinstance(data[0][1]["input"], torch.Tensor)
# assert isinstance(data[0][1]["target"], torch.Tensor)
# dataloader = dm.val_dataloader()
# assert isinstance(dataloader, DummyDataloader)
# assert len(dataloader) == 1
# data = next(dataloader)
# assert isinstance(data, list)
# assert isinstance(data[0], tuple)
# if isinstance(input_, list):
# assert isinstance(data[0][1]["input"], Batch)
# else:
# assert isinstance(data[0][1]["input"], torch.Tensor)
# assert isinstance(data[0][1]["target"], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DummyDataloader)
assert len(dataloader) == 1
data = next(dataloader)
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, list):
assert isinstance(data[0][1]["input"], Batch)
else:
assert isinstance(data[0][1]["input"], torch.Tensor)
assert isinstance(data[0][1]["target"], torch.Tensor)
@pytest.mark.parametrize(
@@ -159,11 +137,7 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
[(input_tensor, output_tensor), (input_graph, output_graph)],
)
@pytest.mark.parametrize("automatic_batching", [True, False])
@pytest.mark.parametrize("batch_size", [None, 10])
@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"])
def test_dataloader(
input_, output_, automatic_batching, batch_size, batching_mode
):
def test_dataloader(input_, output_, automatic_batching):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(
@@ -173,13 +147,12 @@ def test_dataloader(
val_size=0.3,
test_size=0.0,
automatic_batching=automatic_batching,
batching_mode=batching_mode,
)
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, PinaDataLoader)
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 7
data = next(iter(dataloader))
assert isinstance(data, dict)
@@ -190,8 +163,8 @@ def test_dataloader(
assert isinstance(data["data"]["target"], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 3 if batch_size is not None else 1
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 3
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, list):
@@ -229,13 +202,12 @@ def test_dataloader_labels(input_, output_, automatic_batching):
val_size=0.3,
test_size=0.0,
automatic_batching=automatic_batching,
# common_batch_size=True,
)
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, PinaDataLoader)
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 7
data = next(iter(dataloader))
assert isinstance(data, dict)
@@ -251,7 +223,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
assert data["data"]["target"].labels == ["u", "v", "w"]
dataloader = dm.val_dataloader()
assert isinstance(dataloader, PinaDataLoader)
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 3
data = next(iter(dataloader))
assert isinstance(data, dict)
@@ -268,6 +240,39 @@ def test_dataloader_labels(input_, output_, automatic_batching):
assert data["data"]["target"].labels == ["u", "v", "w"]
def test_get_all_data():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input
problem = SupervisedProblem(input, target)
datamodule = PinaDataModule(
problem,
train_size=0.7,
test_size=0.2,
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
)
datamodule.setup("fit")
datamodule.setup("test")
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
assert torch.isclose(
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
).all()
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
assert torch.isclose(
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
).all()
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
assert torch.isclose(
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
).all()
def test_input_propery_tensor():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input
@@ -280,6 +285,7 @@ def test_input_propery_tensor():
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
@@ -305,6 +311,7 @@ def test_input_propery_graph():
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,

View File

@@ -1,138 +1,138 @@
# import torch
# import pytest
# from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset
# from pina.graph import KNNGraph
# from torch_geometric.data import Data
import torch
import pytest
from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset
from pina.graph import KNNGraph
from torch_geometric.data import Data
# x = torch.rand((100, 20, 10))
# pos = torch.rand((100, 20, 2))
# input_ = [
# KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
# for x_, pos_ in zip(x, pos)
# ]
# output_ = torch.rand((100, 20, 10))
x = torch.rand((100, 20, 10))
pos = torch.rand((100, 20, 2))
input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos)
]
output_ = torch.rand((100, 20, 10))
# x_2 = torch.rand((50, 20, 10))
# pos_2 = torch.rand((50, 20, 2))
# input_2_ = [
# KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
# for x_, pos_ in zip(x_2, pos_2)
# ]
# output_2_ = torch.rand((50, 20, 10))
x_2 = torch.rand((50, 20, 10))
pos_2 = torch.rand((50, 20, 2))
input_2_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x_2, pos_2)
]
output_2_ = torch.rand((50, 20, 10))
# # Problem with a single condition
# conditions_dict_single = {
# "data": {
# "input": input_,
# "target": output_,
# }
# }
# max_conditions_lengths_single = {"data": 100}
# Problem with a single condition
conditions_dict_single = {
"data": {
"input": input_,
"target": output_,
}
}
max_conditions_lengths_single = {"data": 100}
# # Problem with multiple conditions
# conditions_dict_multi = {
# "data_1": {
# "input": input_,
# "target": output_,
# },
# "data_2": {
# "input": input_2_,
# "target": output_2_,
# },
# }
# Problem with multiple conditions
conditions_dict_multi = {
"data_1": {
"input": input_,
"target": output_,
},
"data_2": {
"input": input_2_,
"target": output_2_,
},
}
# max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
# @pytest.mark.parametrize(
# "conditions_dict, max_conditions_lengths",
# [
# (conditions_dict_single, max_conditions_lengths_single),
# (conditions_dict_multi, max_conditions_lengths_multi),
# ],
# )
# def test_constructor(conditions_dict, max_conditions_lengths):
# dataset = PinaDatasetFactory(
# conditions_dict,
# max_conditions_lengths=max_conditions_lengths,
# automatic_batching=True,
# )
# assert isinstance(dataset, PinaGraphDataset)
# assert len(dataset) == 100
@pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_multi, max_conditions_lengths_multi),
],
)
def test_constructor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory(
conditions_dict,
max_conditions_lengths=max_conditions_lengths,
automatic_batching=True,
)
assert isinstance(dataset, PinaGraphDataset)
assert len(dataset) == 100
# @pytest.mark.parametrize(
# "conditions_dict, max_conditions_lengths",
# [
# (conditions_dict_single, max_conditions_lengths_single),
# (conditions_dict_multi, max_conditions_lengths_multi),
# ],
# )
# def test_getitem(conditions_dict, max_conditions_lengths):
# dataset = PinaDatasetFactory(
# conditions_dict,
# max_conditions_lengths=max_conditions_lengths,
# automatic_batching=True,
# )
# data = dataset[50]
# assert isinstance(data, dict)
# assert all([isinstance(d["input"], Data) for d in data.values()])
# assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
# assert all(
# [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()]
# )
# assert all(
# [d["target"].shape == torch.Size((20, 10)) for d in data.values()]
# )
# assert all(
# [
# d["input"].edge_index.shape == torch.Size((2, 60))
# for d in data.values()
# ]
# )
# assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()])
@pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_multi, max_conditions_lengths_multi),
],
)
def test_getitem(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory(
conditions_dict,
max_conditions_lengths=max_conditions_lengths,
automatic_batching=True,
)
data = dataset[50]
assert isinstance(data, dict)
assert all([isinstance(d["input"], Data) for d in data.values()])
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
assert all(
[d["input"].x.shape == torch.Size((20, 10)) for d in data.values()]
)
assert all(
[d["target"].shape == torch.Size((20, 10)) for d in data.values()]
)
assert all(
[
d["input"].edge_index.shape == torch.Size((2, 60))
for d in data.values()
]
)
assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()])
# data = dataset.fetch_from_idx_list([i for i in range(20)])
# assert isinstance(data, dict)
# assert all([isinstance(d["input"], Data) for d in data.values()])
# assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
# assert all(
# [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
# )
# assert all(
# [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
# )
# assert all(
# [
# d["input"].edge_index.shape == torch.Size((2, 1200))
# for d in data.values()
# ]
# )
# assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
data = dataset.fetch_from_idx_list([i for i in range(20)])
assert isinstance(data, dict)
assert all([isinstance(d["input"], Data) for d in data.values()])
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
assert all(
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
)
assert all(
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
)
assert all(
[
d["input"].edge_index.shape == torch.Size((2, 1200))
for d in data.values()
]
)
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
# def test_input_single_condition():
# dataset = PinaDatasetFactory(
# conditions_dict_single,
# max_conditions_lengths=max_conditions_lengths_single,
# automatic_batching=True,
# )
# input_ = dataset.input
# assert isinstance(input_, dict)
# assert isinstance(input_["data"], list)
# assert all([isinstance(d, Data) for d in input_["data"]])
def test_input_single_condition():
dataset = PinaDatasetFactory(
conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=True,
)
input_ = dataset.input
assert isinstance(input_, dict)
assert isinstance(input_["data"], list)
assert all([isinstance(d, Data) for d in input_["data"]])
# def test_input_multi_condition():
# dataset = PinaDatasetFactory(
# conditions_dict_multi,
# max_conditions_lengths=max_conditions_lengths_multi,
# automatic_batching=True,
# )
# input_ = dataset.input
# assert isinstance(input_, dict)
# assert isinstance(input_["data_1"], list)
# assert all([isinstance(d, Data) for d in input_["data_1"]])
# assert isinstance(input_["data_2"], list)
# assert all([isinstance(d, Data) for d in input_["data_2"]])
def test_input_multi_condition():
dataset = PinaDatasetFactory(
conditions_dict_multi,
max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=True,
)
input_ = dataset.input
assert isinstance(input_, dict)
assert isinstance(input_["data_1"], list)
assert all([isinstance(d, Data) for d in input_["data_1"]])
assert isinstance(input_["data_2"], list)
assert all([isinstance(d, Data) for d in input_["data_2"]])

View File

@@ -1,86 +1,86 @@
# import torch
# import pytest
# from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset
import torch
import pytest
from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset
# input_tensor = torch.rand((100, 10))
# output_tensor = torch.rand((100, 2))
input_tensor = torch.rand((100, 10))
output_tensor = torch.rand((100, 2))
# input_tensor_2 = torch.rand((50, 10))
# output_tensor_2 = torch.rand((50, 2))
input_tensor_2 = torch.rand((50, 10))
output_tensor_2 = torch.rand((50, 2))
# conditions_dict_single = {
# "data": {
# "input": input_tensor,
# "target": output_tensor,
# }
# }
conditions_dict_single = {
"data": {
"input": input_tensor,
"target": output_tensor,
}
}
# conditions_dict_single_multi = {
# "data_1": {
# "input": input_tensor,
# "target": output_tensor,
# },
# "data_2": {
# "input": input_tensor_2,
# "target": output_tensor_2,
# },
# }
conditions_dict_single_multi = {
"data_1": {
"input": input_tensor,
"target": output_tensor,
},
"data_2": {
"input": input_tensor_2,
"target": output_tensor_2,
},
}
# max_conditions_lengths_single = {"data": 100}
max_conditions_lengths_single = {"data": 100}
# max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
# @pytest.mark.parametrize(
# "conditions_dict, max_conditions_lengths",
# [
# (conditions_dict_single, max_conditions_lengths_single),
# (conditions_dict_single_multi, max_conditions_lengths_multi),
# ],
# )
# def test_constructor_tensor(conditions_dict, max_conditions_lengths):
# dataset = PinaDatasetFactory(
# conditions_dict,
# max_conditions_lengths=max_conditions_lengths,
# automatic_batching=True,
# )
# assert isinstance(dataset, PinaTensorDataset)
@pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi),
],
)
def test_constructor_tensor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory(
conditions_dict,
max_conditions_lengths=max_conditions_lengths,
automatic_batching=True,
)
assert isinstance(dataset, PinaTensorDataset)
# def test_getitem_single():
# dataset = PinaDatasetFactory(
# conditions_dict_single,
# max_conditions_lengths=max_conditions_lengths_single,
# automatic_batching=False,
# )
def test_getitem_single():
dataset = PinaDatasetFactory(
conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=False,
)
# tensors = dataset.fetch_from_idx_list([i for i in range(70)])
# assert isinstance(tensors, dict)
# assert list(tensors.keys()) == ["data"]
# assert sorted(list(tensors["data"].keys())) == ["input", "target"]
# assert isinstance(tensors["data"]["input"], torch.Tensor)
# assert tensors["data"]["input"].shape == torch.Size((70, 10))
# assert isinstance(tensors["data"]["target"], torch.Tensor)
# assert tensors["data"]["target"].shape == torch.Size((70, 2))
tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict)
assert list(tensors.keys()) == ["data"]
assert sorted(list(tensors["data"].keys())) == ["input", "target"]
assert isinstance(tensors["data"]["input"], torch.Tensor)
assert tensors["data"]["input"].shape == torch.Size((70, 10))
assert isinstance(tensors["data"]["target"], torch.Tensor)
assert tensors["data"]["target"].shape == torch.Size((70, 2))
# def test_getitem_multi():
# dataset = PinaDatasetFactory(
# conditions_dict_single_multi,
# max_conditions_lengths=max_conditions_lengths_multi,
# automatic_batching=False,
# )
# tensors = dataset.fetch_from_idx_list([i for i in range(70)])
# assert isinstance(tensors, dict)
# assert list(tensors.keys()) == ["data_1", "data_2"]
# assert sorted(list(tensors["data_1"].keys())) == ["input", "target"]
# assert isinstance(tensors["data_1"]["input"], torch.Tensor)
# assert tensors["data_1"]["input"].shape == torch.Size((70, 10))
# assert isinstance(tensors["data_1"]["target"], torch.Tensor)
# assert tensors["data_1"]["target"].shape == torch.Size((70, 2))
def test_getitem_multi():
dataset = PinaDatasetFactory(
conditions_dict_single_multi,
max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=False,
)
tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict)
assert list(tensors.keys()) == ["data_1", "data_2"]
assert sorted(list(tensors["data_1"].keys())) == ["input", "target"]
assert isinstance(tensors["data_1"]["input"], torch.Tensor)
assert tensors["data_1"]["input"].shape == torch.Size((70, 10))
assert isinstance(tensors["data_1"]["target"], torch.Tensor)
assert tensors["data_1"]["target"].shape == torch.Size((70, 2))
# assert sorted(list(tensors["data_2"].keys())) == ["input", "target"]
# assert isinstance(tensors["data_2"]["input"], torch.Tensor)
# assert tensors["data_2"]["input"].shape == torch.Size((50, 10))
# assert isinstance(tensors["data_2"]["target"], torch.Tensor)
# assert tensors["data_2"]["target"].shape == torch.Size((50, 2))
assert sorted(list(tensors["data_2"].keys())) == ["input", "target"]
assert isinstance(tensors["data_2"]["input"], torch.Tensor)
assert tensors["data_2"]["input"].shape == torch.Size((50, 10))
assert isinstance(tensors["data_2"]["target"], torch.Tensor)
assert tensors["data_2"]["target"].shape == torch.Size((50, 2))

View File

@@ -104,7 +104,7 @@ def test_advection_equation(c):
# Should fail if c is a list and its length != spatial dimension
with pytest.raises(ValueError):
equation = Advection([1, 2, 3])
Advection([1, 2, 3])
residual = equation.residual(pts, u)

View File

@@ -117,10 +117,6 @@ def test_solver_train(use_lt, batch_size, compile):
assert isinstance(solver.model, OptimizedModule)
if __name__ == "__main__":
test_solver_train(use_lt=True, batch_size=20, compile=True)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt):

Binary file not shown.

Before

Width:  |  Height:  |  Size: 411 KiB

After

Width:  |  Height:  |  Size: 51 KiB