🎨 Format Python code with psf/black
This commit is contained in:
183
pina/plotter.py
183
pina/plotter.py
@@ -34,30 +34,33 @@ class Plotter:
|
||||
|
||||
if variables is None:
|
||||
variables = problem.domain.variables
|
||||
elif variables == 'spatial':
|
||||
elif variables == "spatial":
|
||||
variables = problem.spatial_domain.variables
|
||||
elif variables == 'temporal':
|
||||
elif variables == "temporal":
|
||||
variables = problem.temporal_domain.variables
|
||||
|
||||
if len(variables) not in [1, 2, 3]:
|
||||
raise ValueError('Samples can be plotted only in '
|
||||
'dimensions 1, 2 and 3.')
|
||||
raise ValueError(
|
||||
"Samples can be plotted only in " "dimensions 1, 2 and 3."
|
||||
)
|
||||
|
||||
fig = plt.figure()
|
||||
proj = '3d' if len(variables) == 3 else None
|
||||
proj = "3d" if len(variables) == 3 else None
|
||||
ax = fig.add_subplot(projection=proj)
|
||||
for location in problem.input_pts:
|
||||
coords = problem.input_pts[location].extract(variables).T.detach()
|
||||
if len(variables)==1: # 1D samples
|
||||
ax.plot(coords.flatten(),
|
||||
torch.zeros(coords.flatten().shape),
|
||||
'.',
|
||||
label=location,
|
||||
**kwargs)
|
||||
elif len(variables)==2:
|
||||
ax.plot(*coords, '.', label=location, **kwargs)
|
||||
elif len(variables)==3:
|
||||
ax.scatter(*coords, '.', label=location, **kwargs)
|
||||
if len(variables) == 1: # 1D samples
|
||||
ax.plot(
|
||||
coords.flatten(),
|
||||
torch.zeros(coords.flatten().shape),
|
||||
".",
|
||||
label=location,
|
||||
**kwargs,
|
||||
)
|
||||
elif len(variables) == 2:
|
||||
ax.plot(*coords, ".", label=location, **kwargs)
|
||||
elif len(variables) == 3:
|
||||
ax.scatter(*coords, ".", label=location, **kwargs)
|
||||
|
||||
ax.set_xlabel(variables[0])
|
||||
try:
|
||||
@@ -94,27 +97,23 @@ class Plotter:
|
||||
"""
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||
|
||||
ax.plot(pts.extract(v), pred, label='Neural Network solution', **kwargs)
|
||||
ax.plot(pts.extract(v), pred, label="Neural Network solution", **kwargs)
|
||||
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).detach()
|
||||
ax.plot(pts.extract(v), truth_output,
|
||||
label='True solution', **kwargs)
|
||||
ax.plot(
|
||||
pts.extract(v), truth_output, label="True solution", **kwargs
|
||||
)
|
||||
|
||||
# TODO: pred is a torch.Tensor, so no labels is available
|
||||
# extra variable for labels should be
|
||||
# passed in the function arguments.
|
||||
# plt.ylabel(pred.labels[0])
|
||||
# plt.ylabel(pred.labels[0])
|
||||
plt.legend()
|
||||
|
||||
def _2d_plot(self,
|
||||
pts,
|
||||
pred,
|
||||
v,
|
||||
res,
|
||||
method,
|
||||
truth_solution=None,
|
||||
**kwargs):
|
||||
def _2d_plot(
|
||||
self, pts, pred, v, res, method, truth_solution=None, **kwargs
|
||||
):
|
||||
"""Plot solution for two dimensional function
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
@@ -136,44 +135,47 @@ class Plotter:
|
||||
|
||||
pred_output = pred.reshape(res, res)
|
||||
if truth_solution:
|
||||
truth_output = truth_solution(pts).float().reshape(res, res).as_subclass(torch.Tensor)
|
||||
truth_output = (
|
||||
truth_solution(pts)
|
||||
.float()
|
||||
.reshape(res, res)
|
||||
.as_subclass(torch.Tensor)
|
||||
)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(*grids, pred_output,
|
||||
**kwargs)
|
||||
cb = getattr(ax[0], method)(*grids, pred_output, **kwargs)
|
||||
fig.colorbar(cb, ax=ax[0])
|
||||
ax[0].title.set_text('Neural Network prediction')
|
||||
cb = getattr(ax[1], method)(*grids, truth_output,
|
||||
**kwargs)
|
||||
ax[0].title.set_text("Neural Network prediction")
|
||||
cb = getattr(ax[1], method)(*grids, truth_output, **kwargs)
|
||||
fig.colorbar(cb, ax=ax[1])
|
||||
ax[1].title.set_text('True solution')
|
||||
cb = getattr(ax[2],
|
||||
method)(*grids,
|
||||
(truth_output - pred_output),
|
||||
**kwargs)
|
||||
ax[1].title.set_text("True solution")
|
||||
cb = getattr(ax[2], method)(
|
||||
*grids, (truth_output - pred_output), **kwargs
|
||||
)
|
||||
fig.colorbar(cb, ax=ax[2])
|
||||
ax[2].title.set_text('Residual')
|
||||
ax[2].title.set_text("Residual")
|
||||
else:
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(ax, method)(*grids, pred_output,
|
||||
**kwargs)
|
||||
cb = getattr(ax, method)(*grids, pred_output, **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
ax.title.set_text('Neural Network prediction')
|
||||
ax.title.set_text("Neural Network prediction")
|
||||
|
||||
def plot(self,
|
||||
solver,
|
||||
components=None,
|
||||
fixed_variables={},
|
||||
method='contourf',
|
||||
res=256,
|
||||
filename=None,
|
||||
**kwargs):
|
||||
def plot(
|
||||
self,
|
||||
solver,
|
||||
components=None,
|
||||
fixed_variables={},
|
||||
method="contourf",
|
||||
res=256,
|
||||
filename=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Plot sample of SolverInterface output.
|
||||
|
||||
:param SolverInterface solver: The ``SolverInterface`` object instance.
|
||||
:param str | list(str) components: The output variable(s) to plot.
|
||||
If None, all the output variables of the problem are selected.
|
||||
:param str | list(str) components: The output variable(s) to plot.
|
||||
If None, all the output variables of the problem are selected.
|
||||
Default value is None.
|
||||
:param dict fixed_variables: A dictionary with all the variables that
|
||||
should be kept fixed during the plot. The keys of the dictionary
|
||||
@@ -190,23 +192,28 @@ class Plotter:
|
||||
|
||||
if components is None:
|
||||
components = solver.problem.output_variables
|
||||
|
||||
|
||||
if isinstance(components, str):
|
||||
components = [components]
|
||||
|
||||
if not isinstance(components, list):
|
||||
raise NotImplementedError('Output variables must be passed'
|
||||
'as a string or a list of strings.')
|
||||
|
||||
raise NotImplementedError(
|
||||
"Output variables must be passed"
|
||||
"as a string or a list of strings."
|
||||
)
|
||||
|
||||
if len(components) > 1:
|
||||
raise NotImplementedError('Multidimensional plots are not implemented, '
|
||||
'set components to an available components of'
|
||||
' the problem.')
|
||||
raise NotImplementedError(
|
||||
"Multidimensional plots are not implemented, "
|
||||
"set components to an available components of"
|
||||
" the problem."
|
||||
)
|
||||
v = [
|
||||
var for var in solver.problem.input_variables
|
||||
var
|
||||
for var in solver.problem.input_variables
|
||||
if var not in fixed_variables.keys()
|
||||
]
|
||||
pts = solver.problem.domain.sample(res, 'grid', variables=v)
|
||||
pts = solver.problem.domain.sample(res, "grid", variables=v)
|
||||
|
||||
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
|
||||
fixed_pts *= torch.tensor(list(fixed_variables.values()))
|
||||
@@ -218,16 +225,20 @@ class Plotter:
|
||||
|
||||
# computing soluting and sending to cpu
|
||||
predicted_output = solver.forward(pts).extract(components)
|
||||
predicted_output = predicted_output.as_subclass(torch.Tensor).cpu().detach()
|
||||
predicted_output = (
|
||||
predicted_output.as_subclass(torch.Tensor).cpu().detach()
|
||||
)
|
||||
pts = pts.cpu()
|
||||
truth_solution = getattr(solver.problem, 'truth_solution', None)
|
||||
truth_solution = getattr(solver.problem, "truth_solution", None)
|
||||
|
||||
if len(v) == 1:
|
||||
self._1d_plot(pts, predicted_output, v, method, truth_solution,
|
||||
**kwargs)
|
||||
self._1d_plot(
|
||||
pts, predicted_output, v, method, truth_solution, **kwargs
|
||||
)
|
||||
elif len(v) == 2:
|
||||
self._2d_plot(pts, predicted_output, v, res, method, truth_solution,
|
||||
**kwargs)
|
||||
self._2d_plot(
|
||||
pts, predicted_output, v, res, method, truth_solution, **kwargs
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
if filename:
|
||||
@@ -236,13 +247,15 @@ class Plotter:
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_loss(self,
|
||||
trainer,
|
||||
metrics=None,
|
||||
logy=False,
|
||||
logx=False,
|
||||
filename=None,
|
||||
**kwargs):
|
||||
def plot_loss(
|
||||
self,
|
||||
trainer,
|
||||
metrics=None,
|
||||
logy=False,
|
||||
logx=False,
|
||||
filename=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
|
||||
@@ -260,41 +273,43 @@ class Plotter:
|
||||
|
||||
# check that MetricTracker has been used
|
||||
list_ = [
|
||||
idx for idx, s in enumerate(trainer.callbacks)
|
||||
idx
|
||||
for idx, s in enumerate(trainer.callbacks)
|
||||
if isinstance(s, MetricTracker)
|
||||
]
|
||||
if not bool(list_):
|
||||
raise FileNotFoundError(
|
||||
'MetricTracker should be used as a callback during training to'
|
||||
' use this method.')
|
||||
"MetricTracker should be used as a callback during training to"
|
||||
" use this method."
|
||||
)
|
||||
|
||||
# extract trainer metrics
|
||||
trainer_metrics = trainer.callbacks[list_[0]].metrics
|
||||
if metrics is None:
|
||||
metrics = ['mean_loss']
|
||||
metrics = ["mean_loss"]
|
||||
elif not isinstance(metrics, list):
|
||||
raise ValueError('metrics must be class list.')
|
||||
raise ValueError("metrics must be class list.")
|
||||
|
||||
# loop over metrics to plot
|
||||
for metric in metrics:
|
||||
if metric not in trainer_metrics:
|
||||
raise ValueError(
|
||||
f'{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}.'
|
||||
f"{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}."
|
||||
)
|
||||
loss = trainer_metrics[metric]
|
||||
epochs = range(len(loss))
|
||||
plt.plot(epochs, loss.cpu(), **kwargs)
|
||||
|
||||
# plotting
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('loss')
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("loss")
|
||||
plt.legend()
|
||||
|
||||
# log axis
|
||||
if logy:
|
||||
plt.yscale('log')
|
||||
plt.yscale("log")
|
||||
if logx:
|
||||
plt.xscale('log')
|
||||
plt.xscale("log")
|
||||
|
||||
# saving in file
|
||||
if filename:
|
||||
|
||||
Reference in New Issue
Block a user