Tutorials v0.1 (#178)

Tutorial update and small fixes

* Tutorials update + Tutorial FNO
* Create a metric tracker callback
* Update PINN for logging
* Update plotter for plotting
* Small fix LabelTensor
* Small fix FNO

---------

Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-13-250.WIFIeduroamSTUD.units.it>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-176.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-09-26 17:29:37 +02:00
committed by Nicola Demo
parent 939353f517
commit a9b1bd2826
45 changed files with 2760 additions and 1321 deletions

View File

@@ -1,7 +1,9 @@
__all__ = [
'SwitchOptimizer',
'R3Refinement',
'MetricTracker'
]
from .optimizer_callbacks import SwitchOptimizer
from .adaptive_refinment_callbacks import R3Refinement
from .adaptive_refinment_callbacks import R3Refinement
from .processing_callbacks import MetricTracker

View File

@@ -0,0 +1,25 @@
'''PINA Callbacks Implementations'''
from lightning.pytorch.callbacks import Callback
import torch
import copy
class MetricTracker(Callback):
"""
PINA implementation of a Lightining Callback to track relevant
metrics during training.
"""
def __init__(self):
self._collection = []
def on_train_epoch_end(self, trainer, __):
self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them
@property
def metrics(self):
common_keys = set.intersection(*map(set, self._collection))
v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
return v

View File

@@ -63,7 +63,7 @@ class LabelTensor(torch.Tensor):
if isinstance(labels, str):
labels = [labels]
if len(labels) != x.shape[1]:
if len(labels) != x.shape[-1]:
raise ValueError(
'the tensor has not the same number of columns of '
'the passed labels.'

View File

@@ -94,7 +94,9 @@ class FNO(torch.nn.Module):
# 4. Build the FNO network
tmp_layers = layers.copy()
out_feats = lifting_net(torch.rand(10, dimensions)).shape[-1]
first_parameter = next(lifting_net.parameters())
input_shape = first_parameter.size()
out_feats = lifting_net(torch.rand(size=input_shape)).shape[-1]
tmp_layers.insert(0, out_feats)
self._layers = []

View File

@@ -1,6 +1,7 @@
""" Module for plotting. """
import matplotlib.pyplot as plt
import torch
from pina.callbacks import MetricTracker
from pina import LabelTensor
@@ -129,12 +130,12 @@ class Plotter:
*grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax)
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
def plot(self, trainer, components=None, fixed_variables={}, method='contourf',
res=256, filename=None, **kwargs):
"""
Plot sample of SolverInterface output.
:param SolverInterface solver: the SolverInterface object.
:param Trainer trainer: the Trainer object.
:param list(str) components: the output variable to plot. If None, all
the output variables of the problem are selected. Default value is
None.
@@ -149,6 +150,7 @@ class Plotter:
:param str filename: the file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None.
"""
solver = trainer.solver
if components is None:
components = [solver.problem.output_variables]
v = [
@@ -186,25 +188,38 @@ class Plotter:
else:
plt.show()
# TODO loss
# def plot_loss(self, solver, label=None, log_scale=True):
# """
# Plot the loss function values during traininig.
def plot_loss(self, trainer, metric=None, label=None, log_scale=True):
"""
Plot the loss function values during traininig.
# :param SolverInterface solver: the SolverInterface object.
# :param str label: the label to use in the legend, defaults to None.
# :param bool log_scale: If True, the y axis is in log scale. Default is
# True.
# """
:param SolverInterface solver: the SolverInterface object.
:param str metric: the metric to use in the y axis.
:param str label: the label to use in the legend, defaults to None.
:param bool log_scale: If True, the y axis is in log scale. Default is
True.
"""
# if not label:
# label = str(solver)
# check that MetricTracker has been used
list_ = [idx for idx, s in enumerate(trainer.callbacks) if isinstance(s, MetricTracker)]
if not bool(list_):
raise FileNotFoundError('MetricTracker should be used as a callback during training to'
' use this method.')
# epochs = list(solver.history_loss.keys())
# loss = np.array(list(solver.history_loss.values()))
# if loss.ndim != 1:
# loss = loss[:, 0]
metrics = trainer.callbacks[list_[0]].metrics
# plt.plot(epochs, loss, label=label)
# if log_scale:
# plt.yscale('log')
if not metric:
metric = 'mean_loss'
loss = metrics[metric]
epochs = range(len(loss))
if label is not None:
plt.plot(epochs, loss, label=label)
plt.legend()
else:
plt.plot(epochs, loss)
if log_scale:
plt.yscale('log')
plt.xlabel('epoch')
plt.ylabel(metric)

View File

@@ -5,3 +5,4 @@ __all__ = [
from .garom import GAROM
from .pinn import PINN
from .supervised import SupervisedSolver

View File

@@ -109,12 +109,14 @@ class PINN(SolverInterface):
"""
condition_losses = []
condition_names = []
for condition_name, samples in batch.items():
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition_names.append(condition_name)
condition = self.problem.conditions[condition_name]
# PINN loss: equation evaluated on location or input_points
@@ -132,9 +134,9 @@ class PINN(SolverInterface):
# we need to pass it as a torch tensor to make everything work
total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=False)
for condition_loss, loss in zip(self.problem.conditions, condition_losses):
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=False)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
for condition_loss, loss in zip(condition_names, condition_losses):
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss
@property

134
pina/solvers/supervised.py Normal file
View File

@@ -0,0 +1,134 @@
""" Module for SupervisedSolver """
import torch
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .solver import SolverInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss import LossInterface
from torch.nn.modules.loss import _Loss
class SupervisedSolver(SolverInterface):
"""
SupervisedSolver solver class. This class implements a SupervisedSolver,
using a user specified ``model`` to solve a specific ``problem``.
"""
def __init__(self,
problem,
model,
extra_features=None,
loss = torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={'lr' : 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
):
'''
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer,
default torch.nn.MSELoss().
:param torch.nn.Module extra_features: The additional input
features to use as augmented input.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is `torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param float lr: The learning rate; default is 0.001.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
'''
super().__init__(models=[model],
problem=problem,
optimizers=[optimizer],
optimizers_kwargs=[optimizer_kwargs],
extra_features=extra_features)
# check consistency
check_consistency(scheduler, LRScheduler, subclass=True)
check_consistency(scheduler_kwargs, dict)
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
self._loss = loss
self._neural_net = self.models[0]
def forward(self, x):
"""Forward pass implementation for the solver.
:param torch.tensor x: Input data.
:return: Solver solution.
:rtype: torch.tensor
"""
# extract labels
x = x.extract(self.problem.input_variables)
# perform forward pass
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels
output.labels = self.problem.output_variables
return output
def configure_optimizers(self):
"""Optimizer configuration for the solver.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
return self.optimizers, [self.scheduler]
def training_step(self, batch, batch_idx):
"""Solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param batch_idx: The batch index.
:type batch_idx: int
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
for condition_name, samples in batch.items():
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition = self.problem.conditions[condition_name]
# data loss
if hasattr(condition, 'output_points'):
input_pts, output_pts = samples
loss = self.loss(self.forward(input_pts), output_pts) * condition.data_weight
else:
raise RuntimeError('Supervised solver works only in data-driven mode.')
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
return loss
@property
def scheduler(self):
"""
Scheduler for training.
"""
return self._scheduler
@property
def neural_net(self):
"""
Neural network for training.
"""
return self._neural_net
@property
def loss(self):
"""
Loss for training.
"""
return self._loss