batching for rbapinns
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
"""Module for the Residual-Based Attention PINN solver."""
|
"""Module for the Residual-Based Attention PINN solver."""
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .pinn import PINN
|
from .pinn import PINN
|
||||||
@@ -98,6 +97,8 @@ class RBAPINN(PINN):
|
|||||||
:param float gamma: The decay parameter in the update of the weights
|
:param float gamma: The decay parameter in the update of the weights
|
||||||
of the residuals. Must be between ``0`` and ``1``.
|
of the residuals. Must be between ``0`` and ``1``.
|
||||||
Default is ``0.999``.
|
Default is ``0.999``.
|
||||||
|
:raises: ValueError if `gamma` is not in the range (0, 1).
|
||||||
|
:raises: ValueError if `eta` is not greater than 0.
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -111,78 +112,201 @@ class RBAPINN(PINN):
|
|||||||
# check consistency
|
# check consistency
|
||||||
check_consistency(eta, (float, int))
|
check_consistency(eta, (float, int))
|
||||||
check_consistency(gamma, float)
|
check_consistency(gamma, float)
|
||||||
assert (
|
|
||||||
0 < gamma < 1
|
# Validate range for gamma
|
||||||
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}"
|
if not 0 < gamma < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid range: expected 0 < gamma < 1, but got {gamma}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate range for eta
|
||||||
|
if eta <= 0:
|
||||||
|
raise ValueError(f"Invalid range: expected eta > 0, but got {eta}")
|
||||||
|
|
||||||
|
# Initialize parameters
|
||||||
self.eta = eta
|
self.eta = eta
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
|
|
||||||
# initialize weights
|
# Initialize the weight of each point to 0
|
||||||
self.weights = {}
|
self.weights = {}
|
||||||
for condition_name in problem.conditions:
|
for cond, data in self.problem.input_pts.items():
|
||||||
self.weights[condition_name] = 0
|
buffer_tensor = torch.zeros((len(data), 1), device=self.device)
|
||||||
|
self.register_buffer(f"weight_{cond}", buffer_tensor)
|
||||||
|
self.weights[cond] = getattr(self, f"weight_{cond}")
|
||||||
|
|
||||||
# define vectorial loss
|
# Extract the reduction method from the loss function
|
||||||
self._vectorial_loss = deepcopy(self.loss)
|
self._reduction = self._loss_fn.reduction
|
||||||
self._vectorial_loss.reduction = "none"
|
|
||||||
|
|
||||||
# for now RBAPINN is implemented only for batch_size = None
|
# Set the loss function to return non-aggregated losses
|
||||||
def on_train_start(self):
|
self._loss_fn = type(self._loss_fn)(reduction="none")
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, **kwargs):
|
||||||
"""
|
"""
|
||||||
Hook method called at the beginning of training.
|
Solver training step. It computes the optimization cycle and aggregates
|
||||||
|
the losses using the ``weighting`` attribute.
|
||||||
|
|
||||||
:raises NotImplementedError: If the batch size is not ``None``.
|
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||||
|
tuple containing a condition name and a dictionary of points.
|
||||||
|
:param int batch_idx: The index of the current batch.
|
||||||
|
:param dict kwargs: Additional keyword arguments passed to
|
||||||
|
``optimization_cycle``.
|
||||||
|
:return: The loss of the training step.
|
||||||
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
if self.trainer.batch_size is not None:
|
loss = self._optimization_cycle(
|
||||||
raise NotImplementedError(
|
batch=batch, batch_idx=batch_idx, **kwargs
|
||||||
"RBAPINN only works with full batch "
|
|
||||||
"size, set batch_size=None inside the "
|
|
||||||
"Trainer to use the solver."
|
|
||||||
)
|
)
|
||||||
return super().on_train_start()
|
self.store_log("train_loss", loss, self.get_batch_size(batch))
|
||||||
|
return loss
|
||||||
|
|
||||||
def _vect_to_scalar(self, loss_value):
|
@torch.set_grad_enabled(True)
|
||||||
|
def validation_step(self, batch, **kwargs):
|
||||||
"""
|
"""
|
||||||
Computation of the scalar loss.
|
The validation step for the PINN solver. It returns the average residual
|
||||||
|
computed with the ``loss`` function not aggregated.
|
||||||
|
|
||||||
:param LabelTensor loss_value: the tensor of pointwise losses.
|
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||||
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
|
tuple containing a condition name and a dictionary of points.
|
||||||
:return: The computed scalar loss.
|
:param dict kwargs: Additional keyword arguments passed to
|
||||||
:rtype: LabelTensor
|
``optimization_cycle``.
|
||||||
|
:return: The loss of the validation step.
|
||||||
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
if self.loss.reduction == "mean":
|
losses = self.optimization_cycle(batch=batch, **kwargs)
|
||||||
ret = torch.mean(loss_value)
|
|
||||||
elif self.loss.reduction == "sum":
|
|
||||||
ret = torch.sum(loss_value)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Invalid reduction, got {self.loss.reduction} "
|
|
||||||
"but expected mean or sum."
|
|
||||||
)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def loss_phys(self, samples, equation):
|
# Aggregate losses for each condition
|
||||||
|
for cond, loss in losses.items():
|
||||||
|
losses[cond] = self._apply_reduction(loss=losses[cond])
|
||||||
|
|
||||||
|
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
|
||||||
|
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.set_grad_enabled(True)
|
||||||
|
def test_step(self, batch, **kwargs):
|
||||||
"""
|
"""
|
||||||
Computes the physics loss for the physics-informed solver based on the
|
The test step for the PINN solver. It returns the average residual
|
||||||
provided samples and equation.
|
computed with the ``loss`` function not aggregated.
|
||||||
|
|
||||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||||
:param EquationInterface equation: The governing equation.
|
tuple containing a condition name and a dictionary of points.
|
||||||
:return: The computed physics loss.
|
:param dict kwargs: Additional keyword arguments passed to
|
||||||
:rtype: LabelTensor
|
``optimization_cycle``.
|
||||||
|
:return: The loss of the test step.
|
||||||
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
residual = self.compute_residual(samples=samples, equation=equation)
|
losses = self.optimization_cycle(batch=batch, **kwargs)
|
||||||
cond = self.current_condition_name
|
|
||||||
|
|
||||||
r_norm = (
|
# Aggregate losses for each condition
|
||||||
self.eta
|
for cond, loss in losses.items():
|
||||||
* torch.abs(residual)
|
losses[cond] = self._apply_reduction(loss=losses[cond])
|
||||||
/ (torch.max(torch.abs(residual)) + 1e-12)
|
|
||||||
)
|
|
||||||
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
|
|
||||||
|
|
||||||
loss_value = self._vectorial_loss(
|
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
|
||||||
torch.zeros_like(residual, requires_grad=True), residual
|
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _optimization_cycle(self, batch, batch_idx, **kwargs):
|
||||||
|
"""
|
||||||
|
Aggregate the loss for each condition in the batch.
|
||||||
|
|
||||||
|
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||||
|
tuple containing a condition name and a dictionary of points.
|
||||||
|
:param int batch_idx: The index of the current batch.
|
||||||
|
:param dict kwargs: Additional keyword arguments passed to
|
||||||
|
``optimization_cycle``.
|
||||||
|
:return: The losses computed for all conditions in the batch, casted
|
||||||
|
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||||
|
containing the condition name and the associated scalar loss.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
# compute non-aggregated residuals
|
||||||
|
residuals = self.optimization_cycle(batch)
|
||||||
|
|
||||||
|
# update weights based on residuals
|
||||||
|
self._update_weights(batch, batch_idx, residuals)
|
||||||
|
|
||||||
|
# compute losses
|
||||||
|
losses = {}
|
||||||
|
for cond, res in residuals.items():
|
||||||
|
|
||||||
|
# Get the correct indices for the weights. Modulus is used according
|
||||||
|
# to the number of points in the condition, as in the PinaDataset.
|
||||||
|
len_res = len(res)
|
||||||
|
idx = torch.arange(
|
||||||
|
batch_idx * len_res,
|
||||||
|
(batch_idx + 1) * len_res,
|
||||||
|
device=res.device,
|
||||||
|
) % len(self.problem.input_pts[cond])
|
||||||
|
|
||||||
|
losses[cond] = self._apply_reduction(
|
||||||
|
loss=(res * self.weights[cond][idx])
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value)
|
# store log
|
||||||
|
self.store_log(
|
||||||
|
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
|
||||||
|
)
|
||||||
|
|
||||||
|
# clamp unknown parameters in InverseProblem (if needed)
|
||||||
|
self._clamp_params()
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _update_weights(self, batch, batch_idx, residuals):
|
||||||
|
"""
|
||||||
|
Update weights based on residuals.
|
||||||
|
|
||||||
|
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||||
|
tuple containing a condition name and a dictionary of points.
|
||||||
|
:param int batch_idx: The index of the current batch.
|
||||||
|
:param dict residuals: A dictionary containing the residuals for each
|
||||||
|
condition. The keys are the condition names and the values are the
|
||||||
|
residuals as tensors.
|
||||||
|
"""
|
||||||
|
# Iterate over each condition in the batch
|
||||||
|
for cond, data in batch:
|
||||||
|
|
||||||
|
# Compute normalized residuals
|
||||||
|
res = residuals[cond]
|
||||||
|
res_abs = res.abs()
|
||||||
|
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
|
||||||
|
|
||||||
|
# Get the correct indices for the weights. Modulus is used according
|
||||||
|
# to the number of points in the condition, as in the PinaDataset.
|
||||||
|
len_pts = len(data["input"])
|
||||||
|
idx = torch.arange(
|
||||||
|
batch_idx * len_pts,
|
||||||
|
(batch_idx + 1) * len_pts,
|
||||||
|
device=res.device,
|
||||||
|
) % len(self.problem.input_pts[cond])
|
||||||
|
|
||||||
|
# Update weights
|
||||||
|
weights = self.weights[cond]
|
||||||
|
update = self.gamma * weights[idx] + r_norm
|
||||||
|
weights[idx] = update.detach()
|
||||||
|
|
||||||
|
def _apply_reduction(self, loss):
|
||||||
|
"""
|
||||||
|
Apply the specified reduction to the loss. The reduction is deferred
|
||||||
|
until the end of the optimization cycle to allow residual-based weights
|
||||||
|
to be applied to each point beforehand.
|
||||||
|
|
||||||
|
:param torch.Tensor loss: The loss tensor to be reduced.
|
||||||
|
:return: The reduced loss tensor.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
:raises ValueError: If the reduction method is neither "mean" nor "sum".
|
||||||
|
"""
|
||||||
|
# Apply the specified reduction method
|
||||||
|
if self._reduction == "mean":
|
||||||
|
return loss.mean()
|
||||||
|
if self._reduction == "sum":
|
||||||
|
return loss.sum()
|
||||||
|
|
||||||
|
# Raise an error if the reduction method is not recognized
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown reduction: {self._reduction}."
|
||||||
|
" Supported reductions are 'mean' and 'sum'."
|
||||||
|
)
|
||||||
|
|||||||
@@ -42,10 +42,14 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
|||||||
@pytest.mark.parametrize("eta", [1, 0.001])
|
@pytest.mark.parametrize("eta", [1, 0.001])
|
||||||
@pytest.mark.parametrize("gamma", [0.5, 0.9])
|
@pytest.mark.parametrize("gamma", [0.5, 0.9])
|
||||||
def test_constructor(problem, eta, gamma):
|
def test_constructor(problem, eta, gamma):
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
|
|
||||||
solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma)
|
solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
solver = RBAPINN(model=model, problem=problem, eta=-0.1)
|
||||||
|
|
||||||
assert solver.accepted_conditions_types == (
|
assert solver.accepted_conditions_types == (
|
||||||
InputTargetCondition,
|
InputTargetCondition,
|
||||||
InputEquationCondition,
|
InputEquationCondition,
|
||||||
@@ -54,30 +58,18 @@ def test_constructor(problem, eta, gamma):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||||
def test_wrong_batch(problem):
|
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
solver = RBAPINN(model=model, problem=problem)
|
|
||||||
trainer = Trainer(
|
|
||||||
solver=solver,
|
|
||||||
max_epochs=2,
|
|
||||||
accelerator="cpu",
|
|
||||||
batch_size=10,
|
|
||||||
train_size=1.0,
|
|
||||||
val_size=0.0,
|
|
||||||
test_size=0.0,
|
|
||||||
)
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
|
||||||
@pytest.mark.parametrize("compile", [True, False])
|
@pytest.mark.parametrize("compile", [True, False])
|
||||||
def test_solver_train(problem, compile):
|
@pytest.mark.parametrize(
|
||||||
solver = RBAPINN(model=model, problem=problem)
|
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
|
||||||
|
)
|
||||||
|
def test_solver_train(problem, batch_size, loss, compile):
|
||||||
|
solver = RBAPINN(model=model, problem=problem, loss=loss)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
solver=solver,
|
solver=solver,
|
||||||
max_epochs=2,
|
max_epochs=2,
|
||||||
accelerator="cpu",
|
accelerator="cpu",
|
||||||
batch_size=None,
|
batch_size=batch_size,
|
||||||
train_size=1.0,
|
train_size=1.0,
|
||||||
val_size=0.0,
|
val_size=0.0,
|
||||||
test_size=0.0,
|
test_size=0.0,
|
||||||
@@ -89,14 +81,18 @@ def test_solver_train(problem, compile):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||||
|
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||||
@pytest.mark.parametrize("compile", [True, False])
|
@pytest.mark.parametrize("compile", [True, False])
|
||||||
def test_solver_validation(problem, compile):
|
@pytest.mark.parametrize(
|
||||||
solver = RBAPINN(model=model, problem=problem)
|
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
|
||||||
|
)
|
||||||
|
def test_solver_validation(problem, batch_size, loss, compile):
|
||||||
|
solver = RBAPINN(model=model, problem=problem, loss=loss)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
solver=solver,
|
solver=solver,
|
||||||
max_epochs=2,
|
max_epochs=2,
|
||||||
accelerator="cpu",
|
accelerator="cpu",
|
||||||
batch_size=None,
|
batch_size=batch_size,
|
||||||
train_size=0.9,
|
train_size=0.9,
|
||||||
val_size=0.1,
|
val_size=0.1,
|
||||||
test_size=0.0,
|
test_size=0.0,
|
||||||
@@ -108,14 +104,18 @@ def test_solver_validation(problem, compile):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||||
|
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
|
||||||
@pytest.mark.parametrize("compile", [True, False])
|
@pytest.mark.parametrize("compile", [True, False])
|
||||||
def test_solver_test(problem, compile):
|
@pytest.mark.parametrize(
|
||||||
solver = RBAPINN(model=model, problem=problem)
|
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
|
||||||
|
)
|
||||||
|
def test_solver_test(problem, batch_size, loss, compile):
|
||||||
|
solver = RBAPINN(model=model, problem=problem, loss=loss)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
solver=solver,
|
solver=solver,
|
||||||
max_epochs=2,
|
max_epochs=2,
|
||||||
accelerator="cpu",
|
accelerator="cpu",
|
||||||
batch_size=None,
|
batch_size=batch_size,
|
||||||
train_size=0.7,
|
train_size=0.7,
|
||||||
val_size=0.2,
|
val_size=0.2,
|
||||||
test_size=0.1,
|
test_size=0.1,
|
||||||
|
|||||||
Reference in New Issue
Block a user