🎨 Format Python code with psf/black (#297)
Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e0429bb445
commit
9463ae4b15
@@ -4,6 +4,7 @@ import torch
|
||||
|
||||
from pina.solvers import SupervisedSolver
|
||||
|
||||
|
||||
class ReducedOrderModelSolver(SupervisedSolver):
|
||||
r"""
|
||||
ReducedOrderModelSolver solver class. This class implements a
|
||||
@@ -114,10 +115,13 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
model = torch.nn.ModuleDict({
|
||||
'reduction_network' : reduction_network,
|
||||
'interpolation_network' : interpolation_network})
|
||||
|
||||
model = torch.nn.ModuleDict(
|
||||
{
|
||||
"reduction_network": reduction_network,
|
||||
"interpolation_network": interpolation_network,
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
@@ -125,18 +129,22 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
)
|
||||
|
||||
# assert reduction object contains encode/ decode
|
||||
if not hasattr(self.neural_net['reduction_network'], 'encode'):
|
||||
raise SyntaxError('reduction_network must have encode method. '
|
||||
'The encode method should return a lower '
|
||||
'dimensional representation of the input.')
|
||||
if not hasattr(self.neural_net['reduction_network'], 'decode'):
|
||||
raise SyntaxError('reduction_network must have decode method. '
|
||||
'The decode method should return a high '
|
||||
'dimensional representation of the encoding.')
|
||||
if not hasattr(self.neural_net["reduction_network"], "encode"):
|
||||
raise SyntaxError(
|
||||
"reduction_network must have encode method. "
|
||||
"The encode method should return a lower "
|
||||
"dimensional representation of the input."
|
||||
)
|
||||
if not hasattr(self.neural_net["reduction_network"], "decode"):
|
||||
raise SyntaxError(
|
||||
"reduction_network must have decode method. "
|
||||
"The decode method should return a high "
|
||||
"dimensional representation of the encoding."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -149,8 +157,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
reduction_network = self.neural_net['reduction_network']
|
||||
interpolation_network = self.neural_net['interpolation_network']
|
||||
reduction_network = self.neural_net["reduction_network"]
|
||||
interpolation_network = self.neural_net["interpolation_network"]
|
||||
return reduction_network.decode(interpolation_network(x))
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@@ -167,17 +175,18 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract networks
|
||||
reduction_network = self.neural_net['reduction_network']
|
||||
interpolation_network = self.neural_net['interpolation_network']
|
||||
reduction_network = self.neural_net["reduction_network"]
|
||||
interpolation_network = self.neural_net["interpolation_network"]
|
||||
# encoded representations loss
|
||||
encode_repr_inter_net = interpolation_network(input_pts)
|
||||
encode_repr_reduction_network = reduction_network.encode(output_pts)
|
||||
loss_encode = self.loss(encode_repr_inter_net,
|
||||
encode_repr_reduction_network)
|
||||
loss_encode = self.loss(
|
||||
encode_repr_inter_net, encode_repr_reduction_network
|
||||
)
|
||||
# reconstruction loss
|
||||
loss_reconstruction = self.loss(
|
||||
reduction_network.decode(encode_repr_reduction_network),
|
||||
output_pts)
|
||||
reduction_network.decode(encode_repr_reduction_network), output_pts
|
||||
)
|
||||
|
||||
return loss_encode + loss_reconstruction
|
||||
|
||||
|
||||
Reference in New Issue
Block a user