This commit is contained in:
Filippo Olivo
2024-11-28 11:06:38 +01:00
committed by Nicola Demo
parent 3c95441aac
commit f748b66194
9 changed files with 28 additions and 29 deletions

View File

@@ -95,7 +95,7 @@ class PinaDataModule(LightningDataModule):
logging.debug('Start initialization of Pina DataModule')
logging.info('Start initialization of Pina DataModule')
super().__init__()
self.default_batching = automatic_batching
self.automatic_batching = automatic_batching
self.batch_size = batch_size
self.shuffle = shuffle
self.repeat = repeat
@@ -133,24 +133,24 @@ class PinaDataModule(LightningDataModule):
self.train_dataset = PinaDatasetFactory(
self.collector_splits['train'],
max_conditions_lengths=self.find_max_conditions_lengths(
'train'))
'train'), automatic_batching=self.automatic_batching)
if 'val' in self.collector_splits.keys():
self.val_dataset = PinaDatasetFactory(
self.collector_splits['val'],
max_conditions_lengths=self.find_max_conditions_lengths(
'val')
'val'), automatic_batching=self.automatic_batching
)
elif stage == 'test':
self.test_dataset = PinaDatasetFactory(
self.collector_splits['test'],
max_conditions_lengths=self.find_max_conditions_lengths(
'test')
'test'), automatic_batching=self.automatic_batching
)
elif stage == 'predict':
self.predict_dataset = PinaDatasetFactory(
self.collector_splits['predict'],
max_conditions_lengths=self.find_max_conditions_lengths(
'predict')
'predict'), automatic_batching=self.automatic_batching
)
else:
raise ValueError(
@@ -237,9 +237,9 @@ class PinaDataModule(LightningDataModule):
self.val_dataset)
# Use default batching in torch DataLoader (good is batch size is small)
if self.default_batching:
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('val'))
return DataLoader(self.val_dataset, self.batch_size,
return DataLoader(self.val_dataset, batch_size,
collate_fn=collate)
collate = Collator(None)
# Use custom batching (good if batch size is large)
@@ -252,14 +252,16 @@ class PinaDataModule(LightningDataModule):
Create the training dataloader
"""
# Use default batching in torch DataLoader (good is batch size is small)
if self.default_batching:
batch_size = self.batch_size if self.batch_size is not None else len(
self.train_dataset)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('train'))
return DataLoader(self.train_dataset, self.batch_size,
return DataLoader(self.train_dataset, batch_size,
collate_fn=collate)
collate = Collator(None)
# Use custom batching (good if batch size is large)
batch_size = self.batch_size if self.batch_size is not None else len(
self.train_dataset)
sampler = PinaBatchSampler(self.train_dataset, batch_size,
shuffle=False)
return DataLoader(self.train_dataset, sampler=sampler,

View File

@@ -51,8 +51,12 @@ class PinaDataset(Dataset):
class PinaTensorDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths,
):
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list
def _getitem_int(self, idx):
return {
@@ -72,9 +76,7 @@ class PinaTensorDataset(PinaDataset):
return to_return_dict
def __getitem__(self, idx):
if isinstance(idx, int):
return self._getitem_int(idx)
return self._getitem_list(idx)
return self._getitem_func(idx)
class PinaGraphDataset(PinaDataset):
pass

View File

@@ -4,7 +4,7 @@ import torch
from torch import Tensor
full_labels = True
full_labels = False
MATH_FUNCTIONS = {torch.sin, torch.cos}
class LabelTensor(torch.Tensor):

View File

@@ -17,4 +17,3 @@ from .pinns import *
from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver
from .garom import GAROM
from .graph import GraphSupervisedSolver

View File

@@ -3,12 +3,10 @@
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""
accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]
def __init__(
self,
models,

View File

@@ -11,7 +11,7 @@ except ImportError:
from .basepinn import PINNInterface
from pina.problem import InverseProblem
from ...problem import InverseProblem
class PINN(PINNInterface):

View File

@@ -134,16 +134,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return super().on_train_start()
def _check_solver_consistency(self, problem):
"""
TODO
"""
pass
#TODO : Implement this method for the conditions
'''
for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')
'''
@staticmethod
def get_batch_size(batch):
# Assuming batch is your custom Batch object

View File

@@ -7,7 +7,6 @@ from .solver import SolverInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface):
@@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
"""
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
__name__ = 'SupervisedSolver'
def __init__(self,

View File

@@ -1,5 +1,4 @@
""" Trainer module. """
import warnings
import torch
import lightning
from .utils import check_consistency