This reverts commit 69cd0ed8cda91c92dab6551a0c6dfd94d199cee7.
This commit is contained in:
committed by
Nicola Demo
parent
6da74cadd5
commit
3c95441aac
@@ -4,7 +4,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
full_labels = False
|
full_labels = True
|
||||||
MATH_FUNCTIONS = {torch.sin, torch.cos}
|
MATH_FUNCTIONS = {torch.sin, torch.cos}
|
||||||
|
|
||||||
class LabelTensor(torch.Tensor):
|
class LabelTensor(torch.Tensor):
|
||||||
|
|||||||
@@ -17,3 +17,4 @@ from .pinns import *
|
|||||||
from .supervised import SupervisedSolver
|
from .supervised import SupervisedSolver
|
||||||
from .rom import ReducedOrderModelSolver
|
from .rom import ReducedOrderModelSolver
|
||||||
from .garom import GAROM
|
from .garom import GAROM
|
||||||
|
from .graph import GraphSupervisedSolver
|
||||||
|
|||||||
@@ -3,10 +3,12 @@
|
|||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from ...condition import InputOutputPointsCondition
|
||||||
from ...solvers.solver import SolverInterface
|
from ...solvers.solver import SolverInterface
|
||||||
from ...utils import check_consistency
|
from ...utils import check_consistency
|
||||||
from ...loss.loss_interface import LossInterface
|
from ...loss.loss_interface import LossInterface
|
||||||
from ...problem import InverseProblem
|
from ...problem import InverseProblem
|
||||||
|
from ...condition import DomainEquationCondition
|
||||||
from ...optim import TorchOptimizer, TorchScheduler
|
from ...optim import TorchOptimizer, TorchScheduler
|
||||||
|
|
||||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||||
@@ -24,7 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
to the user to choose which problem the implemented solver inheriting from
|
to the user to choose which problem the implemented solver inheriting from
|
||||||
this class is suitable for.
|
this class is suitable for.
|
||||||
"""
|
"""
|
||||||
|
accepted_condition_types = [DomainEquationCondition.condition_type[0],
|
||||||
|
InputOutputPointsCondition.condition_type[0]]
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
models,
|
models,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
from .basepinn import PINNInterface
|
from .basepinn import PINNInterface
|
||||||
from ...problem import InverseProblem
|
from pina.problem import InverseProblem
|
||||||
|
|
||||||
|
|
||||||
class PINN(PINNInterface):
|
class PINN(PINNInterface):
|
||||||
|
|||||||
@@ -134,18 +134,16 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
return super().on_train_start()
|
return super().on_train_start()
|
||||||
|
|
||||||
def _check_solver_consistency(self, problem):
|
def _check_solver_consistency(self, problem):
|
||||||
pass
|
"""
|
||||||
#TODO : Implement this method for the conditions
|
TODO
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
for _, condition in problem.conditions.items():
|
for _, condition in problem.conditions.items():
|
||||||
if not set(condition.condition_type).issubset(
|
if not set(condition.condition_type).issubset(
|
||||||
set(self.accepted_condition_types)):
|
set(self.accepted_condition_types)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'{self.__name__} dose not support condition '
|
f'{self.__name__} dose not support condition '
|
||||||
f'{condition.condition_type}')
|
f'{condition.condition_type}')
|
||||||
'''
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_batch_size(batch):
|
def get_batch_size(batch):
|
||||||
# Assuming batch is your custom Batch object
|
# Assuming batch is your custom Batch object
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .solver import SolverInterface
|
|||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from ..loss.loss_interface import LossInterface
|
from ..loss.loss_interface import LossInterface
|
||||||
|
from ..condition import InputOutputPointsCondition
|
||||||
|
|
||||||
|
|
||||||
class SupervisedSolver(SolverInterface):
|
class SupervisedSolver(SolverInterface):
|
||||||
@@ -37,6 +38,7 @@ class SupervisedSolver(SolverInterface):
|
|||||||
we are seeking to approximate multiple (discretised) functions given
|
we are seeking to approximate multiple (discretised) functions given
|
||||||
multiple (discretised) input functions.
|
multiple (discretised) input functions.
|
||||||
"""
|
"""
|
||||||
|
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
|
||||||
__name__ = 'SupervisedSolver'
|
__name__ = 'SupervisedSolver'
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
""" Trainer module. """
|
""" Trainer module. """
|
||||||
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
import lightning
|
import lightning
|
||||||
from .utils import check_consistency
|
from .utils import check_consistency
|
||||||
|
|||||||
Reference in New Issue
Block a user