Lightining update (#104)
* multiple functions for version 0.0 * lightining update * minor changes * data pinn loss added --------- Co-authored-by: Nicola Demo <demo.nicola@gmail.com> Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-3-125.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.station> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@192.168.1.38>
This commit is contained in:
committed by
Nicola Demo
parent
0e3625de80
commit
63fd068988
@@ -11,11 +11,11 @@ class Plotter:
|
||||
Implementation of a plotter class, for easy visualizations.
|
||||
"""
|
||||
|
||||
def plot_samples(self, pinn, variables=None):
|
||||
def plot_samples(self, solver, variables=None):
|
||||
"""
|
||||
Plot a sample of solution.
|
||||
Plot the training grid samples.
|
||||
|
||||
:param PINN pinn: the PINN object.
|
||||
:param SolverInterface solver: the SolverInterface object.
|
||||
:param list(str) variables: variables to plot. If None, all variables
|
||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
||||
'temporal', only temporal variables are plotted. Defaults to None.
|
||||
@@ -26,15 +26,15 @@ class Plotter:
|
||||
|
||||
:Example:
|
||||
>>> plotter = Plotter()
|
||||
>>> plotter.plot_samples(pinn=pinn, variables='spatial')
|
||||
>>> plotter.plot_samples(solver=solver, variables='spatial')
|
||||
"""
|
||||
|
||||
if variables is None:
|
||||
variables = pinn.problem.domain.variables
|
||||
variables = solver.problem.domain.variables
|
||||
elif variables == 'spatial':
|
||||
variables = pinn.problem.spatial_domain.variables
|
||||
variables = solver.problem.spatial_domain.variables
|
||||
elif variables == 'temporal':
|
||||
variables = pinn.problem.temporal_domain.variables
|
||||
variables = solver.problem.temporal_domain.variables
|
||||
|
||||
if len(variables) not in [1, 2, 3]:
|
||||
raise ValueError
|
||||
@@ -42,8 +42,8 @@ class Plotter:
|
||||
fig = plt.figure()
|
||||
proj = '3d' if len(variables) == 3 else None
|
||||
ax = fig.add_subplot(projection=proj)
|
||||
for location in pinn.input_pts:
|
||||
coords = pinn.input_pts[location].extract(variables).T.detach()
|
||||
for location in solver.problem.input_pts:
|
||||
coords = solver.problem.input_pts[location].extract(variables).T.detach()
|
||||
if coords.shape[0] == 1: # 1D samples
|
||||
ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
|
||||
label=location)
|
||||
@@ -69,7 +69,7 @@ class Plotter:
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: PINN solution evaluated at 'pts'.
|
||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param method: not used, kept for code compatibility
|
||||
:type method: None
|
||||
@@ -95,7 +95,7 @@ class Plotter:
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: PINN solution evaluated at 'pts'.
|
||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param method: matplotlib method to plot 2-dimensional data,
|
||||
see https://matplotlib.org/stable/api/axes_api.html for
|
||||
@@ -129,12 +129,12 @@ class Plotter:
|
||||
*grids, pred_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
|
||||
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
|
||||
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
"""
|
||||
Plot sample of PINN output.
|
||||
Plot sample of SolverInterface output.
|
||||
|
||||
:param PINN pinn: the PINN object.
|
||||
:param SolverInterface solver: the SolverInterface object.
|
||||
:param list(str) components: the output variable to plot. If None, all
|
||||
the output variables of the problem are selected. Default value is
|
||||
None.
|
||||
@@ -150,12 +150,12 @@ class Plotter:
|
||||
is shown using the setted matplotlib frontend. Default is None.
|
||||
"""
|
||||
if components is None:
|
||||
components = [pinn.problem.output_variables]
|
||||
components = [solver.problem.output_variables]
|
||||
v = [
|
||||
var for var in pinn.problem.input_variables
|
||||
var for var in solver.problem.input_variables
|
||||
if var not in fixed_variables.keys()
|
||||
]
|
||||
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
|
||||
pts = solver.problem.domain.sample(res, 'grid', variables=v)
|
||||
|
||||
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
|
||||
fixed_pts *= torch.tensor(list(fixed_variables.values()))
|
||||
@@ -163,15 +163,15 @@ class Plotter:
|
||||
fixed_pts.labels = list(fixed_variables.keys())
|
||||
|
||||
pts = pts.append(fixed_pts)
|
||||
pts = pts.to(device=pinn.device)
|
||||
pts = pts.to(device=solver.device)
|
||||
|
||||
predicted_output = pinn.model(pts)
|
||||
predicted_output = solver.forward(pts)
|
||||
if isinstance(components, str):
|
||||
predicted_output = predicted_output.extract(components)
|
||||
elif callable(components):
|
||||
predicted_output = components(predicted_output)
|
||||
|
||||
truth_solution = getattr(pinn.problem, 'truth_solution', None)
|
||||
truth_solution = getattr(solver.problem, 'truth_solution', None)
|
||||
if len(v) == 1:
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution,
|
||||
**kwargs)
|
||||
@@ -186,37 +186,25 @@ class Plotter:
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_loss(self, pinn, label=None, log_scale=True, filename=None):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
# TODO loss
|
||||
# def plot_loss(self, solver, label=None, log_scale=True):
|
||||
# """
|
||||
# Plot the loss function values during traininig.
|
||||
|
||||
:param PINN pinn: the PINN object.
|
||||
:param str label: the label to use in the legend, defaults to None.
|
||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
True.
|
||||
:param str filename: the file name to save the plot. If None, the plot
|
||||
is not saved. Default is None.
|
||||
"""
|
||||
# :param SolverInterface solver: the SolverInterface object.
|
||||
# :param str label: the label to use in the legend, defaults to None.
|
||||
# :param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
# True.
|
||||
# """
|
||||
|
||||
if not label:
|
||||
label = str(pinn)
|
||||
# if not label:
|
||||
# label = str(solver)
|
||||
|
||||
epochs = list(pinn.history_loss.keys())
|
||||
loss = np.array(list(pinn.history_loss.values()))
|
||||
# epochs = list(solver.history_loss.keys())
|
||||
# loss = np.array(list(solver.history_loss.values()))
|
||||
# if loss.ndim != 1:
|
||||
# loss = loss[:, 0]
|
||||
|
||||
# if multiple outputs, sum the loss
|
||||
if loss.ndim != 1:
|
||||
loss = np.sum(loss, axis=1)
|
||||
|
||||
# plot loss
|
||||
plt.plot(epochs, loss, label=label)
|
||||
plt.legend()
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
plt.title('Loss function')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
# save plot
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
# plt.plot(epochs, loss, label=label)
|
||||
# if log_scale:
|
||||
# plt.yscale('log')
|
||||
|
||||
Reference in New Issue
Block a user