🎨 Format Python code with psf/black (#297)
Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e0429bb445
commit
9463ae4b15
@@ -97,25 +97,27 @@ class CausalPINN(PINN):
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
:param int | float eps: The exponential decay parameter. Note that this
|
||||
value is kept fixed during the training, but can be changed by means
|
||||
of a callback, e.g. for annealing.
|
||||
of a callback, e.g. for annealing.
|
||||
"""
|
||||
super().__init__(
|
||||
problem=problem,
|
||||
model=model,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
problem=problem,
|
||||
model=model,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
)
|
||||
|
||||
# checking consistency
|
||||
check_consistency(eps, (int,float))
|
||||
check_consistency(eps, (int, float))
|
||||
self._eps = eps
|
||||
if not isinstance(self.problem, TimeDependentProblem):
|
||||
raise ValueError('Casual PINN works only for problems'
|
||||
'inheritig from TimeDependentProblem.')
|
||||
raise ValueError(
|
||||
"Casual PINN works only for problems"
|
||||
"inheritig from TimeDependentProblem."
|
||||
)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
@@ -144,14 +146,14 @@ class CausalPINN(PINN):
|
||||
)
|
||||
time_loss.append(loss_val)
|
||||
# store results
|
||||
self.store_log(loss_value=float(sum(time_loss)/len(time_loss)))
|
||||
self.store_log(loss_value=float(sum(time_loss) / len(time_loss)))
|
||||
# concatenate residuals
|
||||
time_loss = torch.stack(time_loss)
|
||||
# compute weights (without the gradient storing)
|
||||
with torch.no_grad():
|
||||
weights = self._compute_weights(time_loss)
|
||||
return (weights * time_loss).mean()
|
||||
|
||||
|
||||
@property
|
||||
def eps(self):
|
||||
"""
|
||||
@@ -205,8 +207,8 @@ class CausalPINN(PINN):
|
||||
_, idx_split = time_tensor.unique(return_counts=True)
|
||||
# splitting
|
||||
chunks = torch.split(tensor, tuple(idx_split))
|
||||
return chunks, labels # return chunks
|
||||
|
||||
return chunks, labels # return chunks
|
||||
|
||||
def _compute_weights(self, loss):
|
||||
"""
|
||||
Computes the weights for the physics loss based on the cumulative loss.
|
||||
|
||||
Reference in New Issue
Block a user