🎨 Format Python code with psf/black (#297)

Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2024-05-10 14:08:01 +02:00
committed by GitHub
parent e0429bb445
commit 9463ae4b15
11 changed files with 169 additions and 160 deletions

View File

@@ -4,6 +4,7 @@ import torch
from pina.solvers import SupervisedSolver
class ReducedOrderModelSolver(SupervisedSolver):
r"""
ReducedOrderModelSolver solver class. This class implements a
@@ -114,10 +115,13 @@ class ReducedOrderModelSolver(SupervisedSolver):
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
"""
model = torch.nn.ModuleDict({
'reduction_network' : reduction_network,
'interpolation_network' : interpolation_network})
model = torch.nn.ModuleDict(
{
"reduction_network": reduction_network,
"interpolation_network": interpolation_network,
}
)
super().__init__(
model=model,
problem=problem,
@@ -125,18 +129,22 @@ class ReducedOrderModelSolver(SupervisedSolver):
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs
scheduler_kwargs=scheduler_kwargs,
)
# assert reduction object contains encode/ decode
if not hasattr(self.neural_net['reduction_network'], 'encode'):
raise SyntaxError('reduction_network must have encode method. '
'The encode method should return a lower '
'dimensional representation of the input.')
if not hasattr(self.neural_net['reduction_network'], 'decode'):
raise SyntaxError('reduction_network must have decode method. '
'The decode method should return a high '
'dimensional representation of the encoding.')
if not hasattr(self.neural_net["reduction_network"], "encode"):
raise SyntaxError(
"reduction_network must have encode method. "
"The encode method should return a lower "
"dimensional representation of the input."
)
if not hasattr(self.neural_net["reduction_network"], "decode"):
raise SyntaxError(
"reduction_network must have decode method. "
"The decode method should return a high "
"dimensional representation of the encoding."
)
def forward(self, x):
"""
@@ -149,8 +157,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
:return: Solver solution.
:rtype: torch.Tensor
"""
reduction_network = self.neural_net['reduction_network']
interpolation_network = self.neural_net['interpolation_network']
reduction_network = self.neural_net["reduction_network"]
interpolation_network = self.neural_net["interpolation_network"]
return reduction_network.decode(interpolation_network(x))
def loss_data(self, input_pts, output_pts):
@@ -167,17 +175,18 @@ class ReducedOrderModelSolver(SupervisedSolver):
:rtype: torch.Tensor
"""
# extract networks
reduction_network = self.neural_net['reduction_network']
interpolation_network = self.neural_net['interpolation_network']
reduction_network = self.neural_net["reduction_network"]
interpolation_network = self.neural_net["interpolation_network"]
# encoded representations loss
encode_repr_inter_net = interpolation_network(input_pts)
encode_repr_reduction_network = reduction_network.encode(output_pts)
loss_encode = self.loss(encode_repr_inter_net,
encode_repr_reduction_network)
loss_encode = self.loss(
encode_repr_inter_net, encode_repr_reduction_network
)
# reconstruction loss
loss_reconstruction = self.loss(
reduction_network.decode(encode_repr_reduction_network),
output_pts)
reduction_network.decode(encode_repr_reduction_network), output_pts
)
return loss_encode + loss_reconstruction