Revert "Fix bugs (#385)" (#386)

This reverts commit 69cd0ed8cda91c92dab6551a0c6dfd94d199cee7.
This commit is contained in:
Dario Coscia
2024-11-27 19:58:42 +01:00
committed by Nicola Demo
parent 6da74cadd5
commit 3c95441aac
7 changed files with 14 additions and 9 deletions

View File

@@ -4,7 +4,7 @@ import torch
from torch import Tensor from torch import Tensor
full_labels = False full_labels = True
MATH_FUNCTIONS = {torch.sin, torch.cos} MATH_FUNCTIONS = {torch.sin, torch.cos}
class LabelTensor(torch.Tensor): class LabelTensor(torch.Tensor):

View File

@@ -17,3 +17,4 @@ from .pinns import *
from .supervised import SupervisedSolver from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver from .rom import ReducedOrderModelSolver
from .garom import GAROM from .garom import GAROM
from .graph import GraphSupervisedSolver

View File

@@ -3,10 +3,12 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface from ...solvers.solver import SolverInterface
from ...utils import check_consistency from ...utils import check_consistency
from ...loss.loss_interface import LossInterface from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler from ...optim import TorchOptimizer, TorchScheduler
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -24,7 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from to the user to choose which problem the implemented solver inheriting from
this class is suitable for. this class is suitable for.
""" """
accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]
def __init__( def __init__(
self, self,
models, models,

View File

@@ -11,7 +11,7 @@ except ImportError:
from .basepinn import PINNInterface from .basepinn import PINNInterface
from ...problem import InverseProblem from pina.problem import InverseProblem
class PINN(PINNInterface): class PINN(PINNInterface):

View File

@@ -134,18 +134,16 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return super().on_train_start() return super().on_train_start()
def _check_solver_consistency(self, problem): def _check_solver_consistency(self, problem):
pass """
#TODO : Implement this method for the conditions TODO
''' """
for _, condition in problem.conditions.items(): for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset( if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)): set(self.accepted_condition_types)):
raise ValueError( raise ValueError(
f'{self.__name__} dose not support condition ' f'{self.__name__} dose not support condition '
f'{condition.condition_type}') f'{condition.condition_type}')
'''
@staticmethod @staticmethod
def get_batch_size(batch): def get_batch_size(batch):
# Assuming batch is your custom Batch object # Assuming batch is your custom Batch object

View File

@@ -7,6 +7,7 @@ from .solver import SolverInterface
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..utils import check_consistency from ..utils import check_consistency
from ..loss.loss_interface import LossInterface from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface): class SupervisedSolver(SolverInterface):
@@ -37,6 +38,7 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions. multiple (discretised) input functions.
""" """
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
__name__ = 'SupervisedSolver' __name__ = 'SupervisedSolver'
def __init__(self, def __init__(self,

View File

@@ -1,4 +1,5 @@
""" Trainer module. """ """ Trainer module. """
import warnings
import torch import torch
import lightning import lightning
from .utils import check_consistency from .utils import check_consistency