This commit is contained in:
Filippo Olivo
2024-11-27 19:30:12 +01:00
committed by Nicola Demo
parent b52112e448
commit 6da74cadd5
7 changed files with 9 additions and 14 deletions

View File

@@ -4,7 +4,7 @@ import torch
from torch import Tensor from torch import Tensor
full_labels = True full_labels = False
MATH_FUNCTIONS = {torch.sin, torch.cos} MATH_FUNCTIONS = {torch.sin, torch.cos}
class LabelTensor(torch.Tensor): class LabelTensor(torch.Tensor):

View File

@@ -17,4 +17,3 @@ from .pinns import *
from .supervised import SupervisedSolver from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver from .rom import ReducedOrderModelSolver
from .garom import GAROM from .garom import GAROM
from .graph import GraphSupervisedSolver

View File

@@ -3,12 +3,10 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface from ...solvers.solver import SolverInterface
from ...utils import check_consistency from ...utils import check_consistency
from ...loss.loss_interface import LossInterface from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler from ...optim import TorchOptimizer, TorchScheduler
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from to the user to choose which problem the implemented solver inheriting from
this class is suitable for. this class is suitable for.
""" """
accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]
def __init__( def __init__(
self, self,
models, models,

View File

@@ -11,7 +11,7 @@ except ImportError:
from .basepinn import PINNInterface from .basepinn import PINNInterface
from pina.problem import InverseProblem from ...problem import InverseProblem
class PINN(PINNInterface): class PINN(PINNInterface):

View File

@@ -134,16 +134,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return super().on_train_start() return super().on_train_start()
def _check_solver_consistency(self, problem): def _check_solver_consistency(self, problem):
""" pass
TODO #TODO : Implement this method for the conditions
""" '''
for _, condition in problem.conditions.items(): for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset( if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)): set(self.accepted_condition_types)):
raise ValueError( raise ValueError(
f'{self.__name__} dose not support condition ' f'{self.__name__} dose not support condition '
f'{condition.condition_type}') f'{condition.condition_type}')
'''
@staticmethod @staticmethod
def get_batch_size(batch): def get_batch_size(batch):
# Assuming batch is your custom Batch object # Assuming batch is your custom Batch object

View File

@@ -7,7 +7,6 @@ from .solver import SolverInterface
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..utils import check_consistency from ..utils import check_consistency
from ..loss.loss_interface import LossInterface from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface): class SupervisedSolver(SolverInterface):
@@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions. multiple (discretised) input functions.
""" """
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
__name__ = 'SupervisedSolver' __name__ = 'SupervisedSolver'
def __init__(self, def __init__(self,

View File

@@ -1,5 +1,4 @@
""" Trainer module. """ """ Trainer module. """
import warnings
import torch import torch
import lightning import lightning
from .utils import check_consistency from .utils import check_consistency