Add check conditions-solver consistency
This commit is contained in:
committed by
Nicola Demo
parent
f2340cd4ee
commit
a6f0336d06
@@ -8,6 +8,8 @@ from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ...optim import TorchOptimizer, TorchScheduler
|
||||
from ...condition import InputOutputPointsCondition, \
|
||||
InputPointsEquationCondition, DomainEquationCondition
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
@@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
to the user to choose which problem the implemented solver inheriting from
|
||||
this class is suitable for.
|
||||
"""
|
||||
accepted_conditions_types = (InputOutputPointsCondition,
|
||||
InputPointsEquationCondition, DomainEquationCondition)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,7 +101,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
|
||||
|
||||
def training_step(self, batch):
|
||||
"""
|
||||
The Physics Informed Solver Training Step. This function takes care
|
||||
@@ -117,14 +120,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
loss_ = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
loss_ = self.loss_phys(
|
||||
input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
@@ -144,14 +149,16 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
for condition_name, points in batch:
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
loss_ = self.loss_data(
|
||||
input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
with torch.set_grad_enabled(True):
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
loss_ = self.loss_phys(
|
||||
input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
|
||||
@@ -10,7 +10,6 @@ import torch
|
||||
import sys
|
||||
|
||||
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver base class. This class inherits is a wrapper of
|
||||
@@ -133,23 +132,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
return super().on_train_start()
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
pass
|
||||
#TODO : Implement this method for the conditions
|
||||
'''
|
||||
|
||||
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(condition.condition_type).issubset(
|
||||
set(self.accepted_condition_types)):
|
||||
raise ValueError(
|
||||
f'{self.__name__} dose not support condition '
|
||||
f'{condition.condition_type}')
|
||||
'''
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
# Assuming batch is your custom Batch object
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]['input_points'])
|
||||
return batch_size
|
||||
return batch_size
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
for condition in problem.conditions.values():
|
||||
check_consistency(condition, self.accepted_conditions_types)
|
||||
|
||||
@@ -7,6 +7,7 @@ from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
@@ -37,7 +38,8 @@ class SupervisedSolver(SolverInterface):
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
accepted_conditions_types = InputOutputPointsCondition
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
@@ -143,7 +145,6 @@ class SupervisedSolver(SolverInterface):
|
||||
self.log('val_loss', loss, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
||||
"""
|
||||
Solver test step.
|
||||
|
||||
Reference in New Issue
Block a user