Add check conditions-solver consistency
This commit is contained in:
committed by
Nicola Demo
parent
f2340cd4ee
commit
a6f0336d06
@@ -8,6 +8,8 @@ from ...utils import check_consistency
|
|||||||
from ...loss.loss_interface import LossInterface
|
from ...loss.loss_interface import LossInterface
|
||||||
from ...problem import InverseProblem
|
from ...problem import InverseProblem
|
||||||
from ...optim import TorchOptimizer, TorchScheduler
|
from ...optim import TorchOptimizer, TorchScheduler
|
||||||
|
from ...condition import InputOutputPointsCondition, \
|
||||||
|
InputPointsEquationCondition, DomainEquationCondition
|
||||||
|
|
||||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||||
|
|
||||||
@@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
to the user to choose which problem the implemented solver inheriting from
|
to the user to choose which problem the implemented solver inheriting from
|
||||||
this class is suitable for.
|
this class is suitable for.
|
||||||
"""
|
"""
|
||||||
|
accepted_conditions_types = (InputOutputPointsCondition,
|
||||||
|
InputPointsEquationCondition, DomainEquationCondition)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -97,7 +101,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
self._optimizer = self._pina_optimizers[0]
|
self._optimizer = self._pina_optimizers[0]
|
||||||
self._scheduler = self._pina_schedulers[0]
|
self._scheduler = self._pina_schedulers[0]
|
||||||
|
|
||||||
|
|
||||||
def training_step(self, batch):
|
def training_step(self, batch):
|
||||||
"""
|
"""
|
||||||
The Physics Informed Solver Training Step. This function takes care
|
The Physics Informed Solver Training Step. This function takes care
|
||||||
@@ -117,14 +120,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
if 'output_points' in points:
|
if 'output_points' in points:
|
||||||
input_pts, output_pts = points['input_points'], points['output_points']
|
input_pts, output_pts = points['input_points'], points['output_points']
|
||||||
|
|
||||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
loss_ = self.loss_data(
|
||||||
|
input_pts=input_pts, output_pts=output_pts)
|
||||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||||
else:
|
else:
|
||||||
input_pts = points['input_points']
|
input_pts = points['input_points']
|
||||||
|
|
||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
|
|
||||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
loss_ = self.loss_phys(
|
||||||
|
input_pts.requires_grad_(), condition.equation)
|
||||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||||
# clamp unknown parameters in InverseProblem (if needed)
|
# clamp unknown parameters in InverseProblem (if needed)
|
||||||
@@ -144,14 +149,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
for condition_name, points in batch:
|
for condition_name, points in batch:
|
||||||
if 'output_points' in points:
|
if 'output_points' in points:
|
||||||
input_pts, output_pts = points['input_points'], points['output_points']
|
input_pts, output_pts = points['input_points'], points['output_points']
|
||||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
loss_ = self.loss_data(
|
||||||
|
input_pts=input_pts, output_pts=output_pts)
|
||||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||||
else:
|
else:
|
||||||
input_pts = points['input_points']
|
input_pts = points['input_points']
|
||||||
|
|
||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
loss_ = self.loss_phys(
|
||||||
|
input_pts.requires_grad_(), condition.equation)
|
||||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||||
# clamp unknown parameters in InverseProblem (if needed)
|
# clamp unknown parameters in InverseProblem (if needed)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import torch
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Solver base class. This class inherits is a wrapper of
|
Solver base class. This class inherits is a wrapper of
|
||||||
@@ -133,23 +132,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
return super().on_train_start()
|
return super().on_train_start()
|
||||||
|
|
||||||
def _check_solver_consistency(self, problem):
|
|
||||||
pass
|
|
||||||
#TODO : Implement this method for the conditions
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
for _, condition in problem.conditions.items():
|
|
||||||
if not set(condition.condition_type).issubset(
|
|
||||||
set(self.accepted_condition_types)):
|
|
||||||
raise ValueError(
|
|
||||||
f'{self.__name__} dose not support condition '
|
|
||||||
f'{condition.condition_type}')
|
|
||||||
'''
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_batch_size(batch):
|
def get_batch_size(batch):
|
||||||
# Assuming batch is your custom Batch object
|
# Assuming batch is your custom Batch object
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
for data in batch:
|
for data in batch:
|
||||||
batch_size += len(data[1]['input_points'])
|
batch_size += len(data[1]['input_points'])
|
||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
|
def _check_solver_consistency(self, problem):
|
||||||
|
for condition in problem.conditions.values():
|
||||||
|
check_consistency(condition, self.accepted_conditions_types)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .solver import SolverInterface
|
|||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from ..loss.loss_interface import LossInterface
|
from ..loss.loss_interface import LossInterface
|
||||||
|
from ..condition import InputOutputPointsCondition
|
||||||
|
|
||||||
|
|
||||||
class SupervisedSolver(SolverInterface):
|
class SupervisedSolver(SolverInterface):
|
||||||
@@ -37,7 +38,8 @@ class SupervisedSolver(SolverInterface):
|
|||||||
we are seeking to approximate multiple (discretised) functions given
|
we are seeking to approximate multiple (discretised) functions given
|
||||||
multiple (discretised) input functions.
|
multiple (discretised) input functions.
|
||||||
"""
|
"""
|
||||||
__name__ = 'SupervisedSolver'
|
|
||||||
|
accepted_conditions_types = InputOutputPointsCondition
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
problem,
|
problem,
|
||||||
@@ -143,7 +145,6 @@ class SupervisedSolver(SolverInterface):
|
|||||||
self.log('val_loss', loss, prog_bar=True, logger=True,
|
self.log('val_loss', loss, prog_bar=True, logger=True,
|
||||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||||
|
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
||||||
"""
|
"""
|
||||||
Solver test step.
|
Solver test step.
|
||||||
|
|||||||
Reference in New Issue
Block a user