🎨 Format Python code with psf/black (#297)
Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e0429bb445
commit
9463ae4b15
@@ -16,4 +16,3 @@ from .pinns import *
|
||||
from .supervised import SupervisedSolver
|
||||
from .rom import ReducedOrderModelSolver
|
||||
from .garom import GAROM
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from torch.nn.modules.loss import _Loss
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
|
||||
class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Base PINN solver class. This class implements the Solver Interface
|
||||
@@ -195,7 +196,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param torch.Tensor loss_value: The value of the loss.
|
||||
"""
|
||||
self.log(
|
||||
self.__logged_metric+'_loss',
|
||||
self.__logged_metric + "_loss",
|
||||
loss_value,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
@@ -211,7 +212,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
if self.__logged_res_losses:
|
||||
# storing mean loss
|
||||
self.__logged_metric = 'mean'
|
||||
self.__logged_metric = "mean"
|
||||
self.store_log(
|
||||
sum(self.__logged_res_losses) / len(self.__logged_res_losses)
|
||||
)
|
||||
|
||||
@@ -114,8 +114,10 @@ class CausalPINN(PINN):
|
||||
check_consistency(eps, (int, float))
|
||||
self._eps = eps
|
||||
if not isinstance(self.problem, TimeDependentProblem):
|
||||
raise ValueError('Casual PINN works only for problems'
|
||||
'inheritig from TimeDependentProblem.')
|
||||
raise ValueError(
|
||||
"Casual PINN works only for problems"
|
||||
"inheritig from TimeDependentProblem."
|
||||
)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
|
||||
@@ -117,7 +117,7 @@ class CompetitivePINN(PINNInterface):
|
||||
optimizer_discriminator_kwargs,
|
||||
],
|
||||
extra_features=None, # CompetitivePINN doesn't take extra features
|
||||
loss=loss
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# set automatic optimization for GANs
|
||||
@@ -131,9 +131,7 @@ class CompetitivePINN(PINNInterface):
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [
|
||||
scheduler_model(
|
||||
self.optimizers[0], **scheduler_model_kwargs
|
||||
),
|
||||
scheduler_model(self.optimizers[0], **scheduler_model_kwargs),
|
||||
scheduler_discriminator(
|
||||
self.optimizers[1], **scheduler_discriminator_kwargs
|
||||
),
|
||||
@@ -195,8 +193,11 @@ class CompetitivePINN(PINNInterface):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
self.optimizer_model.zero_grad()
|
||||
loss_val = super().loss_data(
|
||||
input_tensor, output_tensor).as_subclass(torch.Tensor)
|
||||
loss_val = (
|
||||
super()
|
||||
.loss_data(input_tensor, output_tensor)
|
||||
.as_subclass(torch.Tensor)
|
||||
)
|
||||
loss_val.backward()
|
||||
self.optimizer_model.step()
|
||||
return loss_val
|
||||
@@ -235,7 +236,9 @@ class CompetitivePINN(PINNInterface):
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += (
|
||||
1
|
||||
)
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def _train_discriminator(self, samples, equation, discriminator_bets):
|
||||
@@ -252,13 +255,14 @@ class CompetitivePINN(PINNInterface):
|
||||
self.optimizer_discriminator.zero_grad()
|
||||
# compute residual, we detach because the weights of the generator
|
||||
# model are fixed
|
||||
residual = self.compute_residual(samples=samples,
|
||||
equation=equation).detach()
|
||||
residual = self.compute_residual(
|
||||
samples=samples, equation=equation
|
||||
).detach()
|
||||
# compute competitive residual, the minus is because we maximise
|
||||
competitive_residual = residual * discriminator_bets
|
||||
loss_val = -self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual
|
||||
competitive_residual,
|
||||
).as_subclass(torch.Tensor)
|
||||
# backprop
|
||||
self.manual_backward(loss_val)
|
||||
@@ -283,16 +287,13 @@ class CompetitivePINN(PINNInterface):
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
# store logging
|
||||
with torch.no_grad():
|
||||
loss_residual = self.loss(
|
||||
torch.zeros_like(residual),
|
||||
residual
|
||||
)
|
||||
loss_residual = self.loss(torch.zeros_like(residual), residual)
|
||||
# compute competitive residual, discriminator_bets are detached becase
|
||||
# we optimize only the generator model
|
||||
competitive_residual = residual * discriminator_bets.detach()
|
||||
loss_val = self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual
|
||||
competitive_residual,
|
||||
).as_subclass(torch.Tensor)
|
||||
# backprop
|
||||
self.manual_backward(loss_val)
|
||||
|
||||
@@ -100,11 +100,12 @@ class GPINN(PINN):
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
)
|
||||
if not isinstance(self.problem, SpatialProblem):
|
||||
raise ValueError('Gradient PINN computes the gradient of the '
|
||||
'PINN loss with respect to the spatial '
|
||||
'coordinates, thus the PINA problem must be '
|
||||
'a SpatialProblem.')
|
||||
|
||||
raise ValueError(
|
||||
"Gradient PINN computes the gradient of the "
|
||||
"PINN loss with respect to the spatial "
|
||||
"coordinates, thus the PINA problem must be "
|
||||
"a SpatialProblem."
|
||||
)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
@@ -126,7 +127,7 @@ class GPINN(PINN):
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
# gradient PINN loss
|
||||
loss_value = loss_value.reshape(-1, 1)
|
||||
loss_value.labels = ['__LOSS']
|
||||
loss_value.labels = ["__LOSS"]
|
||||
loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables)
|
||||
g_loss_phys = self.loss(
|
||||
torch.zeros_like(loss_grad, requires_grad=True), loss_grad
|
||||
|
||||
@@ -87,7 +87,7 @@ class PINN(PINNInterface):
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
extra_features=extra_features,
|
||||
loss=loss
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
@@ -131,7 +131,6 @@ class PINN(PINNInterface):
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
return loss_value
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the PINN
|
||||
@@ -153,7 +152,6 @@ class PINN(PINNInterface):
|
||||
)
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
@@ -161,7 +159,6 @@ class PINN(PINNInterface):
|
||||
"""
|
||||
return self._scheduler
|
||||
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
|
||||
@@ -14,6 +14,7 @@ from pina.problem import InverseProblem
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
|
||||
class Weights(torch.nn.Module):
|
||||
"""
|
||||
This class aims to implements the mask model for
|
||||
@@ -27,9 +28,7 @@ class Weights(torch.nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
check_consistency(func, torch.nn.Module)
|
||||
self.sa_weights = torch.nn.Parameter(
|
||||
torch.Tensor()
|
||||
)
|
||||
self.sa_weights = torch.nn.Parameter(torch.Tensor())
|
||||
self.func = func
|
||||
|
||||
def forward(self):
|
||||
@@ -43,6 +42,7 @@ class Weights(torch.nn.Module):
|
||||
"""
|
||||
return self.func(self.sa_weights)
|
||||
|
||||
|
||||
class SAPINN(PINNInterface):
|
||||
r"""
|
||||
Self Adaptive Physics Informed Neural Network (SAPINN) solver class.
|
||||
@@ -121,7 +121,7 @@ class SAPINN(PINNInterface):
|
||||
scheduler_model=ConstantLR,
|
||||
scheduler_model_kwargs={"factor": 1, "total_iters": 0},
|
||||
scheduler_weights=ConstantLR,
|
||||
scheduler_weights_kwargs={"factor" : 1, "total_iters" : 0}
|
||||
scheduler_weights_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
@@ -167,17 +167,16 @@ class SAPINN(PINNInterface):
|
||||
weights_dict[condition_name] = Weights(weights_function)
|
||||
weights_dict = torch.nn.ModuleDict(weights_dict)
|
||||
|
||||
|
||||
super().__init__(
|
||||
models=[model, weights_dict],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
optimizers_kwargs=[
|
||||
optimizer_model_kwargs,
|
||||
optimizer_weights_kwargs
|
||||
optimizer_weights_kwargs,
|
||||
],
|
||||
extra_features=extra_features,
|
||||
loss=loss
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# set automatic optimization
|
||||
@@ -191,12 +190,8 @@ class SAPINN(PINNInterface):
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [
|
||||
scheduler_model(
|
||||
self.optimizers[0], **scheduler_model_kwargs
|
||||
),
|
||||
scheduler_weights(
|
||||
self.optimizers[1], **scheduler_weights_kwargs
|
||||
),
|
||||
scheduler_model(self.optimizers[0], **scheduler_model_kwargs),
|
||||
scheduler_weights(self.optimizers[1], **scheduler_weights_kwargs),
|
||||
]
|
||||
|
||||
self._model = self.models[0]
|
||||
@@ -327,7 +322,9 @@ class SAPINN(PINNInterface):
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += (
|
||||
1
|
||||
)
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def on_train_start(self):
|
||||
@@ -343,9 +340,8 @@ class SAPINN(PINNInterface):
|
||||
self.trainer._accelerator_connector._accelerator_flag
|
||||
)
|
||||
for condition_name, tensor in self.problem.input_pts.items():
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1),
|
||||
device = device
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = (
|
||||
torch.rand((tensor.shape[0], 1), device=device)
|
||||
)
|
||||
return super().on_train_start()
|
||||
|
||||
@@ -358,8 +354,8 @@ class SAPINN(PINNInterface):
|
||||
:param dict checkpoint: Pytorch Lightning checkpoint dict.
|
||||
"""
|
||||
for condition_name, tensor in self.problem.input_pts.items():
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1)
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = (
|
||||
torch.rand((tensor.shape[0], 1))
|
||||
)
|
||||
return super().on_load_checkpoint(checkpoint)
|
||||
|
||||
@@ -403,12 +399,14 @@ class SAPINN(PINNInterface):
|
||||
:rtype List[LabelTensor, LabelTensor]
|
||||
"""
|
||||
weights = self.weights_dict.torchmodel[
|
||||
self.current_condition_name].forward()
|
||||
loss_value = self._vectorial_loss(torch.zeros_like(
|
||||
residual, requires_grad=True), residual)
|
||||
self.current_condition_name
|
||||
].forward()
|
||||
loss_value = self._vectorial_loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
return (
|
||||
self._vect_to_scalar(weights * loss_value),
|
||||
self._vect_to_scalar(loss_value)
|
||||
self._vect_to_scalar(loss_value),
|
||||
)
|
||||
|
||||
def _vect_to_scalar(self, loss_value):
|
||||
@@ -426,11 +424,12 @@ class SAPINN(PINNInterface):
|
||||
elif self.loss.reduction == "sum":
|
||||
ret = torch.sum(loss_value)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid reduction, got {self.loss.reduction} "
|
||||
"but expected mean or sum.")
|
||||
raise RuntimeError(
|
||||
f"Invalid reduction, got {self.loss.reduction} "
|
||||
"but expected mean or sum."
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
|
||||
from pina.solvers import SupervisedSolver
|
||||
|
||||
|
||||
class ReducedOrderModelSolver(SupervisedSolver):
|
||||
r"""
|
||||
ReducedOrderModelSolver solver class. This class implements a
|
||||
@@ -114,9 +115,12 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
model = torch.nn.ModuleDict({
|
||||
'reduction_network' : reduction_network,
|
||||
'interpolation_network' : interpolation_network})
|
||||
model = torch.nn.ModuleDict(
|
||||
{
|
||||
"reduction_network": reduction_network,
|
||||
"interpolation_network": interpolation_network,
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
@@ -125,18 +129,22 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
)
|
||||
|
||||
# assert reduction object contains encode/ decode
|
||||
if not hasattr(self.neural_net['reduction_network'], 'encode'):
|
||||
raise SyntaxError('reduction_network must have encode method. '
|
||||
'The encode method should return a lower '
|
||||
'dimensional representation of the input.')
|
||||
if not hasattr(self.neural_net['reduction_network'], 'decode'):
|
||||
raise SyntaxError('reduction_network must have decode method. '
|
||||
'The decode method should return a high '
|
||||
'dimensional representation of the encoding.')
|
||||
if not hasattr(self.neural_net["reduction_network"], "encode"):
|
||||
raise SyntaxError(
|
||||
"reduction_network must have encode method. "
|
||||
"The encode method should return a lower "
|
||||
"dimensional representation of the input."
|
||||
)
|
||||
if not hasattr(self.neural_net["reduction_network"], "decode"):
|
||||
raise SyntaxError(
|
||||
"reduction_network must have decode method. "
|
||||
"The decode method should return a high "
|
||||
"dimensional representation of the encoding."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -149,8 +157,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
reduction_network = self.neural_net['reduction_network']
|
||||
interpolation_network = self.neural_net['interpolation_network']
|
||||
reduction_network = self.neural_net["reduction_network"]
|
||||
interpolation_network = self.neural_net["interpolation_network"]
|
||||
return reduction_network.decode(interpolation_network(x))
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@@ -167,17 +175,18 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract networks
|
||||
reduction_network = self.neural_net['reduction_network']
|
||||
interpolation_network = self.neural_net['interpolation_network']
|
||||
reduction_network = self.neural_net["reduction_network"]
|
||||
interpolation_network = self.neural_net["interpolation_network"]
|
||||
# encoded representations loss
|
||||
encode_repr_inter_net = interpolation_network(input_pts)
|
||||
encode_repr_reduction_network = reduction_network.encode(output_pts)
|
||||
loss_encode = self.loss(encode_repr_inter_net,
|
||||
encode_repr_reduction_network)
|
||||
loss_encode = self.loss(
|
||||
encode_repr_inter_net, encode_repr_reduction_network
|
||||
)
|
||||
# reconstruction loss
|
||||
loss_reconstruction = self.loss(
|
||||
reduction_network.decode(encode_repr_reduction_network),
|
||||
output_pts)
|
||||
reduction_network.decode(encode_repr_reduction_network), output_pts
|
||||
)
|
||||
|
||||
return loss_encode + loss_reconstruction
|
||||
|
||||
|
||||
@@ -67,9 +67,9 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
pb = self._model.problem
|
||||
if hasattr(pb, "unknown_parameters"):
|
||||
for key in pb.unknown_parameters:
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device))
|
||||
|
||||
|
||||
pb.unknown_parameters[key] = torch.nn.Parameter(
|
||||
pb.unknown_parameters[key].data.to(device)
|
||||
)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user