From 6da74cadd5c5f0c1dd9880a33513c8735ecf2759 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Wed, 27 Nov 2024 19:30:12 +0100 Subject: [PATCH] Fix bugs (#385) --- pina/label_tensor.py | 2 +- pina/solvers/__init__.py | 1 - pina/solvers/pinns/basepinn.py | 5 +---- pina/solvers/pinns/pinn.py | 2 +- pina/solvers/solver.py | 10 ++++++---- pina/solvers/supervised.py | 2 -- pina/trainer.py | 1 - 7 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index a3cf5d2..631c525 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -full_labels = True +full_labels = False MATH_FUNCTIONS = {torch.sin, torch.cos} class LabelTensor(torch.Tensor): diff --git a/pina/solvers/__init__.py b/pina/solvers/__init__.py index 59a1826..7bb988d 100644 --- a/pina/solvers/__init__.py +++ b/pina/solvers/__init__.py @@ -17,4 +17,3 @@ from .pinns import * from .supervised import SupervisedSolver from .rom import ReducedOrderModelSolver from .garom import GAROM -from .graph import GraphSupervisedSolver diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 66f4d14..f44ea1d 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -3,12 +3,10 @@ from abc import ABCMeta, abstractmethod import torch from torch.nn.modules.loss import _Loss -from ...condition import InputOutputPointsCondition from ...solvers.solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface from ...problem import InverseProblem -from ...condition import DomainEquationCondition from ...optim import TorchOptimizer, TorchScheduler torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 @@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): to the user to choose which problem the implemented solver inheriting from this class is suitable for. """ - accepted_condition_types = [DomainEquationCondition.condition_type[0], - InputOutputPointsCondition.condition_type[0]] + def __init__( self, models, diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py index 0888202..d1ab21d 100644 --- a/pina/solvers/pinns/pinn.py +++ b/pina/solvers/pinns/pinn.py @@ -11,7 +11,7 @@ except ImportError: from .basepinn import PINNInterface -from pina.problem import InverseProblem +from ...problem import InverseProblem class PINN(PINNInterface): diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 3a8f400..8052b4b 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -134,16 +134,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): return super().on_train_start() def _check_solver_consistency(self, problem): - """ - TODO - """ + pass + #TODO : Implement this method for the conditions + ''' + + for _, condition in problem.conditions.items(): if not set(condition.condition_type).issubset( set(self.accepted_condition_types)): raise ValueError( f'{self.__name__} dose not support condition ' f'{condition.condition_type}') - + ''' @staticmethod def get_batch_size(batch): # Assuming batch is your custom Batch object diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index 947ab3b..bce4b31 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -7,7 +7,6 @@ from .solver import SolverInterface from ..label_tensor import LabelTensor from ..utils import check_consistency from ..loss.loss_interface import LossInterface -from ..condition import InputOutputPointsCondition class SupervisedSolver(SolverInterface): @@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface): we are seeking to approximate multiple (discretised) functions given multiple (discretised) input functions. """ - accepted_condition_types = [InputOutputPointsCondition.condition_type[0]] __name__ = 'SupervisedSolver' def __init__(self, diff --git a/pina/trainer.py b/pina/trainer.py index a7c5c35..f8bccd8 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,5 +1,4 @@ """ Trainer module. """ -import warnings import torch import lightning from .utils import check_consistency