Fixing self adaptive pinns (#469)
* fix self adaptive pinn * clean competitive pinn
This commit is contained in:
committed by
Nicola Demo
parent
c3aaf5b1a0
commit
375f7f8e2d
@@ -160,22 +160,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
self._train_discriminator(samples, equation, discriminator_bets)
|
||||
return loss_val
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the CompetitivePINN solver. It computes the loss
|
||||
between the network output against the true solution.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
:return: The computed data loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss_val = super().loss_data(input_pts, output_pts)
|
||||
# prepare for optimizer step called in training step
|
||||
loss_val.backward()
|
||||
return loss_val
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the Competitive PINN solver.
|
||||
@@ -252,7 +236,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
)
|
||||
# prepare for optimizer step called in training step
|
||||
self.manual_backward(loss_val)
|
||||
return
|
||||
|
||||
def _train_model(self, samples, equation, discriminator_bets):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user