This commit is contained in:
Filippo Olivo
2024-11-27 19:30:12 +01:00
committed by Nicola Demo
parent b52112e448
commit 6da74cadd5
7 changed files with 9 additions and 14 deletions

View File

@@ -4,7 +4,7 @@ import torch
from torch import Tensor
full_labels = True
full_labels = False
MATH_FUNCTIONS = {torch.sin, torch.cos}
class LabelTensor(torch.Tensor):

View File

@@ -17,4 +17,3 @@ from .pinns import *
from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver
from .garom import GAROM
from .graph import GraphSupervisedSolver

View File

@@ -3,12 +3,10 @@
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""
accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]
def __init__(
self,
models,

View File

@@ -11,7 +11,7 @@ except ImportError:
from .basepinn import PINNInterface
from pina.problem import InverseProblem
from ...problem import InverseProblem
class PINN(PINNInterface):

View File

@@ -134,16 +134,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return super().on_train_start()
def _check_solver_consistency(self, problem):
"""
TODO
"""
pass
#TODO : Implement this method for the conditions
'''
for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')
'''
@staticmethod
def get_batch_size(batch):
# Assuming batch is your custom Batch object

View File

@@ -7,7 +7,6 @@ from .solver import SolverInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition
class SupervisedSolver(SolverInterface):
@@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
"""
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
__name__ = 'SupervisedSolver'
def __init__(self,

View File

@@ -1,5 +1,4 @@
""" Trainer module. """
import warnings
import torch
import lightning
from .utils import check_consistency