🎨 Format Python code with psf/black (#297)

Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2024-05-10 14:08:01 +02:00
committed by GitHub
parent e0429bb445
commit 9463ae4b15
11 changed files with 169 additions and 160 deletions

View File

@@ -97,25 +97,27 @@ class CausalPINN(PINN):
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
:param int | float eps: The exponential decay parameter. Note that this
value is kept fixed during the training, but can be changed by means
of a callback, e.g. for annealing.
of a callback, e.g. for annealing.
"""
super().__init__(
problem=problem,
model=model,
extra_features=extra_features,
loss=loss,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
problem=problem,
model=model,
extra_features=extra_features,
loss=loss,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
)
# checking consistency
check_consistency(eps, (int,float))
check_consistency(eps, (int, float))
self._eps = eps
if not isinstance(self.problem, TimeDependentProblem):
raise ValueError('Casual PINN works only for problems'
'inheritig from TimeDependentProblem.')
raise ValueError(
"Casual PINN works only for problems"
"inheritig from TimeDependentProblem."
)
def loss_phys(self, samples, equation):
"""
@@ -144,14 +146,14 @@ class CausalPINN(PINN):
)
time_loss.append(loss_val)
# store results
self.store_log(loss_value=float(sum(time_loss)/len(time_loss)))
self.store_log(loss_value=float(sum(time_loss) / len(time_loss)))
# concatenate residuals
time_loss = torch.stack(time_loss)
# compute weights (without the gradient storing)
with torch.no_grad():
weights = self._compute_weights(time_loss)
return (weights * time_loss).mean()
@property
def eps(self):
"""
@@ -205,8 +207,8 @@ class CausalPINN(PINN):
_, idx_split = time_tensor.unique(return_counts=True)
# splitting
chunks = torch.split(tensor, tuple(idx_split))
return chunks, labels # return chunks
return chunks, labels # return chunks
def _compute_weights(self, loss):
"""
Computes the weights for the physics loss based on the cumulative loss.