add batching support for self-adaptive pinns

This commit is contained in:
Giovanni Canali
2025-07-22 14:45:43 +02:00
committed by Giovanni Canali
parent 1ed14916f1
commit 6d1d4ef423
2 changed files with 208 additions and 140 deletions

View File

@@ -15,15 +15,20 @@ class Weights(torch.nn.Module):
:class:`SelfAdaptivePINN` solver.
"""
def __init__(self, func):
def __init__(self, func, num_points):
"""
Initialization of the :class:`Weights` class.
:param torch.nn.Module func: the mask model.
:param int num_points: the number of input points.
"""
super().__init__()
# Check consistency
check_consistency(func, torch.nn.Module)
self.sa_weights = torch.nn.Parameter(torch.Tensor())
# Initialize the weights as a learnable parameter
self.sa_weights = torch.nn.Parameter(torch.zeros(num_points, 1))
self.func = func
def forward(self):
@@ -140,17 +145,17 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
"""
# check consistency weitghs_function
# Check consistency
check_consistency(weight_function, torch.nn.Module)
# create models for weights
weights_dict = {}
for condition_name in problem.conditions:
weights_dict[condition_name] = Weights(weight_function)
weights_dict = torch.nn.ModuleDict(weights_dict)
# Define a ModuleDict for the weights
weights = {}
for cond, data in problem.input_pts.items():
weights[cond] = Weights(func=weight_function, num_points=len(data))
weights = torch.nn.ModuleDict(weights)
super().__init__(
models=[model, weights_dict],
models=[model, weights],
problem=problem,
optimizers=[optimizer_model, optimizer_weights],
schedulers=[scheduler_model, scheduler_weights],
@@ -158,116 +163,93 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
loss=loss,
)
self._vectorial_loss = deepcopy(self.loss)
self._vectorial_loss.reduction = "none"
# Extract the reduction method from the loss function
self._reduction = self._loss_fn.reduction
def forward(self, x):
"""
Forward pass.
# Set the loss function to return non-aggregated losses
self._loss_fn = type(self._loss_fn)(reduction="none")
:param LabelTensor x: Input tensor.
:return: The output of the neural network.
:rtype: LabelTensor
def training_step(self, batch, batch_idx, **kwargs):
"""
return self.model(x)
def training_step(self, batch):
"""
Solver training step, overridden to perform manual optimization.
Solver training step. It computes the optimization cycle and aggregates
the losses using the ``weighting`` attribute.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The aggregated loss.
:rtype: LabelTensor
:param int batch_idx: The index of the current batch.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
# Weights optimization
self.optimizer_weights.instance.zero_grad()
loss = super().training_step(batch)
loss = self._optimization_cycle(
batch=batch, batch_idx=batch_idx, **kwargs
)
self.manual_backward(-loss)
self.optimizer_weights.instance.step()
self.scheduler_weights.instance.step()
# Model optimization
self.optimizer_model.instance.zero_grad()
loss = super().training_step(batch)
loss = self._optimization_cycle(
batch=batch, batch_idx=batch_idx, **kwargs
)
self.manual_backward(loss)
self.optimizer_model.instance.step()
self.scheduler_model.instance.step()
# Log the loss
self.store_log("train_loss", loss, self.get_batch_size(batch))
return loss
def configure_optimizers(self):
@torch.set_grad_enabled(True)
def validation_step(self, batch, **kwargs):
"""
Optimizer configuration.
The validation step for the Self-Adaptive PINN solver. It returns the
average residual computed with the ``loss`` function not aggregated.
:return: The optimizers and the schedulers
:rtype: tuple[list[Optimizer], list[Scheduler]]
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
# If the problem is an InverseProblem, add the unknown parameters
# to the parameters to be optimized
self.optimizer_model.hook(self.model.parameters())
self.optimizer_weights.hook(self.weights_dict.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizer_model.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self.scheduler_model.hook(self.optimizer_model)
self.scheduler_weights.hook(self.optimizer_weights)
return (
[self.optimizer_model.instance, self.optimizer_weights.instance],
[self.scheduler_model.instance, self.scheduler_weights.instance],
)
losses = self.optimization_cycle(batch=batch, **kwargs)
def on_train_start(self):
# Aggregate losses for each condition
for cond, loss in losses.items():
losses[cond] = self._apply_reduction(loss=losses[cond])
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
@torch.set_grad_enabled(True)
def test_step(self, batch, **kwargs):
"""
This method is called at the start of the training process to set the
self-adaptive weights as parameters of the mask model.
The test step for the Self-Adaptive PINN solver. It returns the average
residual computed with the ``loss`` function not aggregated.
:raises NotImplementedError: If the batch size is not ``None``.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
if self.trainer.batch_size is not None:
raise NotImplementedError(
"SelfAdaptivePINN only works with full "
"batch size, set batch_size=None inside "
"the Trainer to use the solver."
)
device = torch.device(
self.trainer._accelerator_connector._accelerator_flag
)
losses = self.optimization_cycle(batch=batch, **kwargs)
# Initialize the self adaptive weights only for training points
for (
condition_name,
tensor,
) in self.trainer.data_module.train_dataset.input.items():
self.weights_dict[condition_name].sa_weights.data = torch.rand(
(tensor.shape[0], 1), device=device
)
return super().on_train_start()
# Aggregate losses for each condition
for cond, loss in losses.items():
losses[cond] = self._apply_reduction(loss=losses[cond])
def on_load_checkpoint(self, checkpoint):
"""
Override of the Pytorch Lightning ``on_load_checkpoint`` method to
handle checkpoints for Self-Adaptive Weights. This method should not be
overridden, if not intentionally.
:param dict checkpoint: Pytorch Lightning checkpoint dict.
"""
# First initialize self-adaptive weights with correct shape,
# then load the values from the checkpoint.
for condition_name, _ in self.problem.input_pts.items():
shape = checkpoint["state_dict"][
f"_pina_models.1.{condition_name}.sa_weights"
].shape
self.weights_dict[condition_name].sa_weights.data = torch.rand(
shape
)
return super().on_load_checkpoint(checkpoint)
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss
def loss_phys(self, samples, equation):
"""
@@ -279,47 +261,138 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
:return: The computed physics loss.
:rtype: LabelTensor
"""
residual = self.compute_residual(samples, equation)
weights = self.weights_dict[self.current_condition_name].forward()
loss_value = self._vectorial_loss(
torch.zeros_like(residual, requires_grad=True), residual
)
return self._vect_to_scalar(weights * loss_value)
residuals = self.compute_residual(samples, equation)
return self._loss_fn(residuals, torch.zeros_like(residuals))
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
Compute the data loss for the Self-Adaptive PINN solver by evaluating
the loss between the network's output and the true solution. This method
should not be overridden, if not intentionally.
:param input: The input to the neural network.
:type input: LabelTensor
:type input: LabelTensor | torch.Tensor
:param target: The target to compare with the network's output.
:type target: LabelTensor
:type target: LabelTensor | torch.Tensor
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
:rtype: LabelTensor | torch.Tensor
"""
return self._loss_fn(self.forward(input), target)
def _vect_to_scalar(self, loss_value):
def forward(self, x):
"""
Computation of the scalar loss.
Forward pass.
:param LabelTensor loss_value: the tensor of pointwise losses.
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
:return: The computed scalar loss.
:rtype: LabelTensor
:param x: Input tensor.
:type x: torch.Tensor | LabelTensor
:return: The output of the neural network.
:rtype: torch.Tensor | LabelTensor
"""
if self.loss.reduction == "mean":
ret = torch.mean(loss_value)
elif self.loss.reduction == "sum":
ret = torch.sum(loss_value)
else:
raise RuntimeError(
f"Invalid reduction, got {self.loss.reduction} "
"but expected mean or sum."
return self.model(x)
def configure_optimizers(self):
"""
Optimizer configuration.
:return: The optimizers and the schedulers
:rtype: tuple[list[Optimizer], list[Scheduler]]
"""
# Hook the optimizers to the models
self.optimizer_model.hook(self.model.parameters())
self.optimizer_weights.hook(self.weights.parameters())
# Add unknown parameters to optimization list in case of InverseProblem
if isinstance(self.problem, InverseProblem):
self.optimizer_model.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
return ret
# Hook the schedulers to the optimizers
self.scheduler_model.hook(self.optimizer_model)
self.scheduler_weights.hook(self.optimizer_weights)
return (
[self.optimizer_model.instance, self.optimizer_weights.instance],
[self.scheduler_model.instance, self.scheduler_weights.instance],
)
def _optimization_cycle(self, batch, batch_idx, **kwargs):
"""
Aggregate the loss for each condition in the batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
# Compute non-aggregated residuals
residuals = self.optimization_cycle(batch)
# Compute losses
losses = {}
for cond, res in residuals.items():
weight_tensor = self.weights[cond]()
# Get the correct indices for the weights. Modulus is used according
# to the number of points in the condition, as in the PinaDataset.
len_res = len(res)
idx = torch.arange(
batch_idx * len_res,
(batch_idx + 1) * len_res,
device=res.device,
) % len(self.problem.input_pts[cond])
# Apply the weights to the residuals
losses[cond] = self._apply_reduction(
loss=(res * weight_tensor[idx])
)
# Store log
self.store_log(
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
)
# Clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
# Aggregate
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss
def _apply_reduction(self, loss):
"""
Apply the specified reduction to the loss. The reduction is deferred
until the end of the optimization cycle to allow self-adaptive weights
to be applied to each point beforehand.
:param torch.Tensor loss: The loss tensor to be reduced.
:return: The reduced loss tensor.
:rtype: torch.Tensor
:raises ValueError: If the reduction method is neither "mean" nor "sum".
"""
# Apply the specified reduction method
if self._reduction == "mean":
return loss.mean()
if self._reduction == "sum":
return loss.sum()
# Raise an error if the reduction method is not recognized
raise ValueError(
f"Unknown reduction: {self._reduction}."
" Supported reductions are 'mean' and 'sum'."
)
@property
def model(self):
@@ -332,7 +405,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
return self.models[0]
@property
def weights_dict(self):
def weights(self):
"""
The self-adaptive weights.

View File

@@ -42,9 +42,11 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("weight_fn", [torch.nn.Sigmoid(), torch.nn.Tanh()])
def test_constructor(problem, weight_fn):
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
with pytest.raises(ValueError):
SAPINN(model=model, problem=problem, weight_function=1)
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
assert solver.accepted_conditions_types == (
InputTargetCondition,
@@ -53,26 +55,13 @@ def test_constructor(problem, weight_fn):
)
@pytest.mark.parametrize("problem", [problem, inverse_problem])
def test_wrong_batch(problem):
with pytest.raises(NotImplementedError):
solver = SAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=10,
train_size=1.0,
val_size=0.0,
test_size=0.0,
)
trainer.train()
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_train(problem, compile):
solver = SAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_train(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,
@@ -95,8 +84,11 @@ def test_solver_train(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(problem, compile):
solver = SAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_validation(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,
@@ -119,8 +111,11 @@ def test_solver_validation(problem, compile):
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(problem, compile):
solver = SAPINN(model=model, problem=problem)
@pytest.mark.parametrize(
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
)
def test_solver_test(problem, compile, loss):
solver = SAPINN(model=model, problem=problem, loss=loss)
trainer = Trainer(
solver=solver,
max_epochs=2,