Fixing self adaptive pinns (#469)
* fix self adaptive pinn * clean competitive pinn
This commit is contained in:
committed by
Nicola Demo
parent
c3aaf5b1a0
commit
375f7f8e2d
@@ -160,22 +160,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
self._train_discriminator(samples, equation, discriminator_bets)
|
self._train_discriminator(samples, equation, discriminator_bets)
|
||||||
return loss_val
|
return loss_val
|
||||||
|
|
||||||
def loss_data(self, input_pts, output_pts):
|
|
||||||
"""
|
|
||||||
The data loss for the CompetitivePINN solver. It computes the loss
|
|
||||||
between the network output against the true solution.
|
|
||||||
|
|
||||||
:param LabelTensor input_tensor: The input to the neural networks.
|
|
||||||
:param LabelTensor output_tensor: The true solution to compare the
|
|
||||||
network solution.
|
|
||||||
:return: The computed data loss.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
loss_val = super().loss_data(input_pts, output_pts)
|
|
||||||
# prepare for optimizer step called in training step
|
|
||||||
loss_val.backward()
|
|
||||||
return loss_val
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""
|
"""
|
||||||
Optimizer configuration for the Competitive PINN solver.
|
Optimizer configuration for the Competitive PINN solver.
|
||||||
@@ -252,7 +236,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
)
|
)
|
||||||
# prepare for optimizer step called in training step
|
# prepare for optimizer step called in training step
|
||||||
self.manual_backward(loss_val)
|
self.manual_backward(loss_val)
|
||||||
return
|
|
||||||
|
|
||||||
def _train_model(self, samples, equation, discriminator_bets):
|
def _train_model(self, samples, equation, discriminator_bets):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -178,63 +178,20 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
:return: The sum of the loss functions.
|
:return: The sum of the loss functions.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
self.optimizer_model.instance.zero_grad()
|
# Weights optimization
|
||||||
self.optimizer_weights.instance.zero_grad()
|
self.optimizer_weights.instance.zero_grad()
|
||||||
loss = super().training_step(batch)
|
loss = super().training_step(batch)
|
||||||
self.optimizer_model.instance.step()
|
self.manual_backward(-loss)
|
||||||
self.optimizer_weights.instance.step()
|
self.optimizer_weights.instance.step()
|
||||||
|
|
||||||
|
# Model optimization
|
||||||
|
self.optimizer_model.instance.zero_grad()
|
||||||
|
loss = super().training_step(batch)
|
||||||
|
self.manual_backward(loss)
|
||||||
|
self.optimizer_model.instance.step()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def loss_phys(self, samples, equation):
|
|
||||||
"""
|
|
||||||
Computes the physics loss for the SAPINN solver based on given
|
|
||||||
samples and equation.
|
|
||||||
|
|
||||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
|
||||||
:param EquationInterface equation: The governing equation
|
|
||||||
representing the physics.
|
|
||||||
:return: The physics loss calculated based on given
|
|
||||||
samples and equation.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
# Train the weights
|
|
||||||
weighted_loss = self._loss_phys(samples, equation)
|
|
||||||
loss_value = -weighted_loss.as_subclass(torch.Tensor)
|
|
||||||
self.manual_backward(loss_value)
|
|
||||||
|
|
||||||
# Detach samples from the existing computational graph and
|
|
||||||
# create a new one by setting requires_grad to True.
|
|
||||||
# In alternative set `retain_graph=True`.
|
|
||||||
samples = samples.detach()
|
|
||||||
samples.requires_grad_() # = True
|
|
||||||
|
|
||||||
# Train the model
|
|
||||||
weighted_loss = self._loss_phys(samples, equation)
|
|
||||||
loss_value = weighted_loss.as_subclass(torch.Tensor)
|
|
||||||
self.manual_backward(loss_value)
|
|
||||||
|
|
||||||
return loss_value
|
|
||||||
|
|
||||||
def loss_data(self, input_pts, output_pts):
|
|
||||||
"""
|
|
||||||
Computes the data loss for the SAPINN solver based on input and
|
|
||||||
output. It computes the loss between the
|
|
||||||
network output against the true solution.
|
|
||||||
|
|
||||||
:param LabelTensor input_pts: The input to the neural networks.
|
|
||||||
:param LabelTensor output_pts: The true solution to compare the
|
|
||||||
network solution.
|
|
||||||
:return: The computed data loss.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
residual = self.forward(input_pts) - output_pts
|
|
||||||
loss = self._vectorial_loss(
|
|
||||||
torch.zeros_like(residual, requires_grad=True), residual
|
|
||||||
)
|
|
||||||
loss_value = self._vect_to_scalar(loss).as_subclass(torch.Tensor)
|
|
||||||
self.manual_backward(loss_value)
|
|
||||||
return loss_value
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""
|
"""
|
||||||
Optimizer configuration for the SelfAdaptive PINN solver.
|
Optimizer configuration for the SelfAdaptive PINN solver.
|
||||||
@@ -330,7 +287,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
)
|
)
|
||||||
return super().on_load_checkpoint(checkpoint)
|
return super().on_load_checkpoint(checkpoint)
|
||||||
|
|
||||||
def _loss_phys(self, samples, equation):
|
def loss_phys(self, samples, equation):
|
||||||
"""
|
"""
|
||||||
Computation of the physical loss for SelfAdaptive PINN solver.
|
Computation of the physical loss for SelfAdaptive PINN solver.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user