From f748b661942dbc779783393dd9d943bcd02d7d6b Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Thu, 28 Nov 2024 11:06:38 +0100 Subject: [PATCH] Fix bugs (#387) --- pina/data/data_module.py | 24 +++++++++++++----------- pina/data/dataset.py | 10 ++++++---- pina/label_tensor.py | 2 +- pina/solvers/__init__.py | 1 - pina/solvers/pinns/basepinn.py | 5 +---- pina/solvers/pinns/pinn.py | 2 +- pina/solvers/solver.py | 10 ++++++---- pina/solvers/supervised.py | 2 -- pina/trainer.py | 1 - 9 files changed, 28 insertions(+), 29 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index c9af8ae..4831e20 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -95,7 +95,7 @@ class PinaDataModule(LightningDataModule): logging.debug('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule') super().__init__() - self.default_batching = automatic_batching + self.automatic_batching = automatic_batching self.batch_size = batch_size self.shuffle = shuffle self.repeat = repeat @@ -133,24 +133,24 @@ class PinaDataModule(LightningDataModule): self.train_dataset = PinaDatasetFactory( self.collector_splits['train'], max_conditions_lengths=self.find_max_conditions_lengths( - 'train')) + 'train'), automatic_batching=self.automatic_batching) if 'val' in self.collector_splits.keys(): self.val_dataset = PinaDatasetFactory( self.collector_splits['val'], max_conditions_lengths=self.find_max_conditions_lengths( - 'val') + 'val'), automatic_batching=self.automatic_batching ) elif stage == 'test': self.test_dataset = PinaDatasetFactory( self.collector_splits['test'], max_conditions_lengths=self.find_max_conditions_lengths( - 'test') + 'test'), automatic_batching=self.automatic_batching ) elif stage == 'predict': self.predict_dataset = PinaDatasetFactory( self.collector_splits['predict'], max_conditions_lengths=self.find_max_conditions_lengths( - 'predict') + 'predict'), automatic_batching=self.automatic_batching ) else: raise ValueError( @@ -237,9 +237,9 @@ class PinaDataModule(LightningDataModule): self.val_dataset) # Use default batching in torch DataLoader (good is batch size is small) - if self.default_batching: + if self.automatic_batching: collate = Collator(self.find_max_conditions_lengths('val')) - return DataLoader(self.val_dataset, self.batch_size, + return DataLoader(self.val_dataset, batch_size, collate_fn=collate) collate = Collator(None) # Use custom batching (good if batch size is large) @@ -252,14 +252,16 @@ class PinaDataModule(LightningDataModule): Create the training dataloader """ # Use default batching in torch DataLoader (good is batch size is small) - if self.default_batching: + batch_size = self.batch_size if self.batch_size is not None else len( + self.train_dataset) + + if self.automatic_batching: collate = Collator(self.find_max_conditions_lengths('train')) - return DataLoader(self.train_dataset, self.batch_size, + return DataLoader(self.train_dataset, batch_size, collate_fn=collate) collate = Collator(None) # Use custom batching (good if batch size is large) - batch_size = self.batch_size if self.batch_size is not None else len( - self.train_dataset) + sampler = PinaBatchSampler(self.train_dataset, batch_size, shuffle=False) return DataLoader(self.train_dataset, sampler=sampler, diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 0bc9237..e5685f1 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -51,8 +51,12 @@ class PinaDataset(Dataset): class PinaTensorDataset(PinaDataset): def __init__(self, conditions_dict, max_conditions_lengths, - ): + automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) + if automatic_batching: + self._getitem_func = self._getitem_int + else: + self._getitem_func = self._getitem_list def _getitem_int(self, idx): return { @@ -72,9 +76,7 @@ class PinaTensorDataset(PinaDataset): return to_return_dict def __getitem__(self, idx): - if isinstance(idx, int): - return self._getitem_int(idx) - return self._getitem_list(idx) + return self._getitem_func(idx) class PinaGraphDataset(PinaDataset): pass diff --git a/pina/label_tensor.py b/pina/label_tensor.py index a3cf5d2..631c525 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -full_labels = True +full_labels = False MATH_FUNCTIONS = {torch.sin, torch.cos} class LabelTensor(torch.Tensor): diff --git a/pina/solvers/__init__.py b/pina/solvers/__init__.py index 59a1826..7bb988d 100644 --- a/pina/solvers/__init__.py +++ b/pina/solvers/__init__.py @@ -17,4 +17,3 @@ from .pinns import * from .supervised import SupervisedSolver from .rom import ReducedOrderModelSolver from .garom import GAROM -from .graph import GraphSupervisedSolver diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 66f4d14..f44ea1d 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -3,12 +3,10 @@ from abc import ABCMeta, abstractmethod import torch from torch.nn.modules.loss import _Loss -from ...condition import InputOutputPointsCondition from ...solvers.solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface from ...problem import InverseProblem -from ...condition import DomainEquationCondition from ...optim import TorchOptimizer, TorchScheduler torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 @@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): to the user to choose which problem the implemented solver inheriting from this class is suitable for. """ - accepted_condition_types = [DomainEquationCondition.condition_type[0], - InputOutputPointsCondition.condition_type[0]] + def __init__( self, models, diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py index 0888202..d1ab21d 100644 --- a/pina/solvers/pinns/pinn.py +++ b/pina/solvers/pinns/pinn.py @@ -11,7 +11,7 @@ except ImportError: from .basepinn import PINNInterface -from pina.problem import InverseProblem +from ...problem import InverseProblem class PINN(PINNInterface): diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 3a8f400..8052b4b 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -134,16 +134,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): return super().on_train_start() def _check_solver_consistency(self, problem): - """ - TODO - """ + pass + #TODO : Implement this method for the conditions + ''' + + for _, condition in problem.conditions.items(): if not set(condition.condition_type).issubset( set(self.accepted_condition_types)): raise ValueError( f'{self.__name__} dose not support condition ' f'{condition.condition_type}') - + ''' @staticmethod def get_batch_size(batch): # Assuming batch is your custom Batch object diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index 947ab3b..bce4b31 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -7,7 +7,6 @@ from .solver import SolverInterface from ..label_tensor import LabelTensor from ..utils import check_consistency from ..loss.loss_interface import LossInterface -from ..condition import InputOutputPointsCondition class SupervisedSolver(SolverInterface): @@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface): we are seeking to approximate multiple (discretised) functions given multiple (discretised) input functions. """ - accepted_condition_types = [InputOutputPointsCondition.condition_type[0]] __name__ = 'SupervisedSolver' def __init__(self, diff --git a/pina/trainer.py b/pina/trainer.py index a7c5c35..f8bccd8 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,5 +1,4 @@ """ Trainer module. """ -import warnings import torch import lightning from .utils import check_consistency