Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for Causal PINN. """
|
||||
"""Module for Causal PINN."""
|
||||
|
||||
import torch
|
||||
|
||||
@@ -67,14 +67,16 @@ class CausalPINN(PINN):
|
||||
:class:`~pina.problem.timedep_problem.TimeDependentProblem` class.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
eps=100):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
eps=100,
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
@@ -88,12 +90,14 @@ class CausalPINN(PINN):
|
||||
default `None`.
|
||||
:param float eps: The exponential decay parameter; default `100`.
|
||||
"""
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# checking consistency
|
||||
check_consistency(eps, (int, float))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Module for Competitive PINN. """
|
||||
"""Module for Competitive PINN."""
|
||||
|
||||
import torch
|
||||
import copy
|
||||
@@ -55,16 +55,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
``extra_feature``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
discriminator=None,
|
||||
optimizer_model=None,
|
||||
optimizer_discriminator=None,
|
||||
scheduler_model=None,
|
||||
scheduler_discriminator=None,
|
||||
weighting=None,
|
||||
loss=None):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
discriminator=None,
|
||||
optimizer_model=None,
|
||||
optimizer_discriminator=None,
|
||||
scheduler_model=None,
|
||||
scheduler_discriminator=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use
|
||||
@@ -72,13 +74,13 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
:param torch.nn.Module discriminator: The neural network model to use
|
||||
for the discriminator. If ``None``, the discriminator network will
|
||||
have the same architecture as the model network.
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
optimizer to use for the model network; default `None`.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The neural network
|
||||
optimizer to use for the discriminator network; default `None`.
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
|
||||
for the model; default `None`.
|
||||
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
|
||||
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
|
||||
scheduler for the discriminator; default `None`.
|
||||
:param WeightingInterface weighting: The weighting schema to use;
|
||||
default `None`.
|
||||
@@ -88,12 +90,14 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
if discriminator is None:
|
||||
discriminator = copy.deepcopy(model)
|
||||
|
||||
super().__init__(models=[model, discriminator],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_discriminator],
|
||||
schedulers=[scheduler_model, scheduler_discriminator],
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
models=[model, discriminator],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_discriminator],
|
||||
schedulers=[scheduler_model, scheduler_discriminator],
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# Set automatic optimization to False
|
||||
self.automatic_optimization = False
|
||||
@@ -158,7 +162,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the CompetitivePINN solver. It computes the loss
|
||||
The data loss for the CompetitivePINN solver. It computes the loss
|
||||
between the network output against the true solution.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
@@ -167,7 +171,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
:return: The computed data loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss_val = (super().loss_data(input_pts, output_pts))
|
||||
loss_val = super().loss_data(input_pts, output_pts)
|
||||
# prepare for optimizer step called in training step
|
||||
loss_val.backward()
|
||||
return loss_val
|
||||
@@ -195,10 +199,14 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
self.scheduler_model.hook(self.optimizer_model)
|
||||
self.scheduler_discriminator.hook(self.optimizer_discriminator)
|
||||
return (
|
||||
[self.optimizer_model.instance,
|
||||
self.optimizer_discriminator.instance],
|
||||
[self.scheduler_model.instance,
|
||||
self.scheduler_discriminator.instance]
|
||||
[
|
||||
self.optimizer_model.instance,
|
||||
self.optimizer_discriminator.instance,
|
||||
],
|
||||
[
|
||||
self.scheduler_model.instance,
|
||||
self.scheduler_discriminator.instance,
|
||||
],
|
||||
)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
@@ -216,8 +224,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization
|
||||
.optim_step_progress.total.completed
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Module for Gradient PINN. """
|
||||
"""Module for Gradient PINN."""
|
||||
|
||||
import torch
|
||||
|
||||
@@ -59,18 +59,20 @@ class GradientPINN(PINN):
|
||||
class.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param AbstractProblem problem: The formulation of the problem. It must
|
||||
inherit from at least
|
||||
:class:`~pina.problem.spatial_problem.SpatialProblem` to compute
|
||||
:class:`~pina.problem.spatial_problem.SpatialProblem` to compute
|
||||
the gradient of the loss.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default `None`.
|
||||
@@ -81,12 +83,14 @@ class GradientPINN(PINN):
|
||||
:param torch.nn.Module loss: The loss function to be minimized;
|
||||
default `None`.
|
||||
"""
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
if not isinstance(self.problem, SpatialProblem):
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Module for Physics Informed Neural Network. """
|
||||
"""Module for Physics Informed Neural Network."""
|
||||
|
||||
import torch
|
||||
|
||||
@@ -48,13 +48,15 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
@@ -67,12 +69,14 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
:param torch.nn.Module loss: The loss function to be minimized;
|
||||
default `None`.
|
||||
"""
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
@@ -112,7 +116,4 @@ class PINN(PINNInterface, SingleSolverInterface):
|
||||
}
|
||||
)
|
||||
self.scheduler.hook(self.optimizer)
|
||||
return (
|
||||
[self.optimizer.instance],
|
||||
[self.scheduler.instance]
|
||||
)
|
||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Module for Physics Informed Neural Network Interface."""
|
||||
"""Module for Physics Informed Neural Network Interface."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
@@ -11,7 +11,7 @@ from ...problem import InverseProblem
|
||||
from ...condition import (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
DomainEquationCondition
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,22 +20,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
Base PINN solver class. This class implements the Solver Interface
|
||||
for Physics Informed Neural Network solver.
|
||||
|
||||
This class can be used to define PINNs with multiple ``optimizers``,
|
||||
This class can be used to define PINNs with multiple ``optimizers``,
|
||||
and/or ``models``.
|
||||
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`,
|
||||
so the user can choose what type of problem the implemented solver,
|
||||
inheriting from this class, is designed to solve.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
DomainEquationCondition
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
loss=None,
|
||||
**kwargs):
|
||||
def __init__(self, problem, loss=None, **kwargs):
|
||||
"""
|
||||
:param AbstractProblem problem: A problem definition instance.
|
||||
:param torch.nn.Module loss: The loss function to be minimized,
|
||||
@@ -45,9 +43,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
super().__init__(problem=problem,
|
||||
use_lt=True,
|
||||
**kwargs)
|
||||
super().__init__(problem=problem, use_lt=True, **kwargs)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
@@ -72,14 +68,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
def validation_step(self, batch):
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log('val_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def test_step(self, batch):
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log('test_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@@ -129,42 +125,38 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
except TypeError:
|
||||
# this occurs when the function has three inputs (inverse problem)
|
||||
residual = equation.residual(
|
||||
samples,
|
||||
self.forward(samples),
|
||||
self._params
|
||||
samples, self.forward(samples), self._params
|
||||
)
|
||||
return residual
|
||||
|
||||
def _residual_loss(self, samples, equation):
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self.loss(residuals, torch.zeros_like(residuals))
|
||||
|
||||
|
||||
def _run_optimization_cycle(self, batch, loss_residuals):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if 'output_points' not in points:
|
||||
input_pts = points['input_points']
|
||||
if "output_points" not in points:
|
||||
input_pts = points["input_points"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(),
|
||||
condition.equation
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
output_pts = points['output_points']
|
||||
input_pts = points["input_points"]
|
||||
output_pts = points["output_points"]
|
||||
loss = self.loss_data(
|
||||
input_pts=input_pts.requires_grad_(),
|
||||
output_pts=output_pts
|
||||
input_pts=input_pts.requires_grad_(), output_pts=output_pts
|
||||
)
|
||||
# append loss
|
||||
condition_loss[condition_name] = loss
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
return condition_loss
|
||||
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
Clamps the parameters of the inverse problem
|
||||
@@ -175,14 +167,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
self.problem.unknown_parameter_domain.range_[v][0],
|
||||
self.problem.unknown_parameter_domain.range_[v][1],
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
Loss used for training.
|
||||
"""
|
||||
return self._loss
|
||||
|
||||
|
||||
@property
|
||||
def current_condition_name(self):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Module for Residual-Based Attention PINN. """
|
||||
"""Module for Residual-Based Attention PINN."""
|
||||
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
@@ -66,15 +66,17 @@ class RBAPINN(PINN):
|
||||
j.cma.2024.116805 <https://doi.org/10.1016/j.cma.2024.116805>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
eta=0.001,
|
||||
gamma=0.999):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
eta=0.001,
|
||||
gamma=0.999,
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
@@ -86,17 +88,19 @@ class RBAPINN(PINN):
|
||||
default `None`.
|
||||
:param torch.nn.Module loss: The loss function to be minimized;
|
||||
default `None`.
|
||||
:param float | int eta: The learning rate for the weights of the
|
||||
:param float | int eta: The learning rate for the weights of the
|
||||
residual; default 0.001.
|
||||
:param float gamma: The decay parameter in the update of the weights
|
||||
of the residual. Must be between 0 and 1; default 0.999.
|
||||
"""
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(eta, (float, int))
|
||||
@@ -119,9 +123,11 @@ class RBAPINN(PINN):
|
||||
# for now RBAPINN is implemented only for batch_size = None
|
||||
def on_train_start(self):
|
||||
if self.trainer.batch_size is not None:
|
||||
raise NotImplementedError("RBAPINN only works with full batch "
|
||||
"size, set batch_size=None inside the "
|
||||
"Trainer to use the solver.")
|
||||
raise NotImplementedError(
|
||||
"RBAPINN only works with full batch "
|
||||
"size, set batch_size=None inside the "
|
||||
"Trainer to use the solver."
|
||||
)
|
||||
return super().on_train_start()
|
||||
|
||||
def _vect_to_scalar(self, loss_value):
|
||||
@@ -160,10 +166,11 @@ class RBAPINN(PINN):
|
||||
cond = self.current_condition_name
|
||||
|
||||
r_norm = (
|
||||
self.eta * torch.abs(residual)
|
||||
self.eta
|
||||
* torch.abs(residual)
|
||||
/ (torch.max(torch.abs(residual)) + 1e-12)
|
||||
)
|
||||
self.weights[cond] = (self.gamma*self.weights[cond] + r_norm).detach()
|
||||
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
|
||||
|
||||
loss_value = self._vectorial_loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" Module for Self-Adaptive PINN. """
|
||||
"""Module for Self-Adaptive PINN."""
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
@@ -99,25 +99,27 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
j.jcp.2022.111722 <https://doi.org/10.1016/j.jcp.2022.111722>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
weight_function=torch.nn.Sigmoid(),
|
||||
optimizer_model=None,
|
||||
optimizer_weights=None,
|
||||
scheduler_model=None,
|
||||
scheduler_weights=None,
|
||||
weighting=None,
|
||||
loss=None):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
weight_function=torch.nn.Sigmoid(),
|
||||
optimizer_model=None,
|
||||
optimizer_weights=None,
|
||||
scheduler_model=None,
|
||||
scheduler_weights=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use for
|
||||
:param torch.nn.Module model: The neural network model to use for
|
||||
the model.
|
||||
:param torch.nn.Module weight_function: The neural network model
|
||||
related to the Self-Adaptive PINN mask; default `torch.nn.Sigmoid()`
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
optimizer to use for the model network; default `None`.
|
||||
:param torch.optim.Optimizer optimizer_weights: The neural network
|
||||
:param torch.optim.Optimizer optimizer_weights: The neural network
|
||||
optimizer to use for mask model; default `None`.
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
|
||||
for the model; default `None`.
|
||||
@@ -137,12 +139,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
weights_dict[condition_name] = Weights(weight_function)
|
||||
weights_dict = torch.nn.ModuleDict(weights_dict)
|
||||
|
||||
super().__init__(models=[model, weights_dict],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
schedulers=[scheduler_model, scheduler_weights],
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
models=[model, weights_dict],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
schedulers=[scheduler_model, scheduler_weights],
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# Set automatic optimization to False
|
||||
self.automatic_optimization = False
|
||||
@@ -202,7 +206,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
# create a new one by setting requires_grad to True.
|
||||
# In alternative set `retain_graph=True`.
|
||||
samples = samples.detach()
|
||||
samples.requires_grad_()# = True
|
||||
samples.requires_grad_() # = True
|
||||
|
||||
# Train the model
|
||||
weighted_loss = self._loss_phys(samples, equation)
|
||||
@@ -244,20 +248,18 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
self.optimizer_weights.hook(self.weights_dict.parameters())
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizer_model.instance.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
self.scheduler_model.hook(self.optimizer_model)
|
||||
self.scheduler_weights.hook(self.optimizer_weights)
|
||||
return (
|
||||
[self.optimizer_model.instance,
|
||||
self.optimizer_weights.instance],
|
||||
[self.scheduler_model.instance,
|
||||
self.scheduler_weights.instance]
|
||||
[self.optimizer_model.instance, self.optimizer_weights.instance],
|
||||
[self.scheduler_model.instance, self.scheduler_weights.instance],
|
||||
)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
@@ -275,8 +277,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization
|
||||
.optim_step_progress.total.completed
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
@@ -291,19 +292,22 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
:rtype: Any
|
||||
"""
|
||||
if self.trainer.batch_size is not None:
|
||||
raise NotImplementedError("SelfAdaptivePINN only works with full "
|
||||
"batch size, set batch_size=None inside "
|
||||
"the Trainer to use the solver.")
|
||||
raise NotImplementedError(
|
||||
"SelfAdaptivePINN only works with full "
|
||||
"batch size, set batch_size=None inside "
|
||||
"the Trainer to use the solver."
|
||||
)
|
||||
device = torch.device(
|
||||
self.trainer._accelerator_connector._accelerator_flag
|
||||
)
|
||||
|
||||
# Initialize the self adaptive weights only for training points
|
||||
for condition_name, tensor in (
|
||||
self.trainer.data_module.train_dataset.input_points.items()
|
||||
):
|
||||
self.weights_dict[condition_name].sa_weights.data = (
|
||||
torch.rand((tensor.shape[0], 1), device=device)
|
||||
for (
|
||||
condition_name,
|
||||
tensor,
|
||||
) in self.trainer.data_module.train_dataset.input_points.items():
|
||||
self.weights_dict[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1), device=device
|
||||
)
|
||||
return super().on_train_start()
|
||||
|
||||
@@ -318,11 +322,11 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
# First initialize self-adaptive weights with correct shape,
|
||||
# then load the values from the checkpoint.
|
||||
for condition_name, _ in self.problem.input_pts.items():
|
||||
shape = checkpoint['state_dict'][
|
||||
shape = checkpoint["state_dict"][
|
||||
f"_pina_models.1.{condition_name}.sa_weights"
|
||||
].shape
|
||||
self.weights_dict[condition_name].sa_weights.data = (
|
||||
torch.rand(shape)
|
||||
self.weights_dict[condition_name].sa_weights.data = torch.rand(
|
||||
shape
|
||||
)
|
||||
return super().on_load_checkpoint(checkpoint)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user