Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
committed by
Nicola Demo
parent
780c4921eb
commit
9cae9a438f
@@ -88,11 +88,11 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
problem,
|
||||
reduction_network,
|
||||
interpolation_network,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=torch.optim.lr_scheduler.ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
use_lt=True,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
@@ -105,15 +105,12 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
the ``reduction_network`` encoding.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param float lr: The learning rate; default is 0.001.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
:param WeightingInterface weighting: The loss weighting to use.
|
||||
:param bool use_lt: Using LabelTensors as input during training.
|
||||
"""
|
||||
model = torch.nn.ModuleDict(
|
||||
{
|
||||
@@ -127,19 +124,19 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
problem=problem,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
weighting=weighting,
|
||||
use_lt=use_lt
|
||||
)
|
||||
|
||||
# assert reduction object contains encode/ decode
|
||||
if not hasattr(self.neural_net["reduction_network"], "encode"):
|
||||
if not hasattr(self.model["reduction_network"], "encode"):
|
||||
raise SyntaxError(
|
||||
"reduction_network must have encode method. "
|
||||
"The encode method should return a lower "
|
||||
"dimensional representation of the input."
|
||||
)
|
||||
if not hasattr(self.neural_net["reduction_network"], "decode"):
|
||||
if not hasattr(self.model["reduction_network"], "decode"):
|
||||
raise SyntaxError(
|
||||
"reduction_network must have decode method. "
|
||||
"The decode method should return a high "
|
||||
@@ -157,8 +154,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
reduction_network = self.neural_net["reduction_network"]
|
||||
interpolation_network = self.neural_net["interpolation_network"]
|
||||
reduction_network = self.model["reduction_network"]
|
||||
interpolation_network = self.model["interpolation_network"]
|
||||
return reduction_network.decode(interpolation_network(x))
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@@ -175,8 +172,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract networks
|
||||
reduction_network = self.neural_net["reduction_network"]
|
||||
interpolation_network = self.neural_net["interpolation_network"]
|
||||
reduction_network = self.model["reduction_network"]
|
||||
interpolation_network = self.model["interpolation_network"]
|
||||
# encoded representations loss
|
||||
encode_repr_inter_net = interpolation_network(input_pts)
|
||||
encode_repr_reduction_network = reduction_network.encode(output_pts)
|
||||
@@ -188,12 +185,4 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
||||
reduction_network.decode(encode_repr_reduction_network), output_pts
|
||||
)
|
||||
|
||||
return loss_encode + loss_reconstruction
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Neural network for training. It returns a :obj:`~torch.nn.ModuleDict`
|
||||
containing the ``reduction_network`` and ``interpolation_network``.
|
||||
"""
|
||||
return self._neural_net.torchmodel
|
||||
return loss_encode + loss_reconstruction
|
||||
Reference in New Issue
Block a user