Fix bugs (#387)
This commit is contained in:
committed by
Nicola Demo
parent
3c95441aac
commit
f748b66194
@@ -95,7 +95,7 @@ class PinaDataModule(LightningDataModule):
|
||||
logging.debug('Start initialization of Pina DataModule')
|
||||
logging.info('Start initialization of Pina DataModule')
|
||||
super().__init__()
|
||||
self.default_batching = automatic_batching
|
||||
self.automatic_batching = automatic_batching
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.repeat = repeat
|
||||
@@ -133,24 +133,24 @@ class PinaDataModule(LightningDataModule):
|
||||
self.train_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['train'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'train'))
|
||||
'train'), automatic_batching=self.automatic_batching)
|
||||
if 'val' in self.collector_splits.keys():
|
||||
self.val_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['val'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'val')
|
||||
'val'), automatic_batching=self.automatic_batching
|
||||
)
|
||||
elif stage == 'test':
|
||||
self.test_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['test'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'test')
|
||||
'test'), automatic_batching=self.automatic_batching
|
||||
)
|
||||
elif stage == 'predict':
|
||||
self.predict_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['predict'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'predict')
|
||||
'predict'), automatic_batching=self.automatic_batching
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -237,9 +237,9 @@ class PinaDataModule(LightningDataModule):
|
||||
self.val_dataset)
|
||||
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
if self.default_batching:
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('val'))
|
||||
return DataLoader(self.val_dataset, self.batch_size,
|
||||
return DataLoader(self.val_dataset, batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
# Use custom batching (good if batch size is large)
|
||||
@@ -252,14 +252,16 @@ class PinaDataModule(LightningDataModule):
|
||||
Create the training dataloader
|
||||
"""
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
if self.default_batching:
|
||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
||||
self.train_dataset)
|
||||
|
||||
if self.automatic_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('train'))
|
||||
return DataLoader(self.train_dataset, self.batch_size,
|
||||
return DataLoader(self.train_dataset, batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
# Use custom batching (good if batch size is large)
|
||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
||||
self.train_dataset)
|
||||
|
||||
sampler = PinaBatchSampler(self.train_dataset, batch_size,
|
||||
shuffle=False)
|
||||
return DataLoader(self.train_dataset, sampler=sampler,
|
||||
|
||||
@@ -51,8 +51,12 @@ class PinaDataset(Dataset):
|
||||
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||
):
|
||||
automatic_batching):
|
||||
super().__init__(conditions_dict, max_conditions_lengths)
|
||||
if automatic_batching:
|
||||
self._getitem_func = self._getitem_int
|
||||
else:
|
||||
self._getitem_func = self._getitem_list
|
||||
|
||||
def _getitem_int(self, idx):
|
||||
return {
|
||||
@@ -72,9 +76,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
return to_return_dict
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
return self._getitem_int(idx)
|
||||
return self._getitem_list(idx)
|
||||
return self._getitem_func(idx)
|
||||
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
pass
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
full_labels = True
|
||||
full_labels = False
|
||||
MATH_FUNCTIONS = {torch.sin, torch.cos}
|
||||
|
||||
class LabelTensor(torch.Tensor):
|
||||
|
||||
@@ -17,4 +17,3 @@ from .pinns import *
|
||||
from .supervised import SupervisedSolver
|
||||
from .rom import ReducedOrderModelSolver
|
||||
from .garom import GAROM
|
||||
from .graph import GraphSupervisedSolver
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from ...condition import InputOutputPointsCondition
|
||||
from ...solvers.solver import SolverInterface
|
||||
from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ...condition import DomainEquationCondition
|
||||
from ...optim import TorchOptimizer, TorchScheduler
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
@@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
to the user to choose which problem the implemented solver inheriting from
|
||||
this class is suitable for.
|
||||
"""
|
||||
accepted_condition_types = [DomainEquationCondition.condition_type[0],
|
||||
InputOutputPointsCondition.condition_type[0]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
|
||||
@@ -11,7 +11,7 @@ except ImportError:
|
||||
|
||||
|
||||
from .basepinn import PINNInterface
|
||||
from pina.problem import InverseProblem
|
||||
from ...problem import InverseProblem
|
||||
|
||||
|
||||
class PINN(PINNInterface):
|
||||
|
||||
@@ -134,16 +134,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
return super().on_train_start()
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
pass
|
||||
#TODO : Implement this method for the conditions
|
||||
'''
|
||||
|
||||
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(condition.condition_type).issubset(
|
||||
set(self.accepted_condition_types)):
|
||||
raise ValueError(
|
||||
f'{self.__name__} dose not support condition '
|
||||
f'{condition.condition_type}')
|
||||
|
||||
'''
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
# Assuming batch is your custom Batch object
|
||||
|
||||
@@ -7,7 +7,6 @@ from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
@@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface):
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
""" Trainer module. """
|
||||
import warnings
|
||||
import torch
|
||||
import lightning
|
||||
from .utils import check_consistency
|
||||
|
||||
Reference in New Issue
Block a user