Lightining update (#104)

* multiple functions for version 0.0
* lightining update
* minor changes
* data pinn  loss added
---------

Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-3-125.WIFIeduroamSTUD.units.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.station>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@192.168.1.38>
This commit is contained in:
Dario Coscia
2023-06-07 15:34:43 +02:00
committed by Nicola Demo
parent 0e3625de80
commit 63fd068988
16 changed files with 710 additions and 603 deletions

View File

@@ -11,11 +11,11 @@ class Plotter:
Implementation of a plotter class, for easy visualizations.
"""
def plot_samples(self, pinn, variables=None):
def plot_samples(self, solver, variables=None):
"""
Plot a sample of solution.
Plot the training grid samples.
:param PINN pinn: the PINN object.
:param SolverInterface solver: the SolverInterface object.
:param list(str) variables: variables to plot. If None, all variables
are plotted. If 'spatial', only spatial variables are plotted. If
'temporal', only temporal variables are plotted. Defaults to None.
@@ -26,15 +26,15 @@ class Plotter:
:Example:
>>> plotter = Plotter()
>>> plotter.plot_samples(pinn=pinn, variables='spatial')
>>> plotter.plot_samples(solver=solver, variables='spatial')
"""
if variables is None:
variables = pinn.problem.domain.variables
variables = solver.problem.domain.variables
elif variables == 'spatial':
variables = pinn.problem.spatial_domain.variables
variables = solver.problem.spatial_domain.variables
elif variables == 'temporal':
variables = pinn.problem.temporal_domain.variables
variables = solver.problem.temporal_domain.variables
if len(variables) not in [1, 2, 3]:
raise ValueError
@@ -42,8 +42,8 @@ class Plotter:
fig = plt.figure()
proj = '3d' if len(variables) == 3 else None
ax = fig.add_subplot(projection=proj)
for location in pinn.input_pts:
coords = pinn.input_pts[location].extract(variables).T.detach()
for location in solver.problem.input_pts:
coords = solver.problem.input_pts[location].extract(variables).T.detach()
if coords.shape[0] == 1: # 1D samples
ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
label=location)
@@ -69,7 +69,7 @@ class Plotter:
:param pts: Points to plot the solution.
:type pts: torch.Tensor
:param pred: PINN solution evaluated at 'pts'.
:param pred: SolverInterface solution evaluated at 'pts'.
:type pred: torch.Tensor
:param method: not used, kept for code compatibility
:type method: None
@@ -95,7 +95,7 @@ class Plotter:
:param pts: Points to plot the solution.
:type pts: torch.Tensor
:param pred: PINN solution evaluated at 'pts'.
:param pred: SolverInterface solution evaluated at 'pts'.
:type pred: torch.Tensor
:param method: matplotlib method to plot 2-dimensional data,
see https://matplotlib.org/stable/api/axes_api.html for
@@ -129,12 +129,12 @@ class Plotter:
*grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax)
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
res=256, filename=None, **kwargs):
"""
Plot sample of PINN output.
Plot sample of SolverInterface output.
:param PINN pinn: the PINN object.
:param SolverInterface solver: the SolverInterface object.
:param list(str) components: the output variable to plot. If None, all
the output variables of the problem are selected. Default value is
None.
@@ -150,12 +150,12 @@ class Plotter:
is shown using the setted matplotlib frontend. Default is None.
"""
if components is None:
components = [pinn.problem.output_variables]
components = [solver.problem.output_variables]
v = [
var for var in pinn.problem.input_variables
var for var in solver.problem.input_variables
if var not in fixed_variables.keys()
]
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
pts = solver.problem.domain.sample(res, 'grid', variables=v)
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
fixed_pts *= torch.tensor(list(fixed_variables.values()))
@@ -163,15 +163,15 @@ class Plotter:
fixed_pts.labels = list(fixed_variables.keys())
pts = pts.append(fixed_pts)
pts = pts.to(device=pinn.device)
pts = pts.to(device=solver.device)
predicted_output = pinn.model(pts)
predicted_output = solver.forward(pts)
if isinstance(components, str):
predicted_output = predicted_output.extract(components)
elif callable(components):
predicted_output = components(predicted_output)
truth_solution = getattr(pinn.problem, 'truth_solution', None)
truth_solution = getattr(solver.problem, 'truth_solution', None)
if len(v) == 1:
self._1d_plot(pts, predicted_output, method, truth_solution,
**kwargs)
@@ -186,37 +186,25 @@ class Plotter:
else:
plt.show()
def plot_loss(self, pinn, label=None, log_scale=True, filename=None):
"""
Plot the loss function values during traininig.
# TODO loss
# def plot_loss(self, solver, label=None, log_scale=True):
# """
# Plot the loss function values during traininig.
:param PINN pinn: the PINN object.
:param str label: the label to use in the legend, defaults to None.
:param bool log_scale: If True, the y axis is in log scale. Default is
True.
:param str filename: the file name to save the plot. If None, the plot
is not saved. Default is None.
"""
# :param SolverInterface solver: the SolverInterface object.
# :param str label: the label to use in the legend, defaults to None.
# :param bool log_scale: If True, the y axis is in log scale. Default is
# True.
# """
if not label:
label = str(pinn)
# if not label:
# label = str(solver)
epochs = list(pinn.history_loss.keys())
loss = np.array(list(pinn.history_loss.values()))
# epochs = list(solver.history_loss.keys())
# loss = np.array(list(solver.history_loss.values()))
# if loss.ndim != 1:
# loss = loss[:, 0]
# if multiple outputs, sum the loss
if loss.ndim != 1:
loss = np.sum(loss, axis=1)
# plot loss
plt.plot(epochs, loss, label=label)
plt.legend()
if log_scale:
plt.yscale('log')
plt.title('Loss function')
plt.xlabel('Epochs')
plt.ylabel('Loss')
# save plot
if filename:
plt.savefig(filename)
# plt.plot(epochs, loss, label=label)
# if log_scale:
# plt.yscale('log')