🎨 Format Python code with psf/black (#297)
Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e0429bb445
commit
9463ae4b15
@@ -90,22 +90,23 @@ class GPINN(PINN):
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
super().__init__(
|
||||
problem=problem,
|
||||
model=model,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
problem=problem,
|
||||
model=model,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
)
|
||||
if not isinstance(self.problem, SpatialProblem):
|
||||
raise ValueError('Gradient PINN computes the gradient of the '
|
||||
'PINN loss with respect to the spatial '
|
||||
'coordinates, thus the PINA problem must be '
|
||||
'a SpatialProblem.')
|
||||
raise ValueError(
|
||||
"Gradient PINN computes the gradient of the "
|
||||
"PINN loss with respect to the spatial "
|
||||
"coordinates, thus the PINA problem must be "
|
||||
"a SpatialProblem."
|
||||
)
|
||||
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the GPINN solver based on given
|
||||
@@ -126,9 +127,9 @@ class GPINN(PINN):
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
# gradient PINN loss
|
||||
loss_value = loss_value.reshape(-1, 1)
|
||||
loss_value.labels = ['__LOSS']
|
||||
loss_value.labels = ["__LOSS"]
|
||||
loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables)
|
||||
g_loss_phys = self.loss(
|
||||
torch.zeros_like(loss_grad, requires_grad=True), loss_grad
|
||||
)
|
||||
return loss_value + g_loss_phys
|
||||
return loss_value + g_loss_phys
|
||||
|
||||
Reference in New Issue
Block a user