Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -88,11 +88,11 @@ class ReducedOrderModelSolver(SupervisedSolver):
problem,
reduction_network,
interpolation_network,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=torch.optim.lr_scheduler.ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=True,
):
"""
:param AbstractProblem problem: The formualation of the problem.
@@ -105,15 +105,12 @@ class ReducedOrderModelSolver(SupervisedSolver):
the ``reduction_network`` encoding.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param float lr: The learning rate; default is 0.001.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
"""
model = torch.nn.ModuleDict(
{
@@ -127,19 +124,19 @@ class ReducedOrderModelSolver(SupervisedSolver):
problem=problem,
loss=loss,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
weighting=weighting,
use_lt=use_lt
)
# assert reduction object contains encode/ decode
if not hasattr(self.neural_net["reduction_network"], "encode"):
if not hasattr(self.model["reduction_network"], "encode"):
raise SyntaxError(
"reduction_network must have encode method. "
"The encode method should return a lower "
"dimensional representation of the input."
)
if not hasattr(self.neural_net["reduction_network"], "decode"):
if not hasattr(self.model["reduction_network"], "decode"):
raise SyntaxError(
"reduction_network must have decode method. "
"The decode method should return a high "
@@ -157,8 +154,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
:return: Solver solution.
:rtype: torch.Tensor
"""
reduction_network = self.neural_net["reduction_network"]
interpolation_network = self.neural_net["interpolation_network"]
reduction_network = self.model["reduction_network"]
interpolation_network = self.model["interpolation_network"]
return reduction_network.decode(interpolation_network(x))
def loss_data(self, input_pts, output_pts):
@@ -175,8 +172,8 @@ class ReducedOrderModelSolver(SupervisedSolver):
:rtype: torch.Tensor
"""
# extract networks
reduction_network = self.neural_net["reduction_network"]
interpolation_network = self.neural_net["interpolation_network"]
reduction_network = self.model["reduction_network"]
interpolation_network = self.model["interpolation_network"]
# encoded representations loss
encode_repr_inter_net = interpolation_network(input_pts)
encode_repr_reduction_network = reduction_network.encode(output_pts)
@@ -188,12 +185,4 @@ class ReducedOrderModelSolver(SupervisedSolver):
reduction_network.decode(encode_repr_reduction_network), output_pts
)
return loss_encode + loss_reconstruction
@property
def neural_net(self):
"""
Neural network for training. It returns a :obj:`~torch.nn.ModuleDict`
containing the ``reduction_network`` and ``interpolation_network``.
"""
return self._neural_net.torchmodel
return loss_encode + loss_reconstruction