add batching support for self-adaptive pinns

This commit is contained in:
Giovanni Canali
2025-07-22 14:45:43 +02:00
committed by Giovanni Canali
parent 1ed14916f1
commit 6d1d4ef423
2 changed files with 208 additions and 140 deletions

View File

@@ -15,15 +15,20 @@ class Weights(torch.nn.Module):
:class:`SelfAdaptivePINN` solver. :class:`SelfAdaptivePINN` solver.
""" """
def __init__(self, func): def __init__(self, func, num_points):
""" """
Initialization of the :class:`Weights` class. Initialization of the :class:`Weights` class.
:param torch.nn.Module func: the mask model. :param torch.nn.Module func: the mask model.
:param int num_points: the number of input points.
""" """
super().__init__() super().__init__()
# Check consistency
check_consistency(func, torch.nn.Module) check_consistency(func, torch.nn.Module)
self.sa_weights = torch.nn.Parameter(torch.Tensor())
# Initialize the weights as a learnable parameter
self.sa_weights = torch.nn.Parameter(torch.zeros(num_points, 1))
self.func = func self.func = func
def forward(self): def forward(self):
@@ -140,17 +145,17 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
If ``None``, the :class:`torch.nn.MSELoss` loss is used. If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`. Default is `None`.
""" """
# check consistency weitghs_function # Check consistency
check_consistency(weight_function, torch.nn.Module) check_consistency(weight_function, torch.nn.Module)
# create models for weights # Define a ModuleDict for the weights
weights_dict = {} weights = {}
for condition_name in problem.conditions: for cond, data in problem.input_pts.items():
weights_dict[condition_name] = Weights(weight_function) weights[cond] = Weights(func=weight_function, num_points=len(data))
weights_dict = torch.nn.ModuleDict(weights_dict) weights = torch.nn.ModuleDict(weights)
super().__init__( super().__init__(
models=[model, weights_dict], models=[model, weights],
problem=problem, problem=problem,
optimizers=[optimizer_model, optimizer_weights], optimizers=[optimizer_model, optimizer_weights],
schedulers=[scheduler_model, scheduler_weights], schedulers=[scheduler_model, scheduler_weights],
@@ -158,116 +163,93 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
loss=loss, loss=loss,
) )
self._vectorial_loss = deepcopy(self.loss) # Extract the reduction method from the loss function
self._vectorial_loss.reduction = "none" self._reduction = self._loss_fn.reduction
def forward(self, x): # Set the loss function to return non-aggregated losses
""" self._loss_fn = type(self._loss_fn)(reduction="none")
Forward pass.
:param LabelTensor x: Input tensor. def training_step(self, batch, batch_idx, **kwargs):
:return: The output of the neural network.
:rtype: LabelTensor
""" """
return self.model(x) Solver training step. It computes the optimization cycle and aggregates
the losses using the ``weighting`` attribute.
def training_step(self, batch):
"""
Solver training step, overridden to perform manual optimization.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a :param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points. tuple containing a condition name and a dictionary of points.
:return: The aggregated loss. :param int batch_idx: The index of the current batch.
:rtype: LabelTensor :param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the training step.
:rtype: torch.Tensor
""" """
# Weights optimization # Weights optimization
self.optimizer_weights.instance.zero_grad() self.optimizer_weights.instance.zero_grad()
loss = super().training_step(batch) loss = self._optimization_cycle(
batch=batch, batch_idx=batch_idx, **kwargs
)
self.manual_backward(-loss) self.manual_backward(-loss)
self.optimizer_weights.instance.step() self.optimizer_weights.instance.step()
self.scheduler_weights.instance.step() self.scheduler_weights.instance.step()
# Model optimization # Model optimization
self.optimizer_model.instance.zero_grad() self.optimizer_model.instance.zero_grad()
loss = super().training_step(batch) loss = self._optimization_cycle(
batch=batch, batch_idx=batch_idx, **kwargs
)
self.manual_backward(loss) self.manual_backward(loss)
self.optimizer_model.instance.step() self.optimizer_model.instance.step()
self.scheduler_model.instance.step() self.scheduler_model.instance.step()
# Log the loss
self.store_log("train_loss", loss, self.get_batch_size(batch))
return loss return loss
def configure_optimizers(self): @torch.set_grad_enabled(True)
def validation_step(self, batch, **kwargs):
""" """
Optimizer configuration. The validation step for the Self-Adaptive PINN solver. It returns the
average residual computed with the ``loss`` function not aggregated.
:return: The optimizers and the schedulers :param list[tuple[str, dict]] batch: A batch of data. Each element is a
:rtype: tuple[list[Optimizer], list[Scheduler]] tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the validation step.
:rtype: torch.Tensor
""" """
# If the problem is an InverseProblem, add the unknown parameters losses = self.optimization_cycle(batch=batch, **kwargs)
# to the parameters to be optimized
self.optimizer_model.hook(self.model.parameters())
self.optimizer_weights.hook(self.weights_dict.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizer_model.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self.scheduler_model.hook(self.optimizer_model)
self.scheduler_weights.hook(self.optimizer_weights)
return (
[self.optimizer_model.instance, self.optimizer_weights.instance],
[self.scheduler_model.instance, self.scheduler_weights.instance],
)
def on_train_start(self): # Aggregate losses for each condition
for cond, loss in losses.items():
losses[cond] = self._apply_reduction(loss=losses[cond])
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
@torch.set_grad_enabled(True)
def test_step(self, batch, **kwargs):
""" """
This method is called at the start of the training process to set the The test step for the Self-Adaptive PINN solver. It returns the average
self-adaptive weights as parameters of the mask model. residual computed with the ``loss`` function not aggregated.
:raises NotImplementedError: If the batch size is not ``None``. :param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the test step.
:rtype: torch.Tensor
""" """
if self.trainer.batch_size is not None: losses = self.optimization_cycle(batch=batch, **kwargs)
raise NotImplementedError(
"SelfAdaptivePINN only works with full "
"batch size, set batch_size=None inside "
"the Trainer to use the solver."
)
device = torch.device(
self.trainer._accelerator_connector._accelerator_flag
)
# Initialize the self adaptive weights only for training points # Aggregate losses for each condition
for ( for cond, loss in losses.items():
condition_name, losses[cond] = self._apply_reduction(loss=losses[cond])
tensor,
) in self.trainer.data_module.train_dataset.input.items():
self.weights_dict[condition_name].sa_weights.data = torch.rand(
(tensor.shape[0], 1), device=device
)
return super().on_train_start()
def on_load_checkpoint(self, checkpoint): loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
""" self.store_log("test_loss", loss, self.get_batch_size(batch))
Override of the Pytorch Lightning ``on_load_checkpoint`` method to return loss
handle checkpoints for Self-Adaptive Weights. This method should not be
overridden, if not intentionally.
:param dict checkpoint: Pytorch Lightning checkpoint dict.
"""
# First initialize self-adaptive weights with correct shape,
# then load the values from the checkpoint.
for condition_name, _ in self.problem.input_pts.items():
shape = checkpoint["state_dict"][
f"_pina_models.1.{condition_name}.sa_weights"
].shape
self.weights_dict[condition_name].sa_weights.data = torch.rand(
shape
)
return super().on_load_checkpoint(checkpoint)
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
@@ -279,47 +261,138 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
:return: The computed physics loss. :return: The computed physics loss.
:rtype: LabelTensor :rtype: LabelTensor
""" """
residual = self.compute_residual(samples, equation) residuals = self.compute_residual(samples, equation)
weights = self.weights_dict[self.current_condition_name].forward() return self._loss_fn(residuals, torch.zeros_like(residuals))
loss_value = self._vectorial_loss(
torch.zeros_like(residual, requires_grad=True), residual
)
return self._vect_to_scalar(weights * loss_value)
def loss_data(self, input, target): def loss_data(self, input, target):
""" """
Compute the data loss for the PINN solver by evaluating the loss Compute the data loss for the Self-Adaptive PINN solver by evaluating
between the network's output and the true solution. This method should the loss between the network's output and the true solution. This method
not be overridden, if not intentionally. should not be overridden, if not intentionally.
:param input: The input to the neural network. :param input: The input to the neural network.
:type input: LabelTensor :type input: LabelTensor | torch.Tensor
:param target: The target to compare with the network's output. :param target: The target to compare with the network's output.
:type target: LabelTensor :type target: LabelTensor | torch.Tensor
:return: The supervised loss, averaged over the number of observations. :return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor :rtype: LabelTensor | torch.Tensor
""" """
return self._loss_fn(self.forward(input), target) return self._loss_fn(self.forward(input), target)
def _vect_to_scalar(self, loss_value): def forward(self, x):
""" """
Computation of the scalar loss. Forward pass.
:param LabelTensor loss_value: the tensor of pointwise losses. :param x: Input tensor.
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``. :type x: torch.Tensor | LabelTensor
:return: The computed scalar loss. :return: The output of the neural network.
:rtype: LabelTensor :rtype: torch.Tensor | LabelTensor
""" """
if self.loss.reduction == "mean": return self.model(x)
ret = torch.mean(loss_value)
elif self.loss.reduction == "sum": def configure_optimizers(self):
ret = torch.sum(loss_value) """
else: Optimizer configuration.
raise RuntimeError(
f"Invalid reduction, got {self.loss.reduction} " :return: The optimizers and the schedulers
"but expected mean or sum." :rtype: tuple[list[Optimizer], list[Scheduler]]
"""
# Hook the optimizers to the models
self.optimizer_model.hook(self.model.parameters())
self.optimizer_weights.hook(self.weights.parameters())
# Add unknown parameters to optimization list in case of InverseProblem
if isinstance(self.problem, InverseProblem):
self.optimizer_model.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
) )
return ret
# Hook the schedulers to the optimizers
self.scheduler_model.hook(self.optimizer_model)
self.scheduler_weights.hook(self.optimizer_weights)
return (
[self.optimizer_model.instance, self.optimizer_weights.instance],
[self.scheduler_model.instance, self.scheduler_weights.instance],
)
def _optimization_cycle(self, batch, batch_idx, **kwargs):
"""
Aggregate the loss for each condition in the batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
# Compute non-aggregated residuals
residuals = self.optimization_cycle(batch)
# Compute losses
losses = {}
for cond, res in residuals.items():
weight_tensor = self.weights[cond]()
# Get the correct indices for the weights. Modulus is used according
# to the number of points in the condition, as in the PinaDataset.
len_res = len(res)
idx = torch.arange(
batch_idx * len_res,
(batch_idx + 1) * len_res,
device=res.device,
) % len(self.problem.input_pts[cond])
# Apply the weights to the residuals
losses[cond] = self._apply_reduction(
loss=(res * weight_tensor[idx])
)
# Store log
self.store_log(
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
)
# Clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
# Aggregate
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss
def _apply_reduction(self, loss):
"""
Apply the specified reduction to the loss. The reduction is deferred
until the end of the optimization cycle to allow self-adaptive weights
to be applied to each point beforehand.
:param torch.Tensor loss: The loss tensor to be reduced.
:return: The reduced loss tensor.
:rtype: torch.Tensor
:raises ValueError: If the reduction method is neither "mean" nor "sum".
"""
# Apply the specified reduction method
if self._reduction == "mean":
return loss.mean()
if self._reduction == "sum":
return loss.sum()
# Raise an error if the reduction method is not recognized
raise ValueError(
f"Unknown reduction: {self._reduction}."
" Supported reductions are 'mean' and 'sum'."
)
@property @property
def model(self): def model(self):
@@ -332,7 +405,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
return self.models[0] return self.models[0]
@property @property
def weights_dict(self): def weights(self):
""" """
The self-adaptive weights. The self-adaptive weights.

View File

@@ -42,9 +42,11 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("weight_fn", [torch.nn.Sigmoid(), torch.nn.Tanh()]) @pytest.mark.parametrize("weight_fn", [torch.nn.Sigmoid(), torch.nn.Tanh()])
def test_constructor(problem, weight_fn): def test_constructor(problem, weight_fn):
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
with pytest.raises(ValueError): with pytest.raises(ValueError):
SAPINN(model=model, problem=problem, weight_function=1) SAPINN(model=model, problem=problem, weight_function=1)
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
assert solver.accepted_conditions_types == ( assert solver.accepted_conditions_types == (
InputTargetCondition, InputTargetCondition,
@@ -53,26 +55,13 @@ def test_constructor(problem, weight_fn):
) )
@pytest.mark.parametrize("problem", [problem, inverse_problem])
def test_wrong_batch(problem):
with pytest.raises(NotImplementedError):
solver = SAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=10,
train_size=1.0,
val_size=0.0,
test_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize("compile", [True, False])
def test_solver_train(problem, compile): @pytest.mark.parametrize(
solver = SAPINN(model=model, problem=problem) "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_train(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer( trainer = Trainer(
solver=solver, solver=solver,
max_epochs=2, max_epochs=2,
@@ -95,8 +84,11 @@ def test_solver_train(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(problem, compile): @pytest.mark.parametrize(
solver = SAPINN(model=model, problem=problem) "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_validation(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer( trainer = Trainer(
solver=solver, solver=solver,
max_epochs=2, max_epochs=2,
@@ -119,8 +111,11 @@ def test_solver_validation(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize("compile", [True, False])
def test_solver_test(problem, compile): @pytest.mark.parametrize(
solver = SAPINN(model=model, problem=problem) "loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_test(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer( trainer = Trainer(
solver=solver, solver=solver,
max_epochs=2, max_epochs=2,